From 33d1945de2b00eefce596c1fc221cd091d9ab719 Mon Sep 17 00:00:00 2001 From: "matthias@arch" Date: Mon, 14 Aug 2023 18:43:53 +0200 Subject: [PATCH] added padding function --- teng_ml/util/pad.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 teng_ml/util/pad.py diff --git a/teng_ml/util/pad.py b/teng_ml/util/pad.py new file mode 100644 index 0000000..b666602 --- /dev/null +++ b/teng_ml/util/pad.py @@ -0,0 +1,14 @@ +import torch +import torch.nn.utils.rnn as rnn +import numpy as np + +class PadSequences: + def __call__(self, batch): + # batch = [(data, label)] + # sort by length + sorted_batch = sorted(batch, key=lambda sample: sample[0].shape[0], reverse=True) + sequences = [torch.Tensor(sample[0]) for sample in sorted_batch] + labels = torch.Tensor(np.array([sample[1] for sample in sorted_batch])) + lengths = torch.IntTensor(np.array([seq.shape[0] for seq in sequences])) + sequences_padded = rnn.pad_sequence(sequences, batch_first=True) + return sequences_padded, lengths, labels