Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Created November 15, 2025 00:39
Show Gist options
  • Select an option

  • Save pszemraj/a89bdf4848e3f946b0caa478855aa7b5 to your computer and use it in GitHub Desktop.

Select an option

Save pszemraj/a89bdf4848e3f946b0caa478855aa7b5 to your computer and use it in GitHub Desktop.
UL2 Data Collator for PyTorch + Transformers
"""
UL2 Data Collator for PyTorch + Transformers
==============================================
Standalone implementation of UL2 (Unified Language Learner) denoising objectives
for encoder-decoder models (T5, UL2, Flan-T5, etc.).
Based on: "Unifying Language Learning Paradigms" (Tay et al., 2022)
https://arxiv.org/abs/2205.05131
This collator implements mixture-of-denoisers with configurable span corruption:
- [R] Regular denoising (standard T5-like span corruption)
- [S] Sequential denoising (prefix/causal language modeling)
- [X] Extreme denoising (aggressive corruption for robustness)
Requirements:
torch>=1.13.0
transformers>=4.20.0
numpy>=1.21.0
Author: Extracted from catie-aq/flashT5
License: MIT
"""
from typing import Dict, List, Tuple
import numpy as np
import torch
from transformers import AutoTokenizer, BatchEncoding
from transformers.data.data_collator import DataCollatorMixin
# ============================================================================
# RECOMMENDED CONFIGURATIONS
# ============================================================================
UL2_7_DENOISER = {
"denoiser_list": [
{"mu": 3.0, "r": 0.15, "max_spans": 512, "prefix": "[R]"}, # Regular short
{"mu": 8.0, "r": 0.15, "max_spans": 512, "prefix": "[R]"}, # Regular long
{"mu": 4.0, "r": 0.0, "max_spans": 1, "prefix": "[S]"}, # Sequential (causal)
{"mu": 3.0, "r": 0.5, "max_spans": 512, "prefix": "[X]"}, # Extreme aggressive
{"mu": 8.0, "r": 0.15, "max_spans": 512, "prefix": "[X]"}, # Extreme long
{"mu": 64.0, "r": 0.15, "max_spans": 512, "prefix": "[X]"}, # Extreme very long
{
"mu": 64.0,
"r": 0.5,
"max_spans": 512,
"prefix": "[X]",
}, # Extreme very aggressive
],
"denoiser_proportions": [0.165, 0.165, 0.34, 0.0825, 0.0825, 0.0825, 0.0825],
}
T5_STANDARD = {
"denoiser_list": [{"mu": 3.0, "r": 0.15, "max_spans": 512, "prefix": ""}],
"denoiser_proportions": [1.0],
}
# ============================================================================
# MAIN DATA COLLATOR
# ============================================================================
class DataCollatorForUL2MLM(DataCollatorMixin):
"""
Data collator for UL2-style masked language modeling with mixture of denoisers.
Implements span corruption with configurable parameters per denoiser type.
Supports dynamic batching with sequence packing for efficient training.
Args:
tokenizer: HuggingFace tokenizer (must have extra_id tokens and special prefixes)
max_length: Maximum input sequence length
max_labels_length: Maximum target/label sequence length
batch_size: Target batch size for packing
denoiser_list: List of denoiser configs, each with:
- mu (float): Mean span length for corruption
- r (float): Noise density (0.0-1.0, fraction of tokens to corrupt)
- max_spans (int): Maximum number of corrupted spans
- prefix (str): Task prefix token (e.g., "[R]", "[S]", "[X]")
denoiser_proportions: Sampling probabilities for each denoiser (must sum to 1.0)
causal: If True, format for causal LM; if False, encoder-decoder format
random_chunk: If True, randomly sample chunks from long sequences
fixed_batch_size: If True, pad batches to exact batch_size with wrapping
min_size_inputs: Minimum sequence length to include (filter shorter)
Example:
>>> tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-base")
>>> collator = DataCollatorForUL2MLM(
... tokenizer=tokenizer,
... max_length=512,
... max_labels_length=128,
... batch_size=32,
... **UL2_7_DENOISER
... )
"""
def __init__(
self,
tokenizer: AutoTokenizer,
max_length: int,
max_labels_length: int,
batch_size: int,
denoiser_list: List[Dict],
denoiser_proportions: List[float],
causal: bool = False,
random_chunk: bool = True,
fixed_batch_size: bool = False,
min_size_inputs: int = 10,
):
super().__init__()
# Normalize proportions to sum to 1.0
self.denoiser_proportions = denoiser_proportions
if sum(self.denoiser_proportions) != 1.0:
self.denoiser_proportions = [
x / sum(self.denoiser_proportions) for x in self.denoiser_proportions
]
self.denoiser_list = denoiser_list
self.tokenizer = tokenizer
# Encode prefix tokens (e.g., "[R]" -> token IDs, excluding EOS)
self.prefixes = [
tokenizer.encode(denoiser["prefix"], return_tensors="np").flatten()[:-1]
for denoiser in denoiser_list
]
# Extract extra_id sentinel tokens (assumed contiguous)
self.extra_ids = sorted(
[
tokenizer.all_special_ids[i]
for i, token in enumerate(tokenizer.all_special_tokens)
if "extra" in token
],
reverse=True,
)
self.max_length = max_length
self.batch_size = batch_size
self.max_labels_length = max_labels_length
# Pre-compute optimal lengths for each denoiser
max_prefix_len = max(len(p) for p in self.prefixes)
self.denoiser_optimal_len = [
self.compute_input_and_target_lengths(
max_length - max_prefix_len, denoiser["r"], denoiser["mu"]
)
for denoiser in self.denoiser_list
]
self.causal = causal
self.random_chunk = random_chunk
self.fixed_batch_size = fixed_batch_size
self.min_size_inputs = min_size_inputs
def is_special_token(self, x: np.ndarray) -> np.ndarray:
"""Check if token IDs are sentinel tokens (extra_id range)."""
return (x <= self.extra_ids[0]) & (x >= self.extra_ids[-1])
def _best_fit(
self, input_ids: List[np.ndarray], labels: List[np.ndarray]
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""
Pack multiple sequences into batches using first-fit bin packing.
Maximizes GPU utilization by concatenating short sequences until
max_length/max_labels_length constraints are hit.
Returns:
Tuple of (packed_inputs, packed_labels)
"""
batch_inputs = []
batch_labels = []
for _ in range(self.batch_size):
bin_inputs = []
bin_labels = []
bin_input_length = 0
bin_label_length = 0
bin_special_tokens = 0
for i, (x, item_labels) in enumerate(zip(input_ids, labels)):
if x is None:
continue
size_inputs = x.shape[1]
size_labels = item_labels.shape[1]
num_new_special_tokens = self.is_special_token(x).sum()
# Check if sequence fits in current bin
if (
bin_input_length + size_inputs < self.max_length
and bin_label_length + size_labels < self.max_labels_length
and bin_special_tokens + num_new_special_tokens
< len(self.extra_ids)
):
bin_inputs.append(x)
bin_labels.append(item_labels)
bin_input_length += size_inputs
bin_label_length += size_labels
bin_special_tokens += num_new_special_tokens
# Mark as used
input_ids[i] = None
labels[i] = None
if bin_inputs:
batch_inputs.append(np.concatenate(bin_inputs, axis=1))
batch_labels.append(np.concatenate(bin_labels, axis=1))
return batch_inputs, batch_labels
def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding:
"""
Process a batch of examples with UL2 denoising.
Args:
examples: List of dicts with 'input_ids' (shape: [1, seq_len])
Returns:
BatchEncoding with 'input_ids', 'attention_mask', 'labels'
"""
# Filter examples that are too short
examples = [
x for x in examples if x["input_ids"].shape[1] > self.min_size_inputs
]
input_batch_size = len(examples)
# Sample denoiser for each example
denoisers_sample = np.random.choice(
range(len(self.denoiser_list)),
input_batch_size,
p=self.denoiser_proportions,
)
# Add length field if missing
has_length = "length" in examples[0]
if not has_length:
for i in range(len(examples)):
examples[i]["length"] = examples[i]["input_ids"].shape[1]
# Truncate to optimal length for each denoiser
truncated_examples = []
for i, x in enumerate(examples):
max_len = self.denoiser_optimal_len[denoisers_sample[i]][0]
if x["length"] > max_len:
start = 0
if self.random_chunk:
start = np.random.randint(0, x["length"] - max_len)
new_input_ids = x["input_ids"][:, start : start + max_len]
truncated_examples.append(
{"input_ids": new_input_ids, "length": np.array(max_len)}
)
else:
truncated_examples.append(x)
examples = truncated_examples
# Generate noise masks for span corruption
spans_noise_masks = [
self.random_spans_noise_mask(
x["length"], self.denoiser_list[denoisers_sample[i]]
)
for i, x in enumerate(examples)
]
# Create sentinel IDs for masked and unmasked spans
input_ids_sentinel = [
self.create_sentinel_ids(x.astype(np.int8)) for x in spans_noise_masks
]
labels_sentinel = [
self.create_sentinel_ids((~x).astype(np.int8)) for x in spans_noise_masks
]
# Apply masks and add prefixes
input_ids = [
self.filter_input_ids(
x["input_ids"],
input_ids_sentinel[i],
np.expand_dims(self.prefixes[denoisers_sample[i]], axis=0),
)
for i, x in enumerate(examples)
]
labels = [
self.filter_input_ids(x["input_ids"], labels_sentinel[i], with_eos=False)
for i, x in enumerate(examples)
]
# Pack sequences if needed
if len(input_ids) == self.batch_size:
batch_inputs, batch_labels = input_ids, labels
else:
batch_inputs, batch_labels = self._best_fit(input_ids, labels)
# Replace sentinel placeholders with actual extra_id tokens
labels = [
np.where(
self.is_special_token(x),
self.extra_ids[0] - np.cumsum(self.is_special_token(x)) + 1,
x,
)
for x in batch_labels
]
input_ids = [
np.where(
self.is_special_token(x),
self.extra_ids[0] - np.cumsum(self.is_special_token(x)) + 1,
x,
)
for x in batch_inputs
]
# Add EOS to labels
labels = [
np.concatenate(
[x, np.full((1, 1), self.tokenizer.eos_token_id, dtype=np.int32)],
axis=-1,
)
for x in labels
]
# Pad sequences
if self.causal:
# Causal: left-pad inputs, right-pad labels
labels = np.concatenate(
[
np.pad(
x,
((0, 0), (0, self.max_labels_length - x.shape[1])),
constant_values=self.tokenizer.pad_token_id,
)
for x in labels
],
axis=0,
)
input_ids = np.concatenate(
[
np.pad(
x,
((0, 0), (self.max_length - x.shape[1], 0)),
constant_values=self.tokenizer.pad_token_id,
)
for x in input_ids
],
axis=0,
)
else:
# Encoder-decoder: right-pad everything
labels = np.concatenate(
[
np.pad(
x,
((0, 0), (0, self.max_labels_length - x.shape[1])),
constant_values=self.tokenizer.pad_token_id,
)
for x in labels
],
axis=0,
)
input_ids = np.concatenate(
[
np.pad(
x,
((0, 0), (0, self.max_length - x.shape[1])),
constant_values=self.tokenizer.pad_token_id,
)
for x in input_ids
],
axis=0,
)
# Fixed batch size padding with wrapping
if self.fixed_batch_size and input_ids.shape[0] < self.batch_size:
input_ids = np.pad(
input_ids,
((0, self.batch_size - input_ids.shape[0]), (0, 0)),
mode="wrap",
)
labels = np.pad(
labels, ((0, self.batch_size - labels.shape[0]), (0, 0)), mode="wrap"
)
# Construct final batch
batch = {}
if not self.causal:
batch["input_ids"] = torch.from_numpy(input_ids)
causal_labels = torch.from_numpy(labels)
else:
# Concatenate inputs and labels for causal LM
batch["input_ids"] = torch.from_numpy(
np.concatenate([input_ids, labels], axis=-1)
)
causal_labels = batch["input_ids"].clone()
batch["attention_mask"] = batch["input_ids"] != self.tokenizer.pad_token_id
causal_labels[causal_labels == self.tokenizer.pad_token_id] = -100
batch["labels"] = causal_labels
return batch
def compute_input_and_target_lengths(
self, inputs_length: int, noise_density: float, mean_noise_span_length: float
) -> Tuple[int, int]:
"""
Compute optimal token lengths to avoid padding.
Given desired input length, noise parameters, and mean span length,
calculates the required number of tokens in raw text and encoded targets.
Adapted from T5's random_spans_helper:
https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/data/preprocessors.py#L2466
Args:
inputs_length: Desired tokenized input sequence length
noise_density: Fraction of tokens to corrupt (0.0-1.0)
mean_noise_span_length: Average span length for corruption
Returns:
(tokens_length, targets_length) tuple
"""
def _tokens_length_to_inputs_length_targets_length(tokens_length):
num_noise_tokens = int(round(tokens_length * noise_density))
num_nonnoise_tokens = tokens_length - num_noise_tokens
num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
# Inputs: all nonnoise tokens + sentinels + EOS
_input_length = num_nonnoise_tokens + num_noise_spans + 1
_output_length = num_noise_tokens + num_noise_spans + 1
return _input_length, _output_length
tokens_length = inputs_length
# Special case: causal LM (no corruption)
if noise_density == 0.0:
return (
self.max_labels_length
- 2
+ int(self.max_length // mean_noise_span_length)
- 2,
inputs_length,
)
# Find optimal tokens_length
while (
_tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0]
<= inputs_length
):
tokens_length += 1
inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(
tokens_length
)
# Adjust for 50% noise density edge case
if noise_density == 0.5 and targets_length > inputs_length:
tokens_length -= 1
targets_length -= 1
return tokens_length, targets_length
def random_spans_noise_mask(
self, sequence_length: int, denoiser_params: Dict
) -> np.ndarray:
"""
Generate boolean mask for span corruption.
Creates alternating spans of noise/non-noise tokens with configurable
density and span length. Special handling for sequential denoising (max_spans=1).
Adapted from T5's random_spans_noise_mask:
https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/data/preprocessors.py#L2682
Args:
sequence_length: Length of input sequence
denoiser_params: Dict with 'mu', 'r', 'max_spans' keys
Returns:
Boolean array of shape [sequence_length] where True = noise
"""
mean_noise_span_length = denoiser_params["mu"]
noise_density = denoiser_params["r"]
max_num_spans = denoiser_params["max_spans"]
if max_num_spans == 1:
# Force single span at sequence start (for sequential/causal denoising)
prefix_span = int(np.round(sequence_length / mean_noise_span_length))
masked_span = sequence_length - prefix_span
interleaved_span_lengths = np.array([prefix_span, masked_span])
else:
# Standard span corruption
num_noise_tokens = int(np.round(sequence_length * noise_density))
num_noise_tokens = min(max(num_noise_tokens, 1), sequence_length - 1)
num_noise_spans = min(
max_num_spans, int(np.round(num_noise_tokens / mean_noise_span_length))
)
num_noise_spans = max(num_noise_spans, 1)
num_nonnoise_tokens = sequence_length - num_noise_tokens
def _random_segmentation(num_items, num_segments):
"""Randomly partition items into non-empty segments."""
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
np.random.shuffle(mask_indices)
first_in_segment = np.pad(mask_indices, [[1, 0]])
segment_id = np.cumsum(first_in_segment)
_, segment_length = np.unique(segment_id, return_counts=True)
return segment_length
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
nonnoise_span_lengths = _random_segmentation(
num_nonnoise_tokens, num_noise_spans
)
interleaved_span_lengths = np.reshape(
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
[num_noise_spans * 2],
)
# Convert spans to boolean mask
span_starts = np.cumsum(interleaved_span_lengths)[:-1]
span_start_indicator = np.zeros((sequence_length,), dtype=np.int8)
span_start_indicator[span_starts] = True
span_num = np.cumsum(span_start_indicator)
is_noise = np.equal(span_num % 2, 1)
return is_noise
def create_sentinel_ids(self, mask_indices: np.ndarray) -> np.ndarray:
"""
Convert boolean mask to sentinel token IDs.
Replaces span starts with increasing sentinel IDs.
Consecutive masked tokens are marked for deletion (-1).
Args:
mask_indices: Boolean array where True = masked
Returns:
Array with sentinel IDs at span starts, 0 elsewhere
"""
start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
start_indices[0] = mask_indices[0]
sentinel_ids = np.where(
start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices
)
sentinel_ids = np.where(
sentinel_ids != 0, (self.extra_ids[0] - sentinel_ids), 0
)
sentinel_ids -= mask_indices - start_indices
return sentinel_ids
def filter_input_ids(
self,
input_ids: np.ndarray,
sentinel_ids: np.ndarray,
prefixes: np.ndarray = None,
with_eos: bool = True,
) -> np.ndarray:
"""
Apply sentinel mask and fuse consecutive masks.
Replaces original tokens with sentinels, removes -1 markers,
and optionally adds prefix and EOS tokens.
Args:
input_ids: Original token IDs
sentinel_ids: Sentinel mask from create_sentinel_ids
prefixes: Optional prefix tokens to prepend
with_eos: Whether to append EOS token
Returns:
Filtered token ID array
"""
batch_size = input_ids.shape[0]
# Replace with sentinels
input_ids = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
# Remove EOS and negative markers
input_ids = input_ids[input_ids != self.tokenizer.eos_token_id]
input_ids = input_ids[input_ids >= 0].reshape((batch_size, -1))
# Add prefix
if prefixes is not None:
input_ids = np.concatenate([prefixes, input_ids], axis=-1)
# Add EOS
if with_eos:
input_ids = np.concatenate(
[
input_ids,
np.full(
(batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32
),
],
axis=-1,
)
return input_ids
# ============================================================================
# USAGE EXAMPLE
# ============================================================================
if __name__ == "__main__":
"""
Example: Train T5 with UL2 objectives
Requires:
- Tokenizer with extra_id tokens (T5, UL2, Flan-T5)
- Tokenizer extended with special prefixes: [R], [S], [X]
"""
from transformers import AutoTokenizer
# Load tokenizer (ensure it has extra_id tokens)
tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-base")
# Add special denoiser prefix tokens
special_tokens = {"additional_special_tokens": ["[R]", "[S]", "[X]"]}
tokenizer.add_special_tokens(special_tokens)
# Initialize UL2 collator
collator = DataCollatorForUL2MLM(
tokenizer=tokenizer,
max_length=512,
max_labels_length=128,
batch_size=32,
**UL2_7_DENOISER, # Use 7-denoiser mixture
)
# Or use standard T5 denoising
# collator = DataCollatorForUL2MLM(
# tokenizer=tokenizer,
# max_length=512,
# max_labels_length=128,
# batch_size=32,
# **T5_STANDARD
# )
# Create dummy examples
dummy_data = [
{"input_ids": np.random.randint(0, 1000, (1, 256))} for _ in range(64)
]
# Collate batch
batch = collator(dummy_data)
print(f"Input IDs shape: {batch['input_ids'].shape}")
print(f"Labels shape: {batch['labels'].shape}")
print(f"Attention mask shape: {batch['attention_mask'].shape}")
# Use with PyTorch DataLoader
# from torch.utils.data import DataLoader
# dataloader = DataLoader(dataset, collate_fn=collator, batch_size=None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment