added padding function
This commit is contained in:
parent
37bb1f444e
commit
33d1945de2
14
teng_ml/util/pad.py
Normal file
14
teng_ml/util/pad.py
Normal file
@ -0,0 +1,14 @@
|
||||
import torch
|
||||
import torch.nn.utils.rnn as rnn
|
||||
import numpy as np
|
||||
|
||||
class PadSequences:
|
||||
def __call__(self, batch):
|
||||
# batch = [(data, label)]
|
||||
# sort by length
|
||||
sorted_batch = sorted(batch, key=lambda sample: sample[0].shape[0], reverse=True)
|
||||
sequences = [torch.Tensor(sample[0]) for sample in sorted_batch]
|
||||
labels = torch.Tensor(np.array([sample[1] for sample in sorted_batch]))
|
||||
lengths = torch.IntTensor(np.array([seq.shape[0] for seq in sequences]))
|
||||
sequences_padded = rnn.pad_sequence(sequences, batch_first=True)
|
||||
return sequences_padded, lengths, labels
|
Loading…
Reference in New Issue
Block a user