Collate functions manipulate and merge a list of samples to form a mini-batch, see An example use case is batching sequences of variable-length, which requires padding each sample to the maximum length in the batch.


class audtorch.collate.Collation

Abstract interface for collation classes.

All other collation classes should subclass it. All subclasses should override __call__, that executes the actual collate function.


class audtorch.collate.Seq2Seq(sequence_dimensions, batch_first=None, pad_values=[0, 0], sort_sequences=True)

Pads mini-batches to longest contained sequence for seq2seq-purposes.

This class pads features and targets to the largest sequence in the batch. Before padding, length information are extracted from them.


The tensors can be sorted in descending order of features’ lengths by enabling sort_sequences. Thereby the requirements of torch.nn.utils.rnn.pack_padded_sequence() are anticipated, which is used by recurrent layers.

  • sequence_dimensions holds dimension of sequence in features and targets
  • batch_first controls output shape of features and targets
  • pad_values controls values to pad features (targets) with
  • sort_sequences controls if sequences are sorted in descending order of features’ lengths
  • sequence_dimensions (list of ints) – indices representing dimension of sequence in feature and target tensors. Position 0 represents sequence dimension of features, position 1 represents sequence dimension of targets. Negative indexing is permitted
  • batch_first (bool or None, optional) – determines output shape of collate function. If None, original shape of features and targets is kept with dimension of batch size prepended. See Shape for more information. Default: None
  • pad_values (list, optional) – values to pad shorter sequences with. Position 0 represents value of features, position 1 represents value of targets. Default: [0, 0]
  • sort_sequences (bool, optional) – option whether to sort sequences in descending order of features’ lengths. Default: True
  • Input: \((*, S, *)\), where \(*\) can be any number of further dimensions except \(N\) which is the batch size, and where \(S\) is the sequence dimension.
  • Output:
    • features:
      • \((N, *, S, *)\) if batch_first is None, i.e. the original input shape with \(N\) prepended which is the batch size
      • \((N, S, *, *)\) if batch_first is True
      • \((S, N, *, *)\) if batch_first is False
    • feats_lengths: \((N,)\)
    • targets: analogous to features
    • tgt_lengths: analogous to feats_lengths


>>> # data format: FS = (feature dimension, sequence dimension)
>>> batch = [[torch.zeros(161, 108), torch.zeros(10)],
...          [torch.zeros(161, 223), torch.zeros(12)]]
>>> collate_fn = Seq2Seq([-1, -1], batch_first=None)
>>> features = collate_fn(batch)[0]
>>> list(features.shape)
[2, 161, 223]