Last active
August 2, 2021 00:25
-
-
Save enijkamp/79ea43b81d0e9af232cccf96c3d13168 to your computer and use it in GitHub Desktop.
resharding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
python3.8 -m venv .venv | |
source .venv/bin/activate | |
pip install --upgrade pip setuptools | |
pip install -r requirements.txt | |
pip install --upgrade jax==0.2.12 jaxlib==0.1.67+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html | |
''' | |
import os | |
# xla | |
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' | |
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform' | |
import io | |
import time | |
import multiprocessing | |
import numpy as np | |
import jax | |
import jax.numpy as jnp | |
from jax.experimental import maps | |
import optax | |
import transformers | |
from mesh_transformer.sampling import nucleaus_sample | |
from mesh_transformer.transformer_shard import CausalTransformer | |
import mesh_transformer.util as util | |
def tree_flatten_with_names(pytree, is_leaf, path='', to_id=id): | |
id_to_name = {} | |
if getattr(pytree, 'items', None): | |
for k, v in pytree.items(): | |
k_path = f'{path}/{k}' | |
if is_leaf(v): | |
id_to_name[to_id(v)] = k_path | |
else: | |
id_to_name = { **id_to_name, **tree_flatten_with_names(v, is_leaf=is_leaf, path=k_path)} | |
elif getattr(pytree, '__getitem__', None): | |
for v in pytree: | |
if is_leaf(v): | |
id_to_name[to_id(v)] = path | |
else: | |
id_to_name = { **id_to_name, **tree_flatten_with_names(v, is_leaf=is_leaf, path=path)} | |
else: | |
id_to_name[to_id(pytree)] = path | |
return id_to_name | |
def tree_leaves_with_names(pytree, to_id=id): | |
leaves = jax.tree_leaves(pytree) | |
is_leaf = lambda x: not isinstance(x, list) and to_id(x) in [to_id(x) for x in leaves] | |
return tree_flatten_with_names(pytree, is_leaf) | |
def get_tree_leaves_names_original(params): | |
jax.config.update('jax_platform_name', 'cpu') | |
params['optimizer'] = optax.chain( | |
optax.scale(1), | |
util.clip_by_global_norm(1), | |
optax.scale_by_adam(), | |
optax.additive_weight_decay(0), | |
optax.scale(-1), | |
optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0)) | |
) | |
devices = np.array([jax.devices()[0]]).reshape((1, 1)) | |
with jax.experimental.maps.mesh(devices, ('dp', 'mp')): | |
network = CausalTransformer(params) | |
leaves_ids = tree_leaves_with_names(network.state, to_id=id) | |
leaves = jax.tree_leaves(network.state) | |
leaves_names = [leaves_ids[id(l)] for l in leaves] | |
return leaves_names | |
def get_tree_leaves_names_reduced(params): | |
jax.config.update('jax_platform_name', 'cpu') | |
params['optimizer'] = optax.scale(0) | |
devices = np.array([jax.devices()[0]]).reshape((1, 1)) | |
with jax.experimental.maps.mesh(devices, ('dp', 'mp')): | |
network = CausalTransformer(params) | |
leaves_ids = tree_leaves_with_names(network.state, to_id=id) | |
leaves = jax.tree_leaves(network.state) | |
leaves_names = [leaves_ids[id(l)] for l in leaves] | |
return leaves_names | |
# TODO(nijkamp): rewrite this mess | |
def reshard(x, old_shape, do_shard_ln, do_shard_bias): | |
if len(x.shape) == 1: | |
# print("step") | |
out = x[0:1] | |
elif len(x.shape) == 2: | |
# print(f"LN/bias {x.shape}") | |
# TODO(nijkamp): incorrect | |
# if (x[1:] == x[-1]).all(): | |
if do_shard_ln or do_shard_bias: | |
# print("LN") | |
# if (x[1:] == 0).all() or (x[1:] == 1).all(): | |
if do_shard_ln: | |
# TODO(nijkamp): for thise case, expression (x[1:] == 0).all() or (x[1:] == 1).all() should hold | |
out = x[0:1] | |
else: | |
# print("shard bias") | |
# out = x[0:1] * x.shape[0] / old_shape[0] | |
# TODO(nijkamp): sum() bias terms, is this correct? | |
out = np.reshape(np.sum(x, axis=0), old_shape) | |
else: | |
# print("bias") | |
out = x.reshape(old_shape) | |
elif len(x.shape) == 3: | |
# print(f"weight {x.shape}") | |
if x.shape[0] * x.shape[2] == old_shape[2]: | |
out = jnp.transpose(x, (1, 0, 2)).reshape(old_shape) | |
elif x.shape[0] * x.shape[1] == old_shape[1]: | |
out = x.reshape(old_shape) | |
else: | |
raise Exception(f"unimplemented, {x.shape}, {old_shape}") | |
else: | |
raise Exception(f"unimplemented, {x}") | |
return out | |
def read_shard(ckpt_dir, pieces=16): | |
out = [] | |
for idx in range(pieces): | |
file_path = ckpt_dir + f"{idx}.npz" | |
with open(file_path, "rb") as f: | |
buf = f.read() | |
f_io = io.BytesIO(buf) | |
deserialized = np.load(f_io) | |
for i in deserialized: | |
out.append(deserialized[i]) | |
return out | |
def read_ckpt(pytree, leaves_names, leaves_names_reduced, dir, shards_in, shards_out): | |
old_flattened, structure = jax.tree_flatten(pytree) | |
print(len(old_flattened)) | |
print(len(leaves_names)) | |
print(len(leaves_names_reduced)) | |
# TODO(nijkamp): rewrite this mess | |
with multiprocessing.pool.ThreadPool(shards_in) as p: | |
start = time.time() | |
shards = list((p.imap(read_shard, [f"{dir}shard_{i}/" for i in range(shards_in)]))) | |
print(f"read checkpoint in {time.time() - start:.06}s") | |
unsharded = [] | |
old_i = 0 | |
for i, (leave_name, *all_shards) in enumerate(zip(leaves_names, *shards)): | |
if not leave_name.startswith('/params') and not leave_name.startswith('/step'): | |
continue | |
assert leave_name == leaves_names_reduced[old_i], f'{leave_name} {leaves_names_reduced[old_i]}' | |
old = old_flattened[old_i] | |
old_i += 1 | |
x = np.stack(all_shards) | |
# TODO(nijkamp): what is this? | |
if x.dtype == np.dtype('V2'): | |
x.dtype = jnp.bfloat16 | |
print(i, old_i, leave_name) | |
print(x.shape) | |
print(old.shape) | |
if shards_out != shards_in: | |
x = reshard(x, old.shape, do_shard_bias=leave_name.endswith('embedding_shard/~/linear/b') or leave_name.endswith('linear_5/b'), do_shard_ln=leave_name.endswith('replicated_layer_norm/offset') or leave_name.endswith('replicated_layer_norm/scale')) | |
unsharded.append(x) | |
assert x.shape == old.shape, f"Incompatible checkpoints {x.shape} vs {old.shape} (leave_name)" | |
loaded_pytree = jax.tree_unflatten(structure, unsharded) | |
return loaded_pytree | |
def sample(tokenizer, per_replica_batch, seq, network, context, top_k=40, top_p=0.9, temp=1.0, gen_len=512): | |
tokens = tokenizer.encode(context) | |
provided_ctx = len(tokens) | |
pad_amount = seq - provided_ctx | |
padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32) | |
batched_tokens = np.array([padded_tokens] * per_replica_batch) | |
length = np.ones(per_replica_batch, dtype=np.uint32) * len(tokens) | |
start = time.time() | |
output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(per_replica_batch) * top_p, "top_k": top_k is not None and (np.ones(per_replica_batch, dtype=np.int32) * top_k) or None, "temp": np.ones(per_replica_batch) * temp}) | |
samples = [] | |
decoded_tokens = output[1][0] | |
for o in decoded_tokens[:, :, 0]: | |
samples.append(tokenizer.decode(o)) | |
print(f"completion done in {time.time() - start:06}s") | |
return samples | |
def main(): | |
params = { | |
"layers": 28, | |
"d_model": 4096, | |
"n_heads": 16, | |
"n_vocab": 50400, | |
"norm": "layernorm", | |
"pe": "rotary", | |
"pe_rotary_dims": 64, | |
"early_cast": True, | |
"seq": 2048, | |
"cores_per_replica": 1, | |
"per_replica_batch": 1, | |
} | |
# leaves_names_original = get_tree_leaves_names_original(params) | |
# print(leaves_names_original) | |
leaves_names_original = ['/opt_state', '/opt_state/causal_transformer_shard/~/embedding_shard/~/linear/b', '/opt_state/causal_transformer_shard/~/embedding_shard/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_0/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_0/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_1/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_1/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_1/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_10/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_10/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_10/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_11/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_11/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_11/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_12/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_12/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_12/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_13/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_13/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_13/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_14/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_14/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_14/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_15/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_15/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_15/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_16/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_16/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_16/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_17/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_17/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_17/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_18/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_18/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_18/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_19/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_19/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_19/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_2/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_2/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_2/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_20/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_20/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_20/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_21/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_21/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_21/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_22/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_22/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_22/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_23/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_23/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_23/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_24/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_24/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_24/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_25/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_25/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_25/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_26/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_26/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_26/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_27/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_27/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_27/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_3/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_3/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_3/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_4/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_4/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_4/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_5/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_5/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_5/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_6/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_6/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_6/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_7/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_7/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_7/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_8/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_8/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_8/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_9/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_9/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_9/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/projection_shard/~/linear/b', '/opt_state/causal_transformer_shard/~/projection_shard/~/linear/w', '/opt_state/causal_transformer_shard/~/projection_shard/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/projection_shard/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/embedding_shard/~/linear/b', '/opt_state/causal_transformer_shard/~/embedding_shard/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_0/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_0/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_1/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_1/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_1/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_10/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_10/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_10/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_11/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_11/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_11/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_12/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_12/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_12/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_13/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_13/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_13/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_14/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_14/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_14/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_15/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_15/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_15/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_16/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_16/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_16/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_17/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_17/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_17/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_18/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_18/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_18/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_19/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_19/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_19/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_2/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_2/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_2/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_20/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_20/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_20/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_21/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_21/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_21/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_22/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_22/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_22/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_23/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_23/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_23/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_24/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_24/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_24/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_25/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_25/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_25/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_26/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_26/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_26/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_27/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_27/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_27/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_3/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_3/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_3/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_4/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_4/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_4/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_5/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_5/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_5/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_6/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_6/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_6/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_7/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_7/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_7/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_8/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_8/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_8/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_9/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_9/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_9/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/projection_shard/~/linear/b', '/opt_state/causal_transformer_shard/~/projection_shard/~/linear/w', '/opt_state/causal_transformer_shard/~/projection_shard/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/projection_shard/~/replicated_layer_norm/scale', '/opt_state', '/params/causal_transformer_shard/~/embedding_shard/~/linear/b', '/params/causal_transformer_shard/~/embedding_shard/~/linear/w', '/params/causal_transformer_shard/~/layer_0/~/linear/w', '/params/causal_transformer_shard/~/layer_0/~/linear_1/w', '/params/causal_transformer_shard/~/layer_0/~/linear_2/w', '/params/causal_transformer_shard/~/layer_0/~/linear_3/w', '/params/causal_transformer_shard/~/layer_0/~/linear_4/b', '/params/causal_transformer_shard/~/layer_0/~/linear_4/w', '/params/causal_transformer_shard/~/layer_0/~/linear_5/b', '/params/causal_transformer_shard/~/layer_0/~/linear_5/w', '/params/causal_transformer_shard/~/layer_0/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_0/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_1/~/linear/w', '/params/causal_transformer_shard/~/layer_1/~/linear_1/w', '/params/causal_transformer_shard/~/layer_1/~/linear_2/w', '/params/causal_transformer_shard/~/layer_1/~/linear_3/w', '/params/causal_transformer_shard/~/layer_1/~/linear_4/b', '/params/causal_transformer_shard/~/layer_1/~/linear_4/w', '/params/causal_transformer_shard/~/layer_1/~/linear_5/b', '/params/causal_transformer_shard/~/layer_1/~/linear_5/w', '/params/causal_transformer_shard/~/layer_1/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_1/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_10/~/linear/w', '/params/causal_transformer_shard/~/layer_10/~/linear_1/w', '/params/causal_transformer_shard/~/layer_10/~/linear_2/w', '/params/causal_transformer_shard/~/layer_10/~/linear_3/w', '/params/causal_transformer_shard/~/layer_10/~/linear_4/b', '/params/causal_transformer_shard/~/layer_10/~/linear_4/w', '/params/causal_transformer_shard/~/layer_10/~/linear_5/b', '/params/causal_transformer_shard/~/layer_10/~/linear_5/w', '/params/causal_transformer_shard/~/layer_10/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_10/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_11/~/linear/w', '/params/causal_transformer_shard/~/layer_11/~/linear_1/w', '/params/causal_transformer_shard/~/layer_11/~/linear_2/w', '/params/causal_transformer_shard/~/layer_11/~/linear_3/w', '/params/causal_transformer_shard/~/layer_11/~/linear_4/b', '/params/causal_transformer_shard/~/layer_11/~/linear_4/w', '/params/causal_transformer_shard/~/layer_11/~/linear_5/b', '/params/causal_transformer_shard/~/layer_11/~/linear_5/w', '/params/causal_transformer_shard/~/layer_11/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_11/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_12/~/linear/w', '/params/causal_transformer_shard/~/layer_12/~/linear_1/w', '/params/causal_transformer_shard/~/layer_12/~/linear_2/w', '/params/causal_transformer_shard/~/layer_12/~/linear_3/w', '/params/causal_transformer_shard/~/layer_12/~/linear_4/b', '/params/causal_transformer_shard/~/layer_12/~/linear_4/w', '/params/causal_transformer_shard/~/layer_12/~/linear_5/b', '/params/causal_transformer_shard/~/layer_12/~/linear_5/w', '/params/causal_transformer_shard/~/layer_12/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_12/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_13/~/linear/w', '/params/causal_transformer_shard/~/layer_13/~/linear_1/w', '/params/causal_transformer_shard/~/layer_13/~/linear_2/w', '/params/causal_transformer_shard/~/layer_13/~/linear_3/w', '/params/causal_transformer_shard/~/layer_13/~/linear_4/b', '/params/causal_transformer_shard/~/layer_13/~/linear_4/w', '/params/causal_transformer_shard/~/layer_13/~/linear_5/b', '/params/causal_transformer_shard/~/layer_13/~/linear_5/w', '/params/causal_transformer_shard/~/layer_13/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_13/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_14/~/linear/w', '/params/causal_transformer_shard/~/layer_14/~/linear_1/w', '/params/causal_transformer_shard/~/layer_14/~/linear_2/w', '/params/causal_transformer_shard/~/layer_14/~/linear_3/w', '/params/causal_transformer_shard/~/layer_14/~/linear_4/b', '/params/causal_transformer_shard/~/layer_14/~/linear_4/w', '/params/causal_transformer_shard/~/layer_14/~/linear_5/b', '/params/causal_transformer_shard/~/layer_14/~/linear_5/w', '/params/causal_transformer_shard/~/layer_14/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_14/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_15/~/linear/w', '/params/causal_transformer_shard/~/layer_15/~/linear_1/w', '/params/causal_transformer_shard/~/layer_15/~/linear_2/w', '/params/causal_transformer_shard/~/layer_15/~/linear_3/w', '/params/causal_transformer_shard/~/layer_15/~/linear_4/b', '/params/causal_transformer_shard/~/layer_15/~/linear_4/w', '/params/causal_transformer_shard/~/layer_15/~/linear_5/b', '/params/causal_transformer_shard/~/layer_15/~/linear_5/w', '/params/causal_transformer_shard/~/layer_15/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_15/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_16/~/linear/w', '/params/causal_transformer_shard/~/layer_16/~/linear_1/w', '/params/causal_transformer_shard/~/layer_16/~/linear_2/w', '/params/causal_transformer_shard/~/layer_16/~/linear_3/w', '/params/causal_transformer_shard/~/layer_16/~/linear_4/b', '/params/causal_transformer_shard/~/layer_16/~/linear_4/w', '/params/causal_transformer_shard/~/layer_16/~/linear_5/b', '/params/causal_transformer_shard/~/layer_16/~/linear_5/w', '/params/causal_transformer_shard/~/layer_16/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_16/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_17/~/linear/w', '/params/causal_transformer_shard/~/layer_17/~/linear_1/w', '/params/causal_transformer_shard/~/layer_17/~/linear_2/w', '/params/causal_transformer_shard/~/layer_17/~/linear_3/w', '/params/causal_transformer_shard/~/layer_17/~/linear_4/b', '/params/causal_transformer_shard/~/layer_17/~/linear_4/w', '/params/causal_transformer_shard/~/layer_17/~/linear_5/b', '/params/causal_transformer_shard/~/layer_17/~/linear_5/w', '/params/causal_transformer_shard/~/layer_17/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_17/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_18/~/linear/w', '/params/causal_transformer_shard/~/layer_18/~/linear_1/w', '/params/causal_transformer_shard/~/layer_18/~/linear_2/w', '/params/causal_transformer_shard/~/layer_18/~/linear_3/w', '/params/causal_transformer_shard/~/layer_18/~/linear_4/b', '/params/causal_transformer_shard/~/layer_18/~/linear_4/w', '/params/causal_transformer_shard/~/layer_18/~/linear_5/b', '/params/causal_transformer_shard/~/layer_18/~/linear_5/w', '/params/causal_transformer_shard/~/layer_18/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_18/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_19/~/linear/w', '/params/causal_transformer_shard/~/layer_19/~/linear_1/w', '/params/causal_transformer_shard/~/layer_19/~/linear_2/w', '/params/causal_transformer_shard/~/layer_19/~/linear_3/w', '/params/causal_transformer_shard/~/layer_19/~/linear_4/b', '/params/causal_transformer_shard/~/layer_19/~/linear_4/w', '/params/causal_transformer_shard/~/layer_19/~/linear_5/b', '/params/causal_transformer_shard/~/layer_19/~/linear_5/w', '/params/causal_transformer_shard/~/layer_19/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_19/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_2/~/linear/w', '/params/causal_transformer_shard/~/layer_2/~/linear_1/w', '/params/causal_transformer_shard/~/layer_2/~/linear_2/w', '/params/causal_transformer_shard/~/layer_2/~/linear_3/w', '/params/causal_transformer_shard/~/layer_2/~/linear_4/b', '/params/causal_transformer_shard/~/layer_2/~/linear_4/w', '/params/causal_transformer_shard/~/layer_2/~/linear_5/b', '/params/causal_transformer_shard/~/layer_2/~/linear_5/w', '/params/causal_transformer_shard/~/layer_2/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_2/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_20/~/linear/w', '/params/causal_transformer_shard/~/layer_20/~/linear_1/w', '/params/causal_transformer_shard/~/layer_20/~/linear_2/w', '/params/causal_transformer_shard/~/layer_20/~/linear_3/w', '/params/causal_transformer_shard/~/layer_20/~/linear_4/b', '/params/causal_transformer_shard/~/layer_20/~/linear_4/w', '/params/causal_transformer_shard/~/layer_20/~/linear_5/b', '/params/causal_transformer_shard/~/layer_20/~/linear_5/w', '/params/causal_transformer_shard/~/layer_20/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_20/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_21/~/linear/w', '/params/causal_transformer_shard/~/layer_21/~/linear_1/w', '/params/causal_transformer_shard/~/layer_21/~/linear_2/w', '/params/causal_transformer_shard/~/layer_21/~/linear_3/w', '/params/causal_transformer_shard/~/layer_21/~/linear_4/b', '/params/causal_transformer_shard/~/layer_21/~/linear_4/w', '/params/causal_transformer_shard/~/layer_21/~/linear_5/b', '/params/causal_transformer_shard/~/layer_21/~/linear_5/w', '/params/causal_transformer_shard/~/layer_21/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_21/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_22/~/linear/w', '/params/causal_transformer_shard/~/layer_22/~/linear_1/w', '/params/causal_transformer_shard/~/layer_22/~/linear_2/w', '/params/causal_transformer_shard/~/layer_22/~/linear_3/w', '/params/causal_transformer_shard/~/layer_22/~/linear_4/b', '/params/causal_transformer_shard/~/layer_22/~/linear_4/w', '/params/causal_transformer_shard/~/layer_22/~/linear_5/b', '/params/causal_transformer_shard/~/layer_22/~/linear_5/w', '/params/causal_transformer_shard/~/layer_22/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_22/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_23/~/linear/w', '/params/causal_transformer_shard/~/layer_23/~/linear_1/w', '/params/causal_transformer_shard/~/layer_23/~/linear_2/w', '/params/causal_transformer_shard/~/layer_23/~/linear_3/w', '/params/causal_transformer_shard/~/layer_23/~/linear_4/b', '/params/causal_transformer_shard/~/layer_23/~/linear_4/w', '/params/causal_transformer_shard/~/layer_23/~/linear_5/b', '/params/causal_transformer_shard/~/layer_23/~/linear_5/w', '/params/causal_transformer_shard/~/layer_23/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_23/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_24/~/linear/w', '/params/causal_transformer_shard/~/layer_24/~/linear_1/w', '/params/causal_transformer_shard/~/layer_24/~/linear_2/w', '/params/causal_transformer_shard/~/layer_24/~/linear_3/w', '/params/causal_transformer_shard/~/layer_24/~/linear_4/b', '/params/causal_transformer_shard/~/layer_24/~/linear_4/w', '/params/causal_transformer_shard/~/layer_24/~/linear_5/b', '/params/causal_transformer_shard/~/layer_24/~/linear_5/w', '/params/causal_transformer_shard/~/layer_24/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_24/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_25/~/linear/w', '/params/causal_transformer_shard/~/layer_25/~/linear_1/w', '/params/causal_transformer_shard/~/layer_25/~/linear_2/w', '/params/causal_transformer_shard/~/layer_25/~/linear_3/w', '/params/causal_transformer_shard/~/layer_25/~/linear_4/b', '/params/causal_transformer_shard/~/layer_25/~/linear_4/w', '/params/causal_transformer_shard/~/layer_25/~/linear_5/b', '/params/causal_transformer_shard/~/layer_25/~/linear_5/w', '/params/causal_transformer_shard/~/layer_25/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_25/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_26/~/linear/w', '/params/causal_transformer_shard/~/layer_26/~/linear_1/w', '/params/causal_transformer_shard/~/layer_26/~/linear_2/w', '/params/causal_transformer_shard/~/layer_26/~/linear_3/w', '/params/causal_transformer_shard/~/layer_26/~/linear_4/b', '/params/causal_transformer_shard/~/layer_26/~/linear_4/w', '/params/causal_transformer_shard/~/layer_26/~/linear_5/b', '/params/causal_transformer_shard/~/layer_26/~/linear_5/w', '/params/causal_transformer_shard/~/layer_26/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_26/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_27/~/linear/w', '/params/causal_transformer_shard/~/layer_27/~/linear_1/w', '/params/causal_transformer_shard/~/layer_27/~/linear_2/w', '/params/causal_transformer_shard/~/layer_27/~/linear_3/w', '/params/causal_transformer_shard/~/layer_27/~/linear_4/b', '/params/causal_transformer_shard/~/layer_27/~/linear_4/w', '/params/causal_transformer_shard/~/layer_27/~/linear_5/b', '/params/causal_transformer_shard/~/layer_27/~/linear_5/w', '/params/causal_transformer_shard/~/layer_27/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_27/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_3/~/linear/w', '/params/causal_transformer_shard/~/layer_3/~/linear_1/w', '/params/causal_transformer_shard/~/layer_3/~/linear_2/w', '/params/causal_transformer_shard/~/layer_3/~/linear_3/w', '/params/causal_transformer_shard/~/layer_3/~/linear_4/b', '/params/causal_transformer_shard/~/layer_3/~/linear_4/w', '/params/causal_transformer_shard/~/layer_3/~/linear_5/b', '/params/causal_transformer_shard/~/layer_3/~/linear_5/w', '/params/causal_transformer_shard/~/layer_3/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_3/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_4/~/linear/w', '/params/causal_transformer_shard/~/layer_4/~/linear_1/w', '/params/causal_transformer_shard/~/layer_4/~/linear_2/w', '/params/causal_transformer_shard/~/layer_4/~/linear_3/w', '/params/causal_transformer_shard/~/layer_4/~/linear_4/b', '/params/causal_transformer_shard/~/layer_4/~/linear_4/w', '/params/causal_transformer_shard/~/layer_4/~/linear_5/b', '/params/causal_transformer_shard/~/layer_4/~/linear_5/w', '/params/causal_transformer_shard/~/layer_4/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_4/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_5/~/linear/w', '/params/causal_transformer_shard/~/layer_5/~/linear_1/w', '/params/causal_transformer_shard/~/layer_5/~/linear_2/w', '/params/causal_transformer_shard/~/layer_5/~/linear_3/w', '/params/causal_transformer_shard/~/layer_5/~/linear_4/b', '/params/causal_transformer_shard/~/layer_5/~/linear_4/w', '/params/causal_transformer_shard/~/layer_5/~/linear_5/b', '/params/causal_transformer_shard/~/layer_5/~/linear_5/w', '/params/causal_transformer_shard/~/layer_5/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_5/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_6/~/linear/w', '/params/causal_transformer_shard/~/layer_6/~/linear_1/w', '/params/causal_transformer_shard/~/layer_6/~/linear_2/w', '/params/causal_transformer_shard/~/layer_6/~/linear_3/w', '/params/causal_transformer_shard/~/layer_6/~/linear_4/b', '/params/causal_transformer_shard/~/layer_6/~/linear_4/w', '/params/causal_transformer_shard/~/layer_6/~/linear_5/b', '/params/causal_transformer_shard/~/layer_6/~/linear_5/w', '/params/causal_transformer_shard/~/layer_6/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_6/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_7/~/linear/w', '/params/causal_transformer_shard/~/layer_7/~/linear_1/w', '/params/causal_transformer_shard/~/layer_7/~/linear_2/w', '/params/causal_transformer_shard/~/layer_7/~/linear_3/w', '/params/causal_transformer_shard/~/layer_7/~/linear_4/b', '/params/causal_transformer_shard/~/layer_7/~/linear_4/w', '/params/causal_transformer_shard/~/layer_7/~/linear_5/b', '/params/causal_transformer_shard/~/layer_7/~/linear_5/w', '/params/causal_transformer_shard/~/layer_7/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_7/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_8/~/linear/w', '/params/causal_transformer_shard/~/layer_8/~/linear_1/w', '/params/causal_transformer_shard/~/layer_8/~/linear_2/w', '/params/causal_transformer_shard/~/layer_8/~/linear_3/w', '/params/causal_transformer_shard/~/layer_8/~/linear_4/b', '/params/causal_transformer_shard/~/layer_8/~/linear_4/w', '/params/causal_transformer_shard/~/layer_8/~/linear_5/b', '/params/causal_transformer_shard/~/layer_8/~/linear_5/w', '/params/causal_transformer_shard/~/layer_8/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_8/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_9/~/linear/w', '/params/causal_transformer_shard/~/layer_9/~/linear_1/w', '/params/causal_transformer_shard/~/layer_9/~/linear_2/w', '/params/causal_transformer_shard/~/layer_9/~/linear_3/w', '/params/causal_transformer_shard/~/layer_9/~/linear_4/b', '/params/causal_transformer_shard/~/layer_9/~/linear_4/w', '/params/causal_transformer_shard/~/layer_9/~/linear_5/b', '/params/causal_transformer_shard/~/layer_9/~/linear_5/w', '/params/causal_transformer_shard/~/layer_9/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_9/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/projection_shard/~/linear/b', '/params/causal_transformer_shard/~/projection_shard/~/linear/w', '/params/causal_transformer_shard/~/projection_shard/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/projection_shard/~/replicated_layer_norm/scale', '/step'] | |
# leaves_names_reduced = get_tree_leaves_names_reduced(params) | |
# print(leaves_names_reduced) | |
leaves_names_reduced = ['/params/causal_transformer_shard/~/embedding_shard/~/linear/b', '/params/causal_transformer_shard/~/embedding_shard/~/linear/w', '/params/causal_transformer_shard/~/layer_0/~/linear/w', '/params/causal_transformer_shard/~/layer_0/~/linear_1/w', '/params/causal_transformer_shard/~/layer_0/~/linear_2/w', '/params/causal_transformer_shard/~/layer_0/~/linear_3/w', '/params/causal_transformer_shard/~/layer_0/~/linear_4/b', '/params/causal_transformer_shard/~/layer_0/~/linear_4/w', '/params/causal_transformer_shard/~/layer_0/~/linear_5/b', '/params/causal_transformer_shard/~/layer_0/~/linear_5/w', '/params/causal_transformer_shard/~/layer_0/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_0/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_1/~/linear/w', '/params/causal_transformer_shard/~/layer_1/~/linear_1/w', '/params/causal_transformer_shard/~/layer_1/~/linear_2/w', '/params/causal_transformer_shard/~/layer_1/~/linear_3/w', '/params/causal_transformer_shard/~/layer_1/~/linear_4/b', '/params/causal_transformer_shard/~/layer_1/~/linear_4/w', '/params/causal_transformer_shard/~/layer_1/~/linear_5/b', '/params/causal_transformer_shard/~/layer_1/~/linear_5/w', '/params/causal_transformer_shard/~/layer_1/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_1/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_10/~/linear/w', '/params/causal_transformer_shard/~/layer_10/~/linear_1/w', '/params/causal_transformer_shard/~/layer_10/~/linear_2/w', '/params/causal_transformer_shard/~/layer_10/~/linear_3/w', '/params/causal_transformer_shard/~/layer_10/~/linear_4/b', '/params/causal_transformer_shard/~/layer_10/~/linear_4/w', '/params/causal_transformer_shard/~/layer_10/~/linear_5/b', '/params/causal_transformer_shard/~/layer_10/~/linear_5/w', '/params/causal_transformer_shard/~/layer_10/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_10/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_11/~/linear/w', '/params/causal_transformer_shard/~/layer_11/~/linear_1/w', '/params/causal_transformer_shard/~/layer_11/~/linear_2/w', '/params/causal_transformer_shard/~/layer_11/~/linear_3/w', '/params/causal_transformer_shard/~/layer_11/~/linear_4/b', '/params/causal_transformer_shard/~/layer_11/~/linear_4/w', '/params/causal_transformer_shard/~/layer_11/~/linear_5/b', '/params/causal_transformer_shard/~/layer_11/~/linear_5/w', '/params/causal_transformer_shard/~/layer_11/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_11/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_12/~/linear/w', '/params/causal_transformer_shard/~/layer_12/~/linear_1/w', '/params/causal_transformer_shard/~/layer_12/~/linear_2/w', '/params/causal_transformer_shard/~/layer_12/~/linear_3/w', '/params/causal_transformer_shard/~/layer_12/~/linear_4/b', '/params/causal_transformer_shard/~/layer_12/~/linear_4/w', '/params/causal_transformer_shard/~/layer_12/~/linear_5/b', '/params/causal_transformer_shard/~/layer_12/~/linear_5/w', '/params/causal_transformer_shard/~/layer_12/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_12/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_13/~/linear/w', '/params/causal_transformer_shard/~/layer_13/~/linear_1/w', '/params/causal_transformer_shard/~/layer_13/~/linear_2/w', '/params/causal_transformer_shard/~/layer_13/~/linear_3/w', '/params/causal_transformer_shard/~/layer_13/~/linear_4/b', '/params/causal_transformer_shard/~/layer_13/~/linear_4/w', '/params/causal_transformer_shard/~/layer_13/~/linear_5/b', '/params/causal_transformer_shard/~/layer_13/~/linear_5/w', '/params/causal_transformer_shard/~/layer_13/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_13/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_14/~/linear/w', '/params/causal_transformer_shard/~/layer_14/~/linear_1/w', '/params/causal_transformer_shard/~/layer_14/~/linear_2/w', '/params/causal_transformer_shard/~/layer_14/~/linear_3/w', '/params/causal_transformer_shard/~/layer_14/~/linear_4/b', '/params/causal_transformer_shard/~/layer_14/~/linear_4/w', '/params/causal_transformer_shard/~/layer_14/~/linear_5/b', '/params/causal_transformer_shard/~/layer_14/~/linear_5/w', '/params/causal_transformer_shard/~/layer_14/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_14/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_15/~/linear/w', '/params/causal_transformer_shard/~/layer_15/~/linear_1/w', '/params/causal_transformer_shard/~/layer_15/~/linear_2/w', '/params/causal_transformer_shard/~/layer_15/~/linear_3/w', '/params/causal_transformer_shard/~/layer_15/~/linear_4/b', '/params/causal_transformer_shard/~/layer_15/~/linear_4/w', '/params/causal_transformer_shard/~/layer_15/~/linear_5/b', '/params/causal_transformer_shard/~/layer_15/~/linear_5/w', '/params/causal_transformer_shard/~/layer_15/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_15/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_16/~/linear/w', '/params/causal_transformer_shard/~/layer_16/~/linear_1/w', '/params/causal_transformer_shard/~/layer_16/~/linear_2/w', '/params/causal_transformer_shard/~/layer_16/~/linear_3/w', '/params/causal_transformer_shard/~/layer_16/~/linear_4/b', '/params/causal_transformer_shard/~/layer_16/~/linear_4/w', '/params/causal_transformer_shard/~/layer_16/~/linear_5/b', '/params/causal_transformer_shard/~/layer_16/~/linear_5/w', '/params/causal_transformer_shard/~/layer_16/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_16/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_17/~/linear/w', '/params/causal_transformer_shard/~/layer_17/~/linear_1/w', '/params/causal_transformer_shard/~/layer_17/~/linear_2/w', '/params/causal_transformer_shard/~/layer_17/~/linear_3/w', '/params/causal_transformer_shard/~/layer_17/~/linear_4/b', '/params/causal_transformer_shard/~/layer_17/~/linear_4/w', '/params/causal_transformer_shard/~/layer_17/~/linear_5/b', '/params/causal_transformer_shard/~/layer_17/~/linear_5/w', '/params/causal_transformer_shard/~/layer_17/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_17/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_18/~/linear/w', '/params/causal_transformer_shard/~/layer_18/~/linear_1/w', '/params/causal_transformer_shard/~/layer_18/~/linear_2/w', '/params/causal_transformer_shard/~/layer_18/~/linear_3/w', '/params/causal_transformer_shard/~/layer_18/~/linear_4/b', '/params/causal_transformer_shard/~/layer_18/~/linear_4/w', '/params/causal_transformer_shard/~/layer_18/~/linear_5/b', '/params/causal_transformer_shard/~/layer_18/~/linear_5/w', '/params/causal_transformer_shard/~/layer_18/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_18/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_19/~/linear/w', '/params/causal_transformer_shard/~/layer_19/~/linear_1/w', '/params/causal_transformer_shard/~/layer_19/~/linear_2/w', '/params/causal_transformer_shard/~/layer_19/~/linear_3/w', '/params/causal_transformer_shard/~/layer_19/~/linear_4/b', '/params/causal_transformer_shard/~/layer_19/~/linear_4/w', '/params/causal_transformer_shard/~/layer_19/~/linear_5/b', '/params/causal_transformer_shard/~/layer_19/~/linear_5/w', '/params/causal_transformer_shard/~/layer_19/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_19/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_2/~/linear/w', '/params/causal_transformer_shard/~/layer_2/~/linear_1/w', '/params/causal_transformer_shard/~/layer_2/~/linear_2/w', '/params/causal_transformer_shard/~/layer_2/~/linear_3/w', '/params/causal_transformer_shard/~/layer_2/~/linear_4/b', '/params/causal_transformer_shard/~/layer_2/~/linear_4/w', '/params/causal_transformer_shard/~/layer_2/~/linear_5/b', '/params/causal_transformer_shard/~/layer_2/~/linear_5/w', '/params/causal_transformer_shard/~/layer_2/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_2/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_20/~/linear/w', '/params/causal_transformer_shard/~/layer_20/~/linear_1/w', '/params/causal_transformer_shard/~/layer_20/~/linear_2/w', '/params/causal_transformer_shard/~/layer_20/~/linear_3/w', '/params/causal_transformer_shard/~/layer_20/~/linear_4/b', '/params/causal_transformer_shard/~/layer_20/~/linear_4/w', '/params/causal_transformer_shard/~/layer_20/~/linear_5/b', '/params/causal_transformer_shard/~/layer_20/~/linear_5/w', '/params/causal_transformer_shard/~/layer_20/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_20/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_21/~/linear/w', '/params/causal_transformer_shard/~/layer_21/~/linear_1/w', '/params/causal_transformer_shard/~/layer_21/~/linear_2/w', '/params/causal_transformer_shard/~/layer_21/~/linear_3/w', '/params/causal_transformer_shard/~/layer_21/~/linear_4/b', '/params/causal_transformer_shard/~/layer_21/~/linear_4/w', '/params/causal_transformer_shard/~/layer_21/~/linear_5/b', '/params/causal_transformer_shard/~/layer_21/~/linear_5/w', '/params/causal_transformer_shard/~/layer_21/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_21/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_22/~/linear/w', '/params/causal_transformer_shard/~/layer_22/~/linear_1/w', '/params/causal_transformer_shard/~/layer_22/~/linear_2/w', '/params/causal_transformer_shard/~/layer_22/~/linear_3/w', '/params/causal_transformer_shard/~/layer_22/~/linear_4/b', '/params/causal_transformer_shard/~/layer_22/~/linear_4/w', '/params/causal_transformer_shard/~/layer_22/~/linear_5/b', '/params/causal_transformer_shard/~/layer_22/~/linear_5/w', '/params/causal_transformer_shard/~/layer_22/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_22/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_23/~/linear/w', '/params/causal_transformer_shard/~/layer_23/~/linear_1/w', '/params/causal_transformer_shard/~/layer_23/~/linear_2/w', '/params/causal_transformer_shard/~/layer_23/~/linear_3/w', '/params/causal_transformer_shard/~/layer_23/~/linear_4/b', '/params/causal_transformer_shard/~/layer_23/~/linear_4/w', '/params/causal_transformer_shard/~/layer_23/~/linear_5/b', '/params/causal_transformer_shard/~/layer_23/~/linear_5/w', '/params/causal_transformer_shard/~/layer_23/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_23/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_24/~/linear/w', '/params/causal_transformer_shard/~/layer_24/~/linear_1/w', '/params/causal_transformer_shard/~/layer_24/~/linear_2/w', '/params/causal_transformer_shard/~/layer_24/~/linear_3/w', '/params/causal_transformer_shard/~/layer_24/~/linear_4/b', '/params/causal_transformer_shard/~/layer_24/~/linear_4/w', '/params/causal_transformer_shard/~/layer_24/~/linear_5/b', '/params/causal_transformer_shard/~/layer_24/~/linear_5/w', '/params/causal_transformer_shard/~/layer_24/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_24/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_25/~/linear/w', '/params/causal_transformer_shard/~/layer_25/~/linear_1/w', '/params/causal_transformer_shard/~/layer_25/~/linear_2/w', '/params/causal_transformer_shard/~/layer_25/~/linear_3/w', '/params/causal_transformer_shard/~/layer_25/~/linear_4/b', '/params/causal_transformer_shard/~/layer_25/~/linear_4/w', '/params/causal_transformer_shard/~/layer_25/~/linear_5/b', '/params/causal_transformer_shard/~/layer_25/~/linear_5/w', '/params/causal_transformer_shard/~/layer_25/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_25/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_26/~/linear/w', '/params/causal_transformer_shard/~/layer_26/~/linear_1/w', '/params/causal_transformer_shard/~/layer_26/~/linear_2/w', '/params/causal_transformer_shard/~/layer_26/~/linear_3/w', '/params/causal_transformer_shard/~/layer_26/~/linear_4/b', '/params/causal_transformer_shard/~/layer_26/~/linear_4/w', '/params/causal_transformer_shard/~/layer_26/~/linear_5/b', '/params/causal_transformer_shard/~/layer_26/~/linear_5/w', '/params/causal_transformer_shard/~/layer_26/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_26/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_27/~/linear/w', '/params/causal_transformer_shard/~/layer_27/~/linear_1/w', '/params/causal_transformer_shard/~/layer_27/~/linear_2/w', '/params/causal_transformer_shard/~/layer_27/~/linear_3/w', '/params/causal_transformer_shard/~/layer_27/~/linear_4/b', '/params/causal_transformer_shard/~/layer_27/~/linear_4/w', '/params/causal_transformer_shard/~/layer_27/~/linear_5/b', '/params/causal_transformer_shard/~/layer_27/~/linear_5/w', '/params/causal_transformer_shard/~/layer_27/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_27/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_3/~/linear/w', '/params/causal_transformer_shard/~/layer_3/~/linear_1/w', '/params/causal_transformer_shard/~/layer_3/~/linear_2/w', '/params/causal_transformer_shard/~/layer_3/~/linear_3/w', '/params/causal_transformer_shard/~/layer_3/~/linear_4/b', '/params/causal_transformer_shard/~/layer_3/~/linear_4/w', '/params/causal_transformer_shard/~/layer_3/~/linear_5/b', '/params/causal_transformer_shard/~/layer_3/~/linear_5/w', '/params/causal_transformer_shard/~/layer_3/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_3/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_4/~/linear/w', '/params/causal_transformer_shard/~/layer_4/~/linear_1/w', '/params/causal_transformer_shard/~/layer_4/~/linear_2/w', '/params/causal_transformer_shard/~/layer_4/~/linear_3/w', '/params/causal_transformer_shard/~/layer_4/~/linear_4/b', '/params/causal_transformer_shard/~/layer_4/~/linear_4/w', '/params/causal_transformer_shard/~/layer_4/~/linear_5/b', '/params/causal_transformer_shard/~/layer_4/~/linear_5/w', '/params/causal_transformer_shard/~/layer_4/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_4/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_5/~/linear/w', '/params/causal_transformer_shard/~/layer_5/~/linear_1/w', '/params/causal_transformer_shard/~/layer_5/~/linear_2/w', '/params/causal_transformer_shard/~/layer_5/~/linear_3/w', '/params/causal_transformer_shard/~/layer_5/~/linear_4/b', '/params/causal_transformer_shard/~/layer_5/~/linear_4/w', '/params/causal_transformer_shard/~/layer_5/~/linear_5/b', '/params/causal_transformer_shard/~/layer_5/~/linear_5/w', '/params/causal_transformer_shard/~/layer_5/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_5/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_6/~/linear/w', '/params/causal_transformer_shard/~/layer_6/~/linear_1/w', '/params/causal_transformer_shard/~/layer_6/~/linear_2/w', '/params/causal_transformer_shard/~/layer_6/~/linear_3/w', '/params/causal_transformer_shard/~/layer_6/~/linear_4/b', '/params/causal_transformer_shard/~/layer_6/~/linear_4/w', '/params/causal_transformer_shard/~/layer_6/~/linear_5/b', '/params/causal_transformer_shard/~/layer_6/~/linear_5/w', '/params/causal_transformer_shard/~/layer_6/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_6/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_7/~/linear/w', '/params/causal_transformer_shard/~/layer_7/~/linear_1/w', '/params/causal_transformer_shard/~/layer_7/~/linear_2/w', '/params/causal_transformer_shard/~/layer_7/~/linear_3/w', '/params/causal_transformer_shard/~/layer_7/~/linear_4/b', '/params/causal_transformer_shard/~/layer_7/~/linear_4/w', '/params/causal_transformer_shard/~/layer_7/~/linear_5/b', '/params/causal_transformer_shard/~/layer_7/~/linear_5/w', '/params/causal_transformer_shard/~/layer_7/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_7/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_8/~/linear/w', '/params/causal_transformer_shard/~/layer_8/~/linear_1/w', '/params/causal_transformer_shard/~/layer_8/~/linear_2/w', '/params/causal_transformer_shard/~/layer_8/~/linear_3/w', '/params/causal_transformer_shard/~/layer_8/~/linear_4/b', '/params/causal_transformer_shard/~/layer_8/~/linear_4/w', '/params/causal_transformer_shard/~/layer_8/~/linear_5/b', '/params/causal_transformer_shard/~/layer_8/~/linear_5/w', '/params/causal_transformer_shard/~/layer_8/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_8/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_9/~/linear/w', '/params/causal_transformer_shard/~/layer_9/~/linear_1/w', '/params/causal_transformer_shard/~/layer_9/~/linear_2/w', '/params/causal_transformer_shard/~/layer_9/~/linear_3/w', '/params/causal_transformer_shard/~/layer_9/~/linear_4/b', '/params/causal_transformer_shard/~/layer_9/~/linear_4/w', '/params/causal_transformer_shard/~/layer_9/~/linear_5/b', '/params/causal_transformer_shard/~/layer_9/~/linear_5/w', '/params/causal_transformer_shard/~/layer_9/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_9/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/projection_shard/~/linear/b', '/params/causal_transformer_shard/~/projection_shard/~/linear/w', '/params/causal_transformer_shard/~/projection_shard/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/projection_shard/~/replicated_layer_norm/scale', '/step'] | |
# TODO(nijkamp): overwriting the optimizer mutates the pytree in order to reduce memory alloc, but this will break the serialization format, serialize model into optim / param files separately to clean this mess | |
params['optimizer'] = optax.scale(0) | |
params['sampler'] = nucleaus_sample | |
devices = np.array([jax.devices()[0]]).reshape((1, 1)) | |
with jax.experimental.maps.mesh(devices, ('dp', 'mp')): | |
print('CausalTransformer', 'begin') | |
t0 = time.time() | |
network = CausalTransformer(params) | |
print('CausalTransformer', 'end', f'{time.time() - t0:06}') | |
print('read_ckpt', 'begin') | |
t0 = time.time() | |
network.state = read_ckpt(network.state, leaves_names_original, leaves_names_reduced, '/export/home/gptj/step_383500/', shards_in=8, shards_out=1) | |
print('read_ckpt', 'end', f'{time.time() - t0:06}') | |
print('device_put', 'begin') | |
t0 = time.time() | |
network.state = jax.device_put(network.state, jax.devices("cpu")[0]) | |
print('device_put', 'end', f'{time.time() - t0:06}') | |
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') | |
print('sample', 'begin') | |
t0 = time.time() | |
print(sample(tokenizer=tokenizer, per_replica_batch=params['per_replica_batch'], seq=params['seq'], network=network, context='EleutherAI is')) | |
print('sample', 'end', f'{time.time() - t0:06}') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment