teng-ml/teng_ml/util/pad.py
2023-08-14 18:43:53 +02:00

15 lines
622 B
Python

import torch
import torch.nn.utils.rnn as rnn
import numpy as np
class PadSequences:
def __call__(self, batch):
# batch = [(data, label)]
# sort by length
sorted_batch = sorted(batch, key=lambda sample: sample[0].shape[0], reverse=True)
sequences = [torch.Tensor(sample[0]) for sample in sorted_batch]
labels = torch.Tensor(np.array([sample[1] for sample in sorted_batch]))
lengths = torch.IntTensor(np.array([seq.shape[0] for seq in sequences]))
sequences_padded = rnn.pad_sequence(sequences, batch_first=True)
return sequences_padded, lengths, labels