teng-ml/teng_ml/rnn/rnn.py

71 lines
3.0 KiB
Python
Raw Normal View History

2023-05-10 22:44:14 +02:00
import torch
import torch.nn as nn
class RNN(nn.Module):
"""
(Bi)LSTM for name classification
"""
def __init__(self, input_size, hidden_size, num_layers, num_classes, bidirectional):
super(RNN, self).__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
self.is_bidirectional = bidirectional
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirectional)
if bidirectional == True:
self.fc = nn.Linear(hidden_size * 2, num_classes)
else:
self.fc = nn.Linear(hidden_size, num_classes)
self.softmax = nn.Softmax(dim=1)
2023-05-26 14:01:15 +02:00
self.D = 2 if self.is_bidirectional == True else 1
2023-05-10 22:44:14 +02:00
def forward(self, x):
device = x.device
2023-05-26 14:01:15 +02:00
# h0: initial hidden states
# c0: initial cell states
if len(x.shape) == 2: # x: (seq_length, features)
h0 = torch.zeros(self.D * self.num_layers, self.hidden_size).to(device)
c0 = torch.zeros(self.D * self.num_layers, self.hidden_size).to(device)
elif len(x.shape) == 3: # x: (batch, seq_length, features)
batch_size = x.shape[0]
h0 = torch.zeros(self.D * self.num_layers, batch_size, self.hidden_size).to(device)
c0 = torch.zeros(self.D * self.num_layers, batch_size, self.hidden_size).to(device)
else:
raise ValueError(f"RNN.forward: invalid iput shape: {x.shape}. Must be (batch, seq_length, features) or (seq_length, features)")
2023-05-10 22:44:14 +02:00
2023-05-26 14:01:15 +02:00
# lstm: (batch_size, seq_length, features) -> (batch_size, hidden_size)
2023-05-10 22:44:14 +02:00
out, (h_n, c_n) = self.lstm(x, (h0, c0))
2023-05-26 14:01:15 +02:00
print(f"forward: out.shape={out.shape} TODO verify comment")
2023-05-10 22:44:14 +02:00
# out: (N, L, D * hidden_size)
# h_n: (D * num_layers, hidden_size)
# c_n: (D * num_layers, hidden_size)
# print(f"out({out.shape})={out}")
# print(f"h_n({h_n.shape})={h_n}")
# print(f"c_n({c_n.shape})={c_n}")
# print(f"out({out.shape})=...")
# print(f"h_n({h_n.shape})=...")
# print(f"c_n({c_n.shape})=...")
"""
# select only last layer [-1] -> last layer,
last_layer_state = h_n.view(self.num_layers, D, batch_size, self.hidden_size)[-1]
if D == 1:
# [1, batch_size, hidden_size] -> [batch_size, hidden_size]
X = last_layer_state.squeeze() # TODO what if batch_size == 1
elif D == 2:
h_1, h_2 = last_layer_state[0], last_layer_state[1] # states of both directions
# concatenate both states, X-size: (Batch, hidden_size * 2
X = torch.cat((h_1, h_2), dim=1)
else:
raise ValueError("D must be 1 or 2")
""" # all this is quivalent to line below
out = out[:,-1,:] # select last time step
2023-05-26 14:01:15 +02:00
# fc fully connected layer: (*, hidden_size) -> (*, num_classes)
out = self.fc(out)
2023-05-10 22:44:14 +02:00
2023-05-26 14:01:15 +02:00
# softmax: (*) -> (*)
out = self.softmax(out)
2023-05-10 22:44:14 +02:00
return out