Created
February 23, 2023 13:06
-
-
Save altescy/c6829fae55d5902af4a2a605eccbd4d6 to your computer and use it in GitHub Desktop.
A script for training FlaxAutoModelForCausalLM from scratch.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
from typing import Any, Callable, Dict, Iterator, List, Mapping, MutableMapping, Optional, Tuple, cast | |
import datasets | |
import flax | |
import jax | |
import numpy | |
import optax | |
import tqdm | |
from flax.training import train_state | |
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key | |
from transformers import CONFIG_MAPPING, AutoTokenizer, FlaxAutoModelForCausalLM | |
Array = Any | |
def create_learning_rate_fn( | |
train_ds_size: int, | |
train_batch_size: int, | |
num_train_epochs: int, | |
num_warmup_steps: int, | |
learning_rate: float, | |
) -> Callable[[int], Array]: | |
steps_per_epoch = train_ds_size // train_batch_size | |
num_train_steps = steps_per_epoch * num_train_epochs | |
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) | |
decay_fn = optax.linear_schedule( | |
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps | |
) | |
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) | |
return cast(Callable[[int], Array], schedule_fn) | |
def decay_mask_fn(params: Mapping[str, Any]) -> Mapping[str, Any]: | |
flat_params = flax.traverse_util.flatten_dict(params) # type: ignore[no-untyped-call] | |
# find out all LayerNorm parameters | |
layer_norm_candidates = ["layernorm", "layer_norm", "ln"] | |
layer_norm_named_params = { | |
layer[-2:] | |
for layer_norm_name in layer_norm_candidates | |
for layer in flat_params.keys() | |
if layer_norm_name in "".join(layer).lower() | |
} | |
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} | |
return cast(Mapping[str, Any], flax.traverse_util.unflatten_dict(flat_mask)) # type: ignore[no-untyped-call] | |
class DataModule: | |
def __init__( | |
self, | |
tokenizer: AutoTokenizer, | |
block_size: int = 128, | |
text_column_name: str = "text", | |
) -> None: | |
self.tokenizer = tokenizer | |
self.block_size = block_size | |
self.text_column_name = text_column_name | |
def get_model_kwargs(self) -> Dict[str, Any]: | |
return dict( | |
vocab_size=self.tokenizer.vocab_size, | |
bos_token_id=self.tokenizer.bos_token_id, | |
eos_token_id=self.tokenizer.eos_token_id, | |
) | |
def read_dataset(self, dataset: datasets.Dataset) -> datasets.Dataset: | |
def tokenize(examples: Mapping[str, Any]) -> Dict[str, Any]: | |
return cast(Dict[str, Any], self.tokenizer(examples[self.text_column_name], truncation=True)) | |
def group_texts(examples: Mapping[str, Any]) -> Dict[str, Array]: | |
concatenated_examples: Dict[str, Any] = {k: sum(examples[k], []) for k in examples.keys()} | |
total_length = len(concatenated_examples[list(examples.keys())[0]]) | |
if total_length >= self.block_size: | |
total_length = (total_length // self.block_size) * self.block_size | |
result = { | |
k: [t[i : i + self.block_size] for i in range(0, total_length, self.block_size)] | |
for k, t in concatenated_examples.items() | |
} | |
result["labels"] = result["input_ids"].copy() | |
return result | |
return cast( | |
datasets.Dataset, | |
dataset.flatten() | |
.map( | |
tokenize, | |
batched=True, | |
remove_columns=dataset.column_names, | |
) | |
.map(group_texts, batched=True), | |
) | |
class BatchIterator: | |
def __init__( | |
self, | |
rng: Any, | |
dataset: datasets.Dataset, | |
batch_size: int, | |
shuffle: bool, | |
drop_last: bool, | |
) -> None: | |
self.rng = rng | |
self.dataset = dataset | |
self.batch_size = batch_size | |
self.shuffle = shuffle | |
self.drop_last = drop_last | |
batch_idx: Array | |
if self.shuffle: | |
batch_idx = numpy.asarray(jax.random.permutation(self.rng, len(self.dataset))) | |
else: | |
batch_idx = numpy.arange(len(self.dataset)) | |
if self.drop_last: | |
steps_per_epoch = len(self.dataset) // self.batch_size | |
batch_idx = batch_idx[: steps_per_epoch * self.batch_size] | |
batch_idx = batch_idx.reshape((steps_per_epoch, self.batch_size)) | |
else: | |
steps_per_epoch = math.ceil(len(self.dataset) / self.batch_size) | |
batch_idx = numpy.array_split(batch_idx, steps_per_epoch) | |
self.current_step = 0 | |
self.batch_idx = batch_idx | |
def __len__(self) -> int: | |
if self.drop_last: | |
return len(self.dataset) // self.batch_size | |
return math.ceil(len(self.dataset) / self.batch_size) | |
def __next__(self) -> Dict[str, Any]: | |
if self.current_step >= len(self): | |
raise StopIteration | |
batch = self.dataset[self.batch_idx[self.current_step]] | |
batch = {k: numpy.array(v) for k, v in batch.items()} | |
self.current_step += 1 | |
return cast(Dict[str, Any], batch) | |
def __iter__(self) -> Iterator[Dict[str, Any]]: | |
return self | |
class DataLoader: | |
def __init__(self, batch_size: int, shuffle: bool = True, drop_last: bool = True) -> None: | |
self.batch_size = batch_size | |
self.shuffle = shuffle | |
self.drop_last = drop_last | |
def __call__(self, rng: Any, dataset: datasets.Dataset) -> BatchIterator: | |
return BatchIterator( | |
rng=rng, | |
dataset=dataset, | |
batch_size=self.batch_size, | |
shuffle=self.shuffle, | |
drop_last=self.drop_last, | |
) | |
class TrainState(train_state.TrainState): # type: ignore[no-untyped-call] | |
dropout_rng: Array | |
def replicate(self) -> "TrainState": | |
return cast("TrainState", flax.jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))) # type: ignore[no-untyped-call] | |
class Trainer: | |
def __init__( | |
self, | |
train_dataloader: Callable[[Any, datasets.Dataset], Iterator[Mapping[str, Any]]], | |
eval_dataloader: Optional[Callable[[Any, datasets.Dataset], Iterator[Mapping[str, Any]]]] = None, | |
optimizer: optax.GradientTransformation = optax.adamw(1e-4), | |
num_epochs: int = 1, | |
) -> None: | |
self.train_dataloader = train_dataloader | |
self.eval_dataloader = eval_dataloader | |
self.optimizer = optimizer | |
self.num_epochs = num_epochs | |
def train( | |
self, | |
rng: Any, | |
model: FlaxAutoModelForCausalLM, | |
train_dataset: datasets.Dataset, | |
valid_dataset: Optional[datasets.Dataset] = None, | |
) -> TrainState: | |
if self.eval_dataloader is None and valid_dataset is not None: | |
raise ValueError("You must provide an eval_dataloader to use a validation dataset.") | |
rng, dropout_rng = jax.random.split(rng) | |
state = TrainState.create( # type: ignore[no-untyped-call] | |
apply_fn=model.__call__, | |
params=model.params, | |
tx=self.optimizer, | |
dropout_rng=dropout_rng, | |
) | |
def loss_fn(logits: Array, labels: Array) -> Array: | |
shift_logits = logits[..., :-1, :] | |
shift_labels = labels[..., 1:] | |
loss = optax.softmax_cross_entropy( | |
shift_logits, | |
onehot(shift_labels, shift_logits.shape[-1]), # type: ignore[no-untyped-call] | |
) | |
return loss.mean() | |
def train_step(state: TrainState, batch: MutableMapping[str, Any]) -> Tuple[TrainState, Dict[str, Any]]: | |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) | |
def compute_loss(params: Any) -> Any: | |
labels = batch.pop("labels") | |
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] | |
loss = loss_fn(logits, labels) | |
return loss | |
grad_fn = jax.value_and_grad(compute_loss) | |
loss, grad = grad_fn(state.params) | |
grad = jax.lax.pmean(grad, "batch") # type: ignore[no-untyped-call] | |
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) # type: ignore[no-untyped-call] | |
metrics = {"loss": loss} | |
metrics = jax.lax.pmean(metrics, axis_name="batch") # type: ignore[no-untyped-call] | |
return new_state, metrics | |
def eval_step(params: MutableMapping[str, Any], batch: MutableMapping[str, Any]) -> Dict[str, Any]: | |
labels = batch.pop("labels") | |
logits = model(**batch, params=params, train=False)[0] | |
loss = loss_fn(logits, labels) | |
# summarize metrics | |
metrics = {"loss": loss} | |
metrics = jax.lax.pmean(metrics, axis_name="batch") # type: ignore[no-untyped-call] | |
return metrics | |
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) | |
p_eval_step = jax.pmap(eval_step, "batch") | |
state = state.replicate() | |
train_metric_history: List[Dict[str, Any]] = [] | |
for epoch in tqdm.tqdm(range(self.num_epochs), desc="Epoch ...", position=0): | |
rng, input_rng = jax.random.split(rng) | |
with tqdm.tqdm( | |
self.train_dataloader(input_rng, train_dataset), | |
desc="Training ...", | |
position=1, | |
leave=False, | |
) as train_loader: | |
for batch in train_loader: | |
batch = shard(batch) # type: ignore[no-untyped-call] | |
state, train_metric = p_train_step(state, batch) | |
train_metric_history.append(train_metric) | |
train_loader.set_postfix(**train_metric) | |
if valid_dataset is not None: | |
assert self.eval_dataloader is not None | |
eval_metric_history: List[Dict[str, Any]] = [] | |
with tqdm.tqdm( | |
self.eval_dataloader(input_rng, valid_dataset), | |
desc="Evaluating ...", | |
position=2, | |
leave=False, | |
) as eval_loader: | |
for batch in eval_loader: | |
eval_batch_size = batch["input_ids"].shape[0] | |
eval_metric = flax.jax_utils.pad_shard_unpad(p_eval_step, static_return=True)( # type: ignore[no-untyped-call] | |
state.params, | |
batch, | |
min_device_batch=eval_batch_size, | |
) | |
eval_metric_history.append(eval_metric) | |
eval_loader.set_postfix(**eval_metric, perplexity=math.exp(eval_metric["loss"])) | |
eval_metrics = get_metrics(eval_metric_history) # type: ignore[no-untyped-call] | |
eval_metrics = jax.tree_util.tree_map(lambda x: jax.numpy.mean(x).item(), eval_metrics) | |
try: | |
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) | |
except OverflowError: | |
eval_metrics["perplexity"] = float("inf") | |
return cast(TrainState, state) | |
def main() -> None: | |
model_config = { | |
"model_type": "gpt2", | |
"n_layer": 4, | |
"n_head": 4, | |
"n_embd": 128, | |
} | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
datamodule = DataModule(tokenizer=tokenizer) | |
train_dataset = datamodule.read_dataset( | |
cast(datasets.Dataset, datasets.load_dataset("eli5", split="train_asks[:1000]")) | |
.flatten() | |
.map(lambda x: {"text": " ".join(x["answers.text"])}, batched=False) | |
) | |
valid_dataset = datamodule.read_dataset( | |
cast(datasets.Dataset, datasets.load_dataset("eli5", split="validation_asks[:1000]")) | |
.flatten() | |
.map(lambda x: {"text": " ".join(x["answers.text"])}, batched=False) | |
) | |
trainer = Trainer( | |
train_dataloader=DataLoader(batch_size=32, shuffle=True, drop_last=True), | |
eval_dataloader=DataLoader(batch_size=32, shuffle=False, drop_last=False), | |
optimizer=optax.adamw(learning_rate=1e-4), | |
num_epochs=3, | |
) | |
model = FlaxAutoModelForCausalLM.from_config( | |
CONFIG_MAPPING[model_config["model_type"]]( | |
**model_config, | |
**datamodule.get_model_kwargs(), | |
), | |
) | |
trainer.train( | |
rng=jax.random.PRNGKey(0), | |
model=model, | |
train_dataset=train_dataset, | |
valid_dataset=valid_dataset, | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment