Skip to content

Instantly share code, notes, and snippets.

@altescy
Created February 23, 2023 13:06
Show Gist options
  • Save altescy/c6829fae55d5902af4a2a605eccbd4d6 to your computer and use it in GitHub Desktop.
Save altescy/c6829fae55d5902af4a2a605eccbd4d6 to your computer and use it in GitHub Desktop.
A script for training FlaxAutoModelForCausalLM from scratch.
import math
from typing import Any, Callable, Dict, Iterator, List, Mapping, MutableMapping, Optional, Tuple, cast
import datasets
import flax
import jax
import numpy
import optax
import tqdm
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from transformers import CONFIG_MAPPING, AutoTokenizer, FlaxAutoModelForCausalLM
Array = Any
def create_learning_rate_fn(
train_ds_size: int,
train_batch_size: int,
num_train_epochs: int,
num_warmup_steps: int,
learning_rate: float,
) -> Callable[[int], Array]:
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return cast(Callable[[int], Array], schedule_fn)
def decay_mask_fn(params: Mapping[str, Any]) -> Mapping[str, Any]:
flat_params = flax.traverse_util.flatten_dict(params) # type: ignore[no-untyped-call]
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return cast(Mapping[str, Any], flax.traverse_util.unflatten_dict(flat_mask)) # type: ignore[no-untyped-call]
class DataModule:
def __init__(
self,
tokenizer: AutoTokenizer,
block_size: int = 128,
text_column_name: str = "text",
) -> None:
self.tokenizer = tokenizer
self.block_size = block_size
self.text_column_name = text_column_name
def get_model_kwargs(self) -> Dict[str, Any]:
return dict(
vocab_size=self.tokenizer.vocab_size,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
def read_dataset(self, dataset: datasets.Dataset) -> datasets.Dataset:
def tokenize(examples: Mapping[str, Any]) -> Dict[str, Any]:
return cast(Dict[str, Any], self.tokenizer(examples[self.text_column_name], truncation=True))
def group_texts(examples: Mapping[str, Any]) -> Dict[str, Array]:
concatenated_examples: Dict[str, Any] = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
if total_length >= self.block_size:
total_length = (total_length // self.block_size) * self.block_size
result = {
k: [t[i : i + self.block_size] for i in range(0, total_length, self.block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
return cast(
datasets.Dataset,
dataset.flatten()
.map(
tokenize,
batched=True,
remove_columns=dataset.column_names,
)
.map(group_texts, batched=True),
)
class BatchIterator:
def __init__(
self,
rng: Any,
dataset: datasets.Dataset,
batch_size: int,
shuffle: bool,
drop_last: bool,
) -> None:
self.rng = rng
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
batch_idx: Array
if self.shuffle:
batch_idx = numpy.asarray(jax.random.permutation(self.rng, len(self.dataset)))
else:
batch_idx = numpy.arange(len(self.dataset))
if self.drop_last:
steps_per_epoch = len(self.dataset) // self.batch_size
batch_idx = batch_idx[: steps_per_epoch * self.batch_size]
batch_idx = batch_idx.reshape((steps_per_epoch, self.batch_size))
else:
steps_per_epoch = math.ceil(len(self.dataset) / self.batch_size)
batch_idx = numpy.array_split(batch_idx, steps_per_epoch)
self.current_step = 0
self.batch_idx = batch_idx
def __len__(self) -> int:
if self.drop_last:
return len(self.dataset) // self.batch_size
return math.ceil(len(self.dataset) / self.batch_size)
def __next__(self) -> Dict[str, Any]:
if self.current_step >= len(self):
raise StopIteration
batch = self.dataset[self.batch_idx[self.current_step]]
batch = {k: numpy.array(v) for k, v in batch.items()}
self.current_step += 1
return cast(Dict[str, Any], batch)
def __iter__(self) -> Iterator[Dict[str, Any]]:
return self
class DataLoader:
def __init__(self, batch_size: int, shuffle: bool = True, drop_last: bool = True) -> None:
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
def __call__(self, rng: Any, dataset: datasets.Dataset) -> BatchIterator:
return BatchIterator(
rng=rng,
dataset=dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
drop_last=self.drop_last,
)
class TrainState(train_state.TrainState): # type: ignore[no-untyped-call]
dropout_rng: Array
def replicate(self) -> "TrainState":
return cast("TrainState", flax.jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))) # type: ignore[no-untyped-call]
class Trainer:
def __init__(
self,
train_dataloader: Callable[[Any, datasets.Dataset], Iterator[Mapping[str, Any]]],
eval_dataloader: Optional[Callable[[Any, datasets.Dataset], Iterator[Mapping[str, Any]]]] = None,
optimizer: optax.GradientTransformation = optax.adamw(1e-4),
num_epochs: int = 1,
) -> None:
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.optimizer = optimizer
self.num_epochs = num_epochs
def train(
self,
rng: Any,
model: FlaxAutoModelForCausalLM,
train_dataset: datasets.Dataset,
valid_dataset: Optional[datasets.Dataset] = None,
) -> TrainState:
if self.eval_dataloader is None and valid_dataset is not None:
raise ValueError("You must provide an eval_dataloader to use a validation dataset.")
rng, dropout_rng = jax.random.split(rng)
state = TrainState.create( # type: ignore[no-untyped-call]
apply_fn=model.__call__,
params=model.params,
tx=self.optimizer,
dropout_rng=dropout_rng,
)
def loss_fn(logits: Array, labels: Array) -> Array:
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
loss = optax.softmax_cross_entropy(
shift_logits,
onehot(shift_labels, shift_logits.shape[-1]), # type: ignore[no-untyped-call]
)
return loss.mean()
def train_step(state: TrainState, batch: MutableMapping[str, Any]) -> Tuple[TrainState, Dict[str, Any]]:
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params: Any) -> Any:
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = loss_fn(logits, labels)
return loss
grad_fn = jax.value_and_grad(compute_loss)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch") # type: ignore[no-untyped-call]
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) # type: ignore[no-untyped-call]
metrics = {"loss": loss}
metrics = jax.lax.pmean(metrics, axis_name="batch") # type: ignore[no-untyped-call]
return new_state, metrics
def eval_step(params: MutableMapping[str, Any], batch: MutableMapping[str, Any]) -> Dict[str, Any]:
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
loss = loss_fn(logits, labels)
# summarize metrics
metrics = {"loss": loss}
metrics = jax.lax.pmean(metrics, axis_name="batch") # type: ignore[no-untyped-call]
return metrics
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
p_eval_step = jax.pmap(eval_step, "batch")
state = state.replicate()
train_metric_history: List[Dict[str, Any]] = []
for epoch in tqdm.tqdm(range(self.num_epochs), desc="Epoch ...", position=0):
rng, input_rng = jax.random.split(rng)
with tqdm.tqdm(
self.train_dataloader(input_rng, train_dataset),
desc="Training ...",
position=1,
leave=False,
) as train_loader:
for batch in train_loader:
batch = shard(batch) # type: ignore[no-untyped-call]
state, train_metric = p_train_step(state, batch)
train_metric_history.append(train_metric)
train_loader.set_postfix(**train_metric)
if valid_dataset is not None:
assert self.eval_dataloader is not None
eval_metric_history: List[Dict[str, Any]] = []
with tqdm.tqdm(
self.eval_dataloader(input_rng, valid_dataset),
desc="Evaluating ...",
position=2,
leave=False,
) as eval_loader:
for batch in eval_loader:
eval_batch_size = batch["input_ids"].shape[0]
eval_metric = flax.jax_utils.pad_shard_unpad(p_eval_step, static_return=True)( # type: ignore[no-untyped-call]
state.params,
batch,
min_device_batch=eval_batch_size,
)
eval_metric_history.append(eval_metric)
eval_loader.set_postfix(**eval_metric, perplexity=math.exp(eval_metric["loss"]))
eval_metrics = get_metrics(eval_metric_history) # type: ignore[no-untyped-call]
eval_metrics = jax.tree_util.tree_map(lambda x: jax.numpy.mean(x).item(), eval_metrics)
try:
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
except OverflowError:
eval_metrics["perplexity"] = float("inf")
return cast(TrainState, state)
def main() -> None:
model_config = {
"model_type": "gpt2",
"n_layer": 4,
"n_head": 4,
"n_embd": 128,
}
tokenizer = AutoTokenizer.from_pretrained("gpt2")
datamodule = DataModule(tokenizer=tokenizer)
train_dataset = datamodule.read_dataset(
cast(datasets.Dataset, datasets.load_dataset("eli5", split="train_asks[:1000]"))
.flatten()
.map(lambda x: {"text": " ".join(x["answers.text"])}, batched=False)
)
valid_dataset = datamodule.read_dataset(
cast(datasets.Dataset, datasets.load_dataset("eli5", split="validation_asks[:1000]"))
.flatten()
.map(lambda x: {"text": " ".join(x["answers.text"])}, batched=False)
)
trainer = Trainer(
train_dataloader=DataLoader(batch_size=32, shuffle=True, drop_last=True),
eval_dataloader=DataLoader(batch_size=32, shuffle=False, drop_last=False),
optimizer=optax.adamw(learning_rate=1e-4),
num_epochs=3,
)
model = FlaxAutoModelForCausalLM.from_config(
CONFIG_MAPPING[model_config["model_type"]](
**model_config,
**datamodule.get_model_kwargs(),
),
)
trainer.train(
rng=jax.random.PRNGKey(0),
model=model,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment