Created
November 15, 2025 00:39
-
-
Save pszemraj/a89bdf4848e3f946b0caa478855aa7b5 to your computer and use it in GitHub Desktop.
UL2 Data Collator for PyTorch + Transformers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| UL2 Data Collator for PyTorch + Transformers | |
| ============================================== | |
| Standalone implementation of UL2 (Unified Language Learner) denoising objectives | |
| for encoder-decoder models (T5, UL2, Flan-T5, etc.). | |
| Based on: "Unifying Language Learning Paradigms" (Tay et al., 2022) | |
| https://arxiv.org/abs/2205.05131 | |
| This collator implements mixture-of-denoisers with configurable span corruption: | |
| - [R] Regular denoising (standard T5-like span corruption) | |
| - [S] Sequential denoising (prefix/causal language modeling) | |
| - [X] Extreme denoising (aggressive corruption for robustness) | |
| Requirements: | |
| torch>=1.13.0 | |
| transformers>=4.20.0 | |
| numpy>=1.21.0 | |
| Author: Extracted from catie-aq/flashT5 | |
| License: MIT | |
| """ | |
| from typing import Dict, List, Tuple | |
| import numpy as np | |
| import torch | |
| from transformers import AutoTokenizer, BatchEncoding | |
| from transformers.data.data_collator import DataCollatorMixin | |
| # ============================================================================ | |
| # RECOMMENDED CONFIGURATIONS | |
| # ============================================================================ | |
| UL2_7_DENOISER = { | |
| "denoiser_list": [ | |
| {"mu": 3.0, "r": 0.15, "max_spans": 512, "prefix": "[R]"}, # Regular short | |
| {"mu": 8.0, "r": 0.15, "max_spans": 512, "prefix": "[R]"}, # Regular long | |
| {"mu": 4.0, "r": 0.0, "max_spans": 1, "prefix": "[S]"}, # Sequential (causal) | |
| {"mu": 3.0, "r": 0.5, "max_spans": 512, "prefix": "[X]"}, # Extreme aggressive | |
| {"mu": 8.0, "r": 0.15, "max_spans": 512, "prefix": "[X]"}, # Extreme long | |
| {"mu": 64.0, "r": 0.15, "max_spans": 512, "prefix": "[X]"}, # Extreme very long | |
| { | |
| "mu": 64.0, | |
| "r": 0.5, | |
| "max_spans": 512, | |
| "prefix": "[X]", | |
| }, # Extreme very aggressive | |
| ], | |
| "denoiser_proportions": [0.165, 0.165, 0.34, 0.0825, 0.0825, 0.0825, 0.0825], | |
| } | |
| T5_STANDARD = { | |
| "denoiser_list": [{"mu": 3.0, "r": 0.15, "max_spans": 512, "prefix": ""}], | |
| "denoiser_proportions": [1.0], | |
| } | |
| # ============================================================================ | |
| # MAIN DATA COLLATOR | |
| # ============================================================================ | |
| class DataCollatorForUL2MLM(DataCollatorMixin): | |
| """ | |
| Data collator for UL2-style masked language modeling with mixture of denoisers. | |
| Implements span corruption with configurable parameters per denoiser type. | |
| Supports dynamic batching with sequence packing for efficient training. | |
| Args: | |
| tokenizer: HuggingFace tokenizer (must have extra_id tokens and special prefixes) | |
| max_length: Maximum input sequence length | |
| max_labels_length: Maximum target/label sequence length | |
| batch_size: Target batch size for packing | |
| denoiser_list: List of denoiser configs, each with: | |
| - mu (float): Mean span length for corruption | |
| - r (float): Noise density (0.0-1.0, fraction of tokens to corrupt) | |
| - max_spans (int): Maximum number of corrupted spans | |
| - prefix (str): Task prefix token (e.g., "[R]", "[S]", "[X]") | |
| denoiser_proportions: Sampling probabilities for each denoiser (must sum to 1.0) | |
| causal: If True, format for causal LM; if False, encoder-decoder format | |
| random_chunk: If True, randomly sample chunks from long sequences | |
| fixed_batch_size: If True, pad batches to exact batch_size with wrapping | |
| min_size_inputs: Minimum sequence length to include (filter shorter) | |
| Example: | |
| >>> tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-base") | |
| >>> collator = DataCollatorForUL2MLM( | |
| ... tokenizer=tokenizer, | |
| ... max_length=512, | |
| ... max_labels_length=128, | |
| ... batch_size=32, | |
| ... **UL2_7_DENOISER | |
| ... ) | |
| """ | |
| def __init__( | |
| self, | |
| tokenizer: AutoTokenizer, | |
| max_length: int, | |
| max_labels_length: int, | |
| batch_size: int, | |
| denoiser_list: List[Dict], | |
| denoiser_proportions: List[float], | |
| causal: bool = False, | |
| random_chunk: bool = True, | |
| fixed_batch_size: bool = False, | |
| min_size_inputs: int = 10, | |
| ): | |
| super().__init__() | |
| # Normalize proportions to sum to 1.0 | |
| self.denoiser_proportions = denoiser_proportions | |
| if sum(self.denoiser_proportions) != 1.0: | |
| self.denoiser_proportions = [ | |
| x / sum(self.denoiser_proportions) for x in self.denoiser_proportions | |
| ] | |
| self.denoiser_list = denoiser_list | |
| self.tokenizer = tokenizer | |
| # Encode prefix tokens (e.g., "[R]" -> token IDs, excluding EOS) | |
| self.prefixes = [ | |
| tokenizer.encode(denoiser["prefix"], return_tensors="np").flatten()[:-1] | |
| for denoiser in denoiser_list | |
| ] | |
| # Extract extra_id sentinel tokens (assumed contiguous) | |
| self.extra_ids = sorted( | |
| [ | |
| tokenizer.all_special_ids[i] | |
| for i, token in enumerate(tokenizer.all_special_tokens) | |
| if "extra" in token | |
| ], | |
| reverse=True, | |
| ) | |
| self.max_length = max_length | |
| self.batch_size = batch_size | |
| self.max_labels_length = max_labels_length | |
| # Pre-compute optimal lengths for each denoiser | |
| max_prefix_len = max(len(p) for p in self.prefixes) | |
| self.denoiser_optimal_len = [ | |
| self.compute_input_and_target_lengths( | |
| max_length - max_prefix_len, denoiser["r"], denoiser["mu"] | |
| ) | |
| for denoiser in self.denoiser_list | |
| ] | |
| self.causal = causal | |
| self.random_chunk = random_chunk | |
| self.fixed_batch_size = fixed_batch_size | |
| self.min_size_inputs = min_size_inputs | |
| def is_special_token(self, x: np.ndarray) -> np.ndarray: | |
| """Check if token IDs are sentinel tokens (extra_id range).""" | |
| return (x <= self.extra_ids[0]) & (x >= self.extra_ids[-1]) | |
| def _best_fit( | |
| self, input_ids: List[np.ndarray], labels: List[np.ndarray] | |
| ) -> Tuple[List[np.ndarray], List[np.ndarray]]: | |
| """ | |
| Pack multiple sequences into batches using first-fit bin packing. | |
| Maximizes GPU utilization by concatenating short sequences until | |
| max_length/max_labels_length constraints are hit. | |
| Returns: | |
| Tuple of (packed_inputs, packed_labels) | |
| """ | |
| batch_inputs = [] | |
| batch_labels = [] | |
| for _ in range(self.batch_size): | |
| bin_inputs = [] | |
| bin_labels = [] | |
| bin_input_length = 0 | |
| bin_label_length = 0 | |
| bin_special_tokens = 0 | |
| for i, (x, item_labels) in enumerate(zip(input_ids, labels)): | |
| if x is None: | |
| continue | |
| size_inputs = x.shape[1] | |
| size_labels = item_labels.shape[1] | |
| num_new_special_tokens = self.is_special_token(x).sum() | |
| # Check if sequence fits in current bin | |
| if ( | |
| bin_input_length + size_inputs < self.max_length | |
| and bin_label_length + size_labels < self.max_labels_length | |
| and bin_special_tokens + num_new_special_tokens | |
| < len(self.extra_ids) | |
| ): | |
| bin_inputs.append(x) | |
| bin_labels.append(item_labels) | |
| bin_input_length += size_inputs | |
| bin_label_length += size_labels | |
| bin_special_tokens += num_new_special_tokens | |
| # Mark as used | |
| input_ids[i] = None | |
| labels[i] = None | |
| if bin_inputs: | |
| batch_inputs.append(np.concatenate(bin_inputs, axis=1)) | |
| batch_labels.append(np.concatenate(bin_labels, axis=1)) | |
| return batch_inputs, batch_labels | |
| def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding: | |
| """ | |
| Process a batch of examples with UL2 denoising. | |
| Args: | |
| examples: List of dicts with 'input_ids' (shape: [1, seq_len]) | |
| Returns: | |
| BatchEncoding with 'input_ids', 'attention_mask', 'labels' | |
| """ | |
| # Filter examples that are too short | |
| examples = [ | |
| x for x in examples if x["input_ids"].shape[1] > self.min_size_inputs | |
| ] | |
| input_batch_size = len(examples) | |
| # Sample denoiser for each example | |
| denoisers_sample = np.random.choice( | |
| range(len(self.denoiser_list)), | |
| input_batch_size, | |
| p=self.denoiser_proportions, | |
| ) | |
| # Add length field if missing | |
| has_length = "length" in examples[0] | |
| if not has_length: | |
| for i in range(len(examples)): | |
| examples[i]["length"] = examples[i]["input_ids"].shape[1] | |
| # Truncate to optimal length for each denoiser | |
| truncated_examples = [] | |
| for i, x in enumerate(examples): | |
| max_len = self.denoiser_optimal_len[denoisers_sample[i]][0] | |
| if x["length"] > max_len: | |
| start = 0 | |
| if self.random_chunk: | |
| start = np.random.randint(0, x["length"] - max_len) | |
| new_input_ids = x["input_ids"][:, start : start + max_len] | |
| truncated_examples.append( | |
| {"input_ids": new_input_ids, "length": np.array(max_len)} | |
| ) | |
| else: | |
| truncated_examples.append(x) | |
| examples = truncated_examples | |
| # Generate noise masks for span corruption | |
| spans_noise_masks = [ | |
| self.random_spans_noise_mask( | |
| x["length"], self.denoiser_list[denoisers_sample[i]] | |
| ) | |
| for i, x in enumerate(examples) | |
| ] | |
| # Create sentinel IDs for masked and unmasked spans | |
| input_ids_sentinel = [ | |
| self.create_sentinel_ids(x.astype(np.int8)) for x in spans_noise_masks | |
| ] | |
| labels_sentinel = [ | |
| self.create_sentinel_ids((~x).astype(np.int8)) for x in spans_noise_masks | |
| ] | |
| # Apply masks and add prefixes | |
| input_ids = [ | |
| self.filter_input_ids( | |
| x["input_ids"], | |
| input_ids_sentinel[i], | |
| np.expand_dims(self.prefixes[denoisers_sample[i]], axis=0), | |
| ) | |
| for i, x in enumerate(examples) | |
| ] | |
| labels = [ | |
| self.filter_input_ids(x["input_ids"], labels_sentinel[i], with_eos=False) | |
| for i, x in enumerate(examples) | |
| ] | |
| # Pack sequences if needed | |
| if len(input_ids) == self.batch_size: | |
| batch_inputs, batch_labels = input_ids, labels | |
| else: | |
| batch_inputs, batch_labels = self._best_fit(input_ids, labels) | |
| # Replace sentinel placeholders with actual extra_id tokens | |
| labels = [ | |
| np.where( | |
| self.is_special_token(x), | |
| self.extra_ids[0] - np.cumsum(self.is_special_token(x)) + 1, | |
| x, | |
| ) | |
| for x in batch_labels | |
| ] | |
| input_ids = [ | |
| np.where( | |
| self.is_special_token(x), | |
| self.extra_ids[0] - np.cumsum(self.is_special_token(x)) + 1, | |
| x, | |
| ) | |
| for x in batch_inputs | |
| ] | |
| # Add EOS to labels | |
| labels = [ | |
| np.concatenate( | |
| [x, np.full((1, 1), self.tokenizer.eos_token_id, dtype=np.int32)], | |
| axis=-1, | |
| ) | |
| for x in labels | |
| ] | |
| # Pad sequences | |
| if self.causal: | |
| # Causal: left-pad inputs, right-pad labels | |
| labels = np.concatenate( | |
| [ | |
| np.pad( | |
| x, | |
| ((0, 0), (0, self.max_labels_length - x.shape[1])), | |
| constant_values=self.tokenizer.pad_token_id, | |
| ) | |
| for x in labels | |
| ], | |
| axis=0, | |
| ) | |
| input_ids = np.concatenate( | |
| [ | |
| np.pad( | |
| x, | |
| ((0, 0), (self.max_length - x.shape[1], 0)), | |
| constant_values=self.tokenizer.pad_token_id, | |
| ) | |
| for x in input_ids | |
| ], | |
| axis=0, | |
| ) | |
| else: | |
| # Encoder-decoder: right-pad everything | |
| labels = np.concatenate( | |
| [ | |
| np.pad( | |
| x, | |
| ((0, 0), (0, self.max_labels_length - x.shape[1])), | |
| constant_values=self.tokenizer.pad_token_id, | |
| ) | |
| for x in labels | |
| ], | |
| axis=0, | |
| ) | |
| input_ids = np.concatenate( | |
| [ | |
| np.pad( | |
| x, | |
| ((0, 0), (0, self.max_length - x.shape[1])), | |
| constant_values=self.tokenizer.pad_token_id, | |
| ) | |
| for x in input_ids | |
| ], | |
| axis=0, | |
| ) | |
| # Fixed batch size padding with wrapping | |
| if self.fixed_batch_size and input_ids.shape[0] < self.batch_size: | |
| input_ids = np.pad( | |
| input_ids, | |
| ((0, self.batch_size - input_ids.shape[0]), (0, 0)), | |
| mode="wrap", | |
| ) | |
| labels = np.pad( | |
| labels, ((0, self.batch_size - labels.shape[0]), (0, 0)), mode="wrap" | |
| ) | |
| # Construct final batch | |
| batch = {} | |
| if not self.causal: | |
| batch["input_ids"] = torch.from_numpy(input_ids) | |
| causal_labels = torch.from_numpy(labels) | |
| else: | |
| # Concatenate inputs and labels for causal LM | |
| batch["input_ids"] = torch.from_numpy( | |
| np.concatenate([input_ids, labels], axis=-1) | |
| ) | |
| causal_labels = batch["input_ids"].clone() | |
| batch["attention_mask"] = batch["input_ids"] != self.tokenizer.pad_token_id | |
| causal_labels[causal_labels == self.tokenizer.pad_token_id] = -100 | |
| batch["labels"] = causal_labels | |
| return batch | |
| def compute_input_and_target_lengths( | |
| self, inputs_length: int, noise_density: float, mean_noise_span_length: float | |
| ) -> Tuple[int, int]: | |
| """ | |
| Compute optimal token lengths to avoid padding. | |
| Given desired input length, noise parameters, and mean span length, | |
| calculates the required number of tokens in raw text and encoded targets. | |
| Adapted from T5's random_spans_helper: | |
| https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/data/preprocessors.py#L2466 | |
| Args: | |
| inputs_length: Desired tokenized input sequence length | |
| noise_density: Fraction of tokens to corrupt (0.0-1.0) | |
| mean_noise_span_length: Average span length for corruption | |
| Returns: | |
| (tokens_length, targets_length) tuple | |
| """ | |
| def _tokens_length_to_inputs_length_targets_length(tokens_length): | |
| num_noise_tokens = int(round(tokens_length * noise_density)) | |
| num_nonnoise_tokens = tokens_length - num_noise_tokens | |
| num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) | |
| # Inputs: all nonnoise tokens + sentinels + EOS | |
| _input_length = num_nonnoise_tokens + num_noise_spans + 1 | |
| _output_length = num_noise_tokens + num_noise_spans + 1 | |
| return _input_length, _output_length | |
| tokens_length = inputs_length | |
| # Special case: causal LM (no corruption) | |
| if noise_density == 0.0: | |
| return ( | |
| self.max_labels_length | |
| - 2 | |
| + int(self.max_length // mean_noise_span_length) | |
| - 2, | |
| inputs_length, | |
| ) | |
| # Find optimal tokens_length | |
| while ( | |
| _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] | |
| <= inputs_length | |
| ): | |
| tokens_length += 1 | |
| inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length( | |
| tokens_length | |
| ) | |
| # Adjust for 50% noise density edge case | |
| if noise_density == 0.5 and targets_length > inputs_length: | |
| tokens_length -= 1 | |
| targets_length -= 1 | |
| return tokens_length, targets_length | |
| def random_spans_noise_mask( | |
| self, sequence_length: int, denoiser_params: Dict | |
| ) -> np.ndarray: | |
| """ | |
| Generate boolean mask for span corruption. | |
| Creates alternating spans of noise/non-noise tokens with configurable | |
| density and span length. Special handling for sequential denoising (max_spans=1). | |
| Adapted from T5's random_spans_noise_mask: | |
| https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/data/preprocessors.py#L2682 | |
| Args: | |
| sequence_length: Length of input sequence | |
| denoiser_params: Dict with 'mu', 'r', 'max_spans' keys | |
| Returns: | |
| Boolean array of shape [sequence_length] where True = noise | |
| """ | |
| mean_noise_span_length = denoiser_params["mu"] | |
| noise_density = denoiser_params["r"] | |
| max_num_spans = denoiser_params["max_spans"] | |
| if max_num_spans == 1: | |
| # Force single span at sequence start (for sequential/causal denoising) | |
| prefix_span = int(np.round(sequence_length / mean_noise_span_length)) | |
| masked_span = sequence_length - prefix_span | |
| interleaved_span_lengths = np.array([prefix_span, masked_span]) | |
| else: | |
| # Standard span corruption | |
| num_noise_tokens = int(np.round(sequence_length * noise_density)) | |
| num_noise_tokens = min(max(num_noise_tokens, 1), sequence_length - 1) | |
| num_noise_spans = min( | |
| max_num_spans, int(np.round(num_noise_tokens / mean_noise_span_length)) | |
| ) | |
| num_noise_spans = max(num_noise_spans, 1) | |
| num_nonnoise_tokens = sequence_length - num_noise_tokens | |
| def _random_segmentation(num_items, num_segments): | |
| """Randomly partition items into non-empty segments.""" | |
| mask_indices = np.arange(num_items - 1) < (num_segments - 1) | |
| np.random.shuffle(mask_indices) | |
| first_in_segment = np.pad(mask_indices, [[1, 0]]) | |
| segment_id = np.cumsum(first_in_segment) | |
| _, segment_length = np.unique(segment_id, return_counts=True) | |
| return segment_length | |
| noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) | |
| nonnoise_span_lengths = _random_segmentation( | |
| num_nonnoise_tokens, num_noise_spans | |
| ) | |
| interleaved_span_lengths = np.reshape( | |
| np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), | |
| [num_noise_spans * 2], | |
| ) | |
| # Convert spans to boolean mask | |
| span_starts = np.cumsum(interleaved_span_lengths)[:-1] | |
| span_start_indicator = np.zeros((sequence_length,), dtype=np.int8) | |
| span_start_indicator[span_starts] = True | |
| span_num = np.cumsum(span_start_indicator) | |
| is_noise = np.equal(span_num % 2, 1) | |
| return is_noise | |
| def create_sentinel_ids(self, mask_indices: np.ndarray) -> np.ndarray: | |
| """ | |
| Convert boolean mask to sentinel token IDs. | |
| Replaces span starts with increasing sentinel IDs. | |
| Consecutive masked tokens are marked for deletion (-1). | |
| Args: | |
| mask_indices: Boolean array where True = masked | |
| Returns: | |
| Array with sentinel IDs at span starts, 0 elsewhere | |
| """ | |
| start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices | |
| start_indices[0] = mask_indices[0] | |
| sentinel_ids = np.where( | |
| start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices | |
| ) | |
| sentinel_ids = np.where( | |
| sentinel_ids != 0, (self.extra_ids[0] - sentinel_ids), 0 | |
| ) | |
| sentinel_ids -= mask_indices - start_indices | |
| return sentinel_ids | |
| def filter_input_ids( | |
| self, | |
| input_ids: np.ndarray, | |
| sentinel_ids: np.ndarray, | |
| prefixes: np.ndarray = None, | |
| with_eos: bool = True, | |
| ) -> np.ndarray: | |
| """ | |
| Apply sentinel mask and fuse consecutive masks. | |
| Replaces original tokens with sentinels, removes -1 markers, | |
| and optionally adds prefix and EOS tokens. | |
| Args: | |
| input_ids: Original token IDs | |
| sentinel_ids: Sentinel mask from create_sentinel_ids | |
| prefixes: Optional prefix tokens to prepend | |
| with_eos: Whether to append EOS token | |
| Returns: | |
| Filtered token ID array | |
| """ | |
| batch_size = input_ids.shape[0] | |
| # Replace with sentinels | |
| input_ids = np.where(sentinel_ids != 0, sentinel_ids, input_ids) | |
| # Remove EOS and negative markers | |
| input_ids = input_ids[input_ids != self.tokenizer.eos_token_id] | |
| input_ids = input_ids[input_ids >= 0].reshape((batch_size, -1)) | |
| # Add prefix | |
| if prefixes is not None: | |
| input_ids = np.concatenate([prefixes, input_ids], axis=-1) | |
| # Add EOS | |
| if with_eos: | |
| input_ids = np.concatenate( | |
| [ | |
| input_ids, | |
| np.full( | |
| (batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32 | |
| ), | |
| ], | |
| axis=-1, | |
| ) | |
| return input_ids | |
| # ============================================================================ | |
| # USAGE EXAMPLE | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| """ | |
| Example: Train T5 with UL2 objectives | |
| Requires: | |
| - Tokenizer with extra_id tokens (T5, UL2, Flan-T5) | |
| - Tokenizer extended with special prefixes: [R], [S], [X] | |
| """ | |
| from transformers import AutoTokenizer | |
| # Load tokenizer (ensure it has extra_id tokens) | |
| tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-base") | |
| # Add special denoiser prefix tokens | |
| special_tokens = {"additional_special_tokens": ["[R]", "[S]", "[X]"]} | |
| tokenizer.add_special_tokens(special_tokens) | |
| # Initialize UL2 collator | |
| collator = DataCollatorForUL2MLM( | |
| tokenizer=tokenizer, | |
| max_length=512, | |
| max_labels_length=128, | |
| batch_size=32, | |
| **UL2_7_DENOISER, # Use 7-denoiser mixture | |
| ) | |
| # Or use standard T5 denoising | |
| # collator = DataCollatorForUL2MLM( | |
| # tokenizer=tokenizer, | |
| # max_length=512, | |
| # max_labels_length=128, | |
| # batch_size=32, | |
| # **T5_STANDARD | |
| # ) | |
| # Create dummy examples | |
| dummy_data = [ | |
| {"input_ids": np.random.randint(0, 1000, (1, 256))} for _ in range(64) | |
| ] | |
| # Collate batch | |
| batch = collator(dummy_data) | |
| print(f"Input IDs shape: {batch['input_ids'].shape}") | |
| print(f"Labels shape: {batch['labels'].shape}") | |
| print(f"Attention mask shape: {batch['attention_mask'].shape}") | |
| # Use with PyTorch DataLoader | |
| # from torch.utils.data import DataLoader | |
| # dataloader = DataLoader(dataset, collate_fn=collator, batch_size=None) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment