2023-05-10 22:44:14 +02:00
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
class RNN(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
(Bi)LSTM for name classification
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, input_size, hidden_size, num_layers, num_classes, bidirectional):
|
|
|
|
|
super(RNN, self).__init__()
|
|
|
|
|
self.num_layers = num_layers
|
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
|
self.is_bidirectional = bidirectional
|
|
|
|
|
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirectional)
|
|
|
|
|
if bidirectional == True:
|
|
|
|
|
self.fc = nn.Linear(hidden_size * 2, num_classes)
|
|
|
|
|
else:
|
|
|
|
|
self.fc = nn.Linear(hidden_size, num_classes)
|
|
|
|
|
|
|
|
|
|
self.softmax = nn.Softmax(dim=1)
|
2023-05-26 14:01:15 +02:00
|
|
|
|
self.D = 2 if self.is_bidirectional == True else 1
|
2023-05-10 22:44:14 +02:00
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
device = x.device
|
|
|
|
|
|
2023-05-26 14:01:15 +02:00
|
|
|
|
# h0: initial hidden states
|
|
|
|
|
# c0: initial cell states
|
|
|
|
|
if len(x.shape) == 2: # x: (seq_length, features)
|
|
|
|
|
h0 = torch.zeros(self.D * self.num_layers, self.hidden_size).to(device)
|
|
|
|
|
c0 = torch.zeros(self.D * self.num_layers, self.hidden_size).to(device)
|
|
|
|
|
elif len(x.shape) == 3: # x: (batch, seq_length, features)
|
|
|
|
|
batch_size = x.shape[0]
|
|
|
|
|
h0 = torch.zeros(self.D * self.num_layers, batch_size, self.hidden_size).to(device)
|
|
|
|
|
c0 = torch.zeros(self.D * self.num_layers, batch_size, self.hidden_size).to(device)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"RNN.forward: invalid iput shape: {x.shape}. Must be (batch, seq_length, features) or (seq_length, features)")
|
2023-05-10 22:44:14 +02:00
|
|
|
|
|
2023-05-26 14:01:15 +02:00
|
|
|
|
# lstm: (batch_size, seq_length, features) -> (batch_size, hidden_size)
|
2023-05-10 22:44:14 +02:00
|
|
|
|
out, (h_n, c_n) = self.lstm(x, (h0, c0))
|
2023-05-26 14:01:15 +02:00
|
|
|
|
print(f"forward: out.shape={out.shape} TODO verify comment")
|
2023-05-10 22:44:14 +02:00
|
|
|
|
# out: (N, L, D * hidden_size)
|
|
|
|
|
# h_n: (D * num_layers, hidden_size)
|
|
|
|
|
# c_n: (D * num_layers, hidden_size)
|
|
|
|
|
# print(f"out({out.shape})={out}")
|
|
|
|
|
# print(f"h_n({h_n.shape})={h_n}")
|
|
|
|
|
# print(f"c_n({c_n.shape})={c_n}")
|
|
|
|
|
# print(f"out({out.shape})=...")
|
|
|
|
|
# print(f"h_n({h_n.shape})=...")
|
|
|
|
|
# print(f"c_n({c_n.shape})=...")
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
# select only last layer [-1] -> last layer,
|
|
|
|
|
last_layer_state = h_n.view(self.num_layers, D, batch_size, self.hidden_size)[-1]
|
|
|
|
|
if D == 1:
|
|
|
|
|
# [1, batch_size, hidden_size] -> [batch_size, hidden_size]
|
|
|
|
|
X = last_layer_state.squeeze() # TODO what if batch_size == 1
|
|
|
|
|
elif D == 2:
|
|
|
|
|
h_1, h_2 = last_layer_state[0], last_layer_state[1] # states of both directions
|
|
|
|
|
# concatenate both states, X-size: (Batch, hidden_size * 2)
|
|
|
|
|
X = torch.cat((h_1, h_2), dim=1)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("D must be 1 or 2")
|
|
|
|
|
""" # all this is quivalent to line below
|
|
|
|
|
out = out[:,-1,:] # select last time step
|
|
|
|
|
|
2023-05-26 14:01:15 +02:00
|
|
|
|
# fc fully connected layer: (*, hidden_size) -> (*, num_classes)
|
|
|
|
|
out = self.fc(out)
|
2023-05-10 22:44:14 +02:00
|
|
|
|
|
2023-05-26 14:01:15 +02:00
|
|
|
|
# softmax: (*) -> (*)
|
|
|
|
|
out = self.softmax(out)
|
2023-05-10 22:44:14 +02:00
|
|
|
|
return out
|