teng-ml/teng_ml/rnn/rnn.py
2023-05-26 14:01:15 +02:00

71 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
class RNN(nn.Module):
"""
(Bi)LSTM for name classification
"""
def __init__(self, input_size, hidden_size, num_layers, num_classes, bidirectional):
super(RNN, self).__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
self.is_bidirectional = bidirectional
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirectional)
if bidirectional == True:
self.fc = nn.Linear(hidden_size * 2, num_classes)
else:
self.fc = nn.Linear(hidden_size, num_classes)
self.softmax = nn.Softmax(dim=1)
self.D = 2 if self.is_bidirectional == True else 1
def forward(self, x):
device = x.device
# h0: initial hidden states
# c0: initial cell states
if len(x.shape) == 2: # x: (seq_length, features)
h0 = torch.zeros(self.D * self.num_layers, self.hidden_size).to(device)
c0 = torch.zeros(self.D * self.num_layers, self.hidden_size).to(device)
elif len(x.shape) == 3: # x: (batch, seq_length, features)
batch_size = x.shape[0]
h0 = torch.zeros(self.D * self.num_layers, batch_size, self.hidden_size).to(device)
c0 = torch.zeros(self.D * self.num_layers, batch_size, self.hidden_size).to(device)
else:
raise ValueError(f"RNN.forward: invalid iput shape: {x.shape}. Must be (batch, seq_length, features) or (seq_length, features)")
# lstm: (batch_size, seq_length, features) -> (batch_size, hidden_size)
out, (h_n, c_n) = self.lstm(x, (h0, c0))
print(f"forward: out.shape={out.shape} TODO verify comment")
# out: (N, L, D * hidden_size)
# h_n: (D * num_layers, hidden_size)
# c_n: (D * num_layers, hidden_size)
# print(f"out({out.shape})={out}")
# print(f"h_n({h_n.shape})={h_n}")
# print(f"c_n({c_n.shape})={c_n}")
# print(f"out({out.shape})=...")
# print(f"h_n({h_n.shape})=...")
# print(f"c_n({c_n.shape})=...")
"""
# select only last layer [-1] -> last layer,
last_layer_state = h_n.view(self.num_layers, D, batch_size, self.hidden_size)[-1]
if D == 1:
# [1, batch_size, hidden_size] -> [batch_size, hidden_size]
X = last_layer_state.squeeze() # TODO what if batch_size == 1
elif D == 2:
h_1, h_2 = last_layer_state[0], last_layer_state[1] # states of both directions
# concatenate both states, X-size: (Batch, hidden_size * 2
X = torch.cat((h_1, h_2), dim=1)
else:
raise ValueError("D must be 1 or 2")
""" # all this is quivalent to line below
out = out[:,-1,:] # select last time step
# fc fully connected layer: (*, hidden_size) -> (*, num_classes)
out = self.fc(out)
# softmax: (*) -> (*)
out = self.softmax(out)
return out