Created
August 15, 2023 00:41
-
-
Save napsternxg/86f3e1238ea66e12a39919687c085995 to your computer and use it in GitHub Desktop.
T5 CausalLM Constrained Generation Using Tries
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import pandas as pd | |
import torch | |
import transformers | |
from accelerate import Accelerator | |
from datasets import Dataset | |
from torch.utils.data import DataLoader | |
from tqdm.auto import tqdm | |
from t5_training_utils import ( | |
GenerationType, | |
build_prefix_allowed_tokens_fn, | |
convert_to_features, | |
get_gen_type_attributes, | |
get_model_full_name, | |
get_prediction_name, | |
) | |
torch_dtype = "auto" | |
model_ckpt = "t5-base" | |
gen_type = GenerationType.ALL_TOKENS | |
input_max_length = 512 | |
label_max_length = 6 | |
use_task_prefix = True | |
class_names = [ | |
"Soccer", | |
"Cricket", | |
"Handball", | |
"Snow Cycling", | |
] | |
non_eligible_classes = { | |
"Snow Cycling" | |
} | |
non_eligible_idx = [ | |
i for i, c in enumerate(class_names) if c in non_eligible_classes | |
] | |
num_classes = len(class_names) | |
# Model training | |
### Uncomment a config section for the model type | |
## For small test run | |
train_batch_size = 8 | |
eval_batch_size = 8 | |
epochs = 30 | |
save_every_k_epochs = 5 | |
seed = 3333 | |
torch.manual_seed(seed) | |
logging_steps = 100 # len(squad["train"]) // batch_size | |
eval_step = 100 | |
learning_rate = 2e-5 | |
weight_decay = 0.01 | |
data_version = "guidelines-fixed-occasion" | |
model_full_name = get_model_full_name(model_ckpt, gen_type, epochs, data_version) | |
def get_model(model_local_ckpt): | |
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_local_ckpt) | |
return model.eval() | |
def get_dataset( | |
data_path, tokenizer, class_text_map, task_prefix, accelerator, nrows=1024, offset=0 | |
): | |
df_data = pd.read_csv(data_path, sep="\t", nrows=nrows + offset).rename( | |
columns={"query": "text"} | |
) | |
df_data = df_data.iloc[offset : offset + nrows] | |
print(df_data) | |
dataset = Dataset.from_pandas(df_data) | |
dataset.reset_format() | |
with accelerator.main_process_first(): | |
dataset = dataset.map( | |
functools.partial( | |
convert_to_features, | |
class_text_map=class_text_map, | |
task_prefix=task_prefix, | |
query_key="text", | |
label_key=None, | |
tokenizer=tokenizer, | |
) | |
) | |
return dataset, df_data | |
def get_predictions_accelerate(data_path, model_local_ckpt, nrows=1024, offset=0): | |
accelerator = Accelerator() | |
device = accelerator.device | |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_ckpt) | |
class_text_map, max_decoding_length, task_prefix = get_gen_type_attributes( | |
gen_type, tokenizer, class_names | |
) | |
task_prefix = task_prefix if use_task_prefix else "" | |
dataset, df_data = get_dataset( | |
data_path, | |
tokenizer, | |
class_text_map, | |
task_prefix, | |
accelerator, | |
nrows=nrows, | |
offset=offset, | |
) | |
model = get_model(model_local_ckpt) | |
model = model.to(device) | |
allowed_sequences = [[0] + tokenizer.encode(x) for x in class_text_map.values()] | |
dataset.set_format("pt") | |
custom_dataloader = DataLoader( | |
dataset, shuffle=True, batch_size=eval_batch_size, num_workers=4 | |
) | |
model, custom_dataloader = accelerator.prepare(model, custom_dataloader) | |
preds = [] | |
with torch.no_grad(): | |
for batch in tqdm( | |
custom_dataloader, disable=not accelerator.is_local_main_process | |
): | |
batch_input_ids = batch["input_ids"].to(device) | |
batch_attention_mask = batch["attention_mask"].to(device) | |
# For DDP models use accelerator.unwrap_model(model).generate(inputs) | |
# Taken from: https://github.com/huggingface/transformers/issues/18974 | |
batch_outs = accelerator.unwrap_model(model).generate( | |
input_ids=batch_input_ids, | |
attention_mask=batch_attention_mask, | |
max_length=max_decoding_length, | |
prefix_allowed_tokens_fn=build_prefix_allowed_tokens_fn( | |
allowed_sequences | |
), | |
) | |
batch_outs = accelerator.pad_across_processes( | |
batch_outs, dim=1, pad_index=tokenizer.pad_token_id | |
) | |
batch_outs = accelerator.gather_for_metrics(batch_outs).cpu().numpy() | |
preds.extend(tokenizer.batch_decode(batch_outs, skip_special_tokens=True)) | |
accelerator.wait_for_everyone() | |
if accelerator.is_main_process: | |
if len(preds) != len(dataset): | |
raise ValueError( | |
f"Predictions and labels have different lengths. preds: {len(preds)} " | |
f"labels: {len(dataset)}" | |
) | |
pred_col = get_prediction_name(model_full_name) | |
df_data[pred_col] = preds | |
class_text_map_reversed = {val: key for key, val in class_text_map.items()} | |
df_data[pred_col] = df_data[pred_col].apply(lambda x: class_text_map_reversed[x]) | |
# eligible = ~df_test[pred_col].isin(non_eligible_classes) | |
eligible = ~df_data[pred_col].isin( | |
{v for v in non_eligible_classes if v != "Cricket"} | |
) | |
df_data["eligible_pred"] = eligible | |
output_path = data_path.replace(".tsv", f".predicted.{offset}.{nrows}.tsv") | |
print(df_data) | |
print(f"Writing df_data with predictions to {output_path}") | |
df_data.to_csv(output_path, sep="\t", index=False) | |
return df_data | |
def main(): | |
data_path = "data.tsv" | |
offset = 400_000 | |
nrows = 153 # 600_000 | |
model_local_ckpt = "./model_path/checkpoint-2830" | |
print(data_path) | |
print(nrows) | |
print(model_local_ckpt) | |
get_predictions_accelerate(data_path, model_local_ckpt, nrows=nrows, offset=offset) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import enum | |
import os | |
import string | |
from typing import List, Mapping | |
import marisa_trie | |
import torch | |
class GenerationType(enum.Enum): | |
ALL_TOKENS = "all" | |
THREE_TOKENS = "gen3" | |
TWO_TOKENS = "gen2" | |
ONE_TOKEN = "gen1" | |
OPTION_ID = "abcd" | |
def get_model_full_name( | |
model_ckpt, gen_type: GenerationType, epochs: int, data_version: str = "0" | |
): | |
model_base_name = model_ckpt.split("/")[-1] | |
gen_type = gen_type.value | |
model_full_name = f"{model_base_name}_{gen_type}_ep{epochs}_dt{data_version}" | |
return model_full_name | |
def get_outdir(model_name: str): | |
return os.path.join("./classifers/", model_name) | |
def get_prediction_name(model_name: str): | |
return f"{model_name}_predicted" | |
class MarisaTrie(object): | |
def __init__( | |
self, | |
sequences: List[List[int]] = [], | |
cache_fist_branch=True, | |
max_token_id=256001, | |
): | |
self.int2char = [chr(i) for i in range(min(max_token_id, 55000))] + ( | |
[chr(i) for i in range(65000, max_token_id + 10000)] | |
if max_token_id >= 55000 | |
else [] | |
) | |
self.char2int = {self.int2char[i]: i for i in range(max_token_id)} | |
self.cache_fist_branch = cache_fist_branch | |
if self.cache_fist_branch: | |
self.zero_iter = list({sequence[0] for sequence in sequences}) | |
assert len(self.zero_iter) == 1 | |
self.first_iter = list({sequence[1] for sequence in sequences}) | |
self.trie = marisa_trie.Trie( | |
"".join([self.int2char[i] for i in sequence]) for sequence in sequences | |
) | |
def get(self, prefix_sequence: List[int]): | |
if self.cache_fist_branch and len(prefix_sequence) == 0: | |
return self.zero_iter | |
elif ( | |
self.cache_fist_branch | |
and len(prefix_sequence) == 1 | |
and self.zero_iter == prefix_sequence | |
): | |
return self.first_iter | |
else: | |
key = "".join([self.int2char[i] for i in prefix_sequence]) | |
return list( | |
{ | |
self.char2int[e[len(key)]] | |
for e in self.trie.keys(key) | |
if len(e) > len(key) | |
} | |
) | |
def __iter__(self): | |
for sequence in self.trie.iterkeys(): | |
yield [self.char2int[e] for e in sequence] | |
def __len__(self): | |
return len(self.trie) | |
def __getitem__(self, value): | |
return self.get(value) | |
def map_class_name(tokenizer, class_raw_name, num_tokens=None, delim=" "): | |
if num_tokens is None: | |
return class_raw_name | |
class_raw_words = class_raw_name.split(delim) | |
for i in range(1, len(class_raw_words) + 1): | |
class_name_candidate = " ".join(class_raw_words[:i]) | |
tokens = tokenizer.tokenize(class_name_candidate) | |
if len(tokens) == num_tokens: | |
return class_name_candidate | |
raise ValueError( | |
f"Cannot find class name at the specificed num_tokens: {class_raw_name}, {num_tokens}" | |
) | |
def create_class_text_map(tokenizer, class_raw_names, num_tokens, delim="_"): | |
res = {} | |
for raw_name in class_raw_names: | |
res[raw_name] = map_class_name(tokenizer, raw_name, num_tokens, delim=delim) | |
return res | |
def get_task_prefix(gen_type, class_text_map): | |
task_prefix = "Classify query intent into one of the following categories: " | |
if ( | |
gen_type == GenerationType.TWO_TOKENS | |
or gen_type == GenerationType.ONE_TOKEN | |
or gen_type == GenerationType.ALL_TOKENS | |
): | |
classes = [f"'{x}'" for x in class_text_map.values()] | |
task_prefix += ", ".join(classes) | |
task_prefix += ". query: " | |
elif gen_type == GenerationType.OPTION_ID: | |
classes = [ | |
f"{val}: {' '.join(key.split('_')[:-1])}" | |
for key, val in class_text_map.items() | |
] | |
task_prefix += "\n" + "\n".join(classes) | |
task_prefix += "\nquery: " | |
return task_prefix | |
def get_gen_type_attributes(gen_type, tokenizer, class_names): | |
if gen_type == GenerationType.THREE_TOKENS: | |
class_text_map = create_class_text_map(tokenizer, class_names, 3) | |
max_decoding_length = 3 | |
elif gen_type == GenerationType.TWO_TOKENS: | |
class_text_map = create_class_text_map(tokenizer, class_names, 2) | |
max_decoding_length = 2 | |
elif gen_type == GenerationType.ONE_TOKEN: | |
class_text_map = create_class_text_map(tokenizer, class_names, 1) | |
max_decoding_length = 1 | |
elif gen_type == GenerationType.ALL_TOKENS: | |
class_text_map = create_class_text_map(tokenizer, class_names, None) | |
max_decoding_length = max( | |
[len([0] + tokenizer.encode(x)) for x in class_text_map.values()] | |
) | |
elif gen_type == GenerationType.OPTION_ID: | |
class_text_map = {} | |
for i, raw_name in enumerate(class_names): | |
class_text_map[raw_name] = string.ascii_uppercase[i] | |
max_decoding_length = 2 | |
else: | |
raise ValueError(f"Non-existent `gen_type`: {gen_type}") | |
task_prefix = get_task_prefix(gen_type, class_text_map) | |
return class_text_map, max_decoding_length, task_prefix | |
def convert_to_features( | |
example_batch, | |
class_text_map: Mapping[str, str], | |
task_prefix: str, | |
input_max_length=512, | |
label_max_length=16, | |
query_key="query", | |
label_key="expected_single", | |
class_names=None, | |
tokenizer=None, | |
): | |
q = example_batch[query_key] | |
example_batch["input_text"] = f"{task_prefix}{q}" | |
input_encodings = tokenizer( | |
example_batch["input_text"], | |
padding="max_length", | |
max_length=input_max_length, | |
truncation=True, | |
) | |
encodings = { | |
"inputs": example_batch["input_text"], | |
"input_ids": input_encodings["input_ids"], | |
"attention_mask": input_encodings["attention_mask"], | |
} | |
if label_key: | |
label = class_text_map[class_names[example_batch[label_key]]] | |
example_batch["target_text"] = f"{label}" | |
target_encodings = tokenizer( | |
example_batch["target_text"], | |
padding="max_length", | |
max_length=label_max_length, | |
truncation=True, | |
) | |
encodings["labels"] = target_encodings["input_ids"] | |
return encodings | |
def preprocess_logits_for_metrics(logits, labels): | |
""" | |
Original Trainer may have a memory leak. | |
This is a workaround to avoid storing too many tensors that are not needed. | |
""" | |
pred_ids = torch.argmax(logits[0], dim=-1) | |
return pred_ids, labels | |
def build_prefix_allowed_tokens_fn(allowed_sequences): | |
"""Returns a function that provides next allowed tokens based on the prefix `seq`.""" | |
t = MarisaTrie(allowed_sequences) | |
def fn(unused_batch_id, seq): | |
return t.get(seq) | |
return fn | |
def process_golden_labels(example_batch, class_text_map, class_names): | |
def fn(expected_text): | |
return [class_text_map[y.strip()] for y in expected_text.split(",")] | |
# example_batch['golden_labels'] = fn(example_batch['expected']) | |
example_batch["golden_labels"] = fn(class_names[example_batch["Label"]]) | |
return example_batch | |
def compute_metrics(preds): | |
return {} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment