Skip to content

Instantly share code, notes, and snippets.

@enijkamp
Last active August 2, 2021 00:25
Show Gist options
  • Save enijkamp/79ea43b81d0e9af232cccf96c3d13168 to your computer and use it in GitHub Desktop.
Save enijkamp/79ea43b81d0e9af232cccf96c3d13168 to your computer and use it in GitHub Desktop.
resharding.py
'''
python3.8 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip setuptools
pip install -r requirements.txt
pip install --upgrade jax==0.2.12 jaxlib==0.1.67+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
'''
import os
# xla
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
import io
import time
import multiprocessing
import numpy as np
import jax
import jax.numpy as jnp
from jax.experimental import maps
import optax
import transformers
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer
import mesh_transformer.util as util
def tree_flatten_with_names(pytree, is_leaf, path='', to_id=id):
id_to_name = {}
if getattr(pytree, 'items', None):
for k, v in pytree.items():
k_path = f'{path}/{k}'
if is_leaf(v):
id_to_name[to_id(v)] = k_path
else:
id_to_name = { **id_to_name, **tree_flatten_with_names(v, is_leaf=is_leaf, path=k_path)}
elif getattr(pytree, '__getitem__', None):
for v in pytree:
if is_leaf(v):
id_to_name[to_id(v)] = path
else:
id_to_name = { **id_to_name, **tree_flatten_with_names(v, is_leaf=is_leaf, path=path)}
else:
id_to_name[to_id(pytree)] = path
return id_to_name
def tree_leaves_with_names(pytree, to_id=id):
leaves = jax.tree_leaves(pytree)
is_leaf = lambda x: not isinstance(x, list) and to_id(x) in [to_id(x) for x in leaves]
return tree_flatten_with_names(pytree, is_leaf)
def get_tree_leaves_names_original(params):
jax.config.update('jax_platform_name', 'cpu')
params['optimizer'] = optax.chain(
optax.scale(1),
util.clip_by_global_norm(1),
optax.scale_by_adam(),
optax.additive_weight_decay(0),
optax.scale(-1),
optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0))
)
devices = np.array([jax.devices()[0]]).reshape((1, 1))
with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
network = CausalTransformer(params)
leaves_ids = tree_leaves_with_names(network.state, to_id=id)
leaves = jax.tree_leaves(network.state)
leaves_names = [leaves_ids[id(l)] for l in leaves]
return leaves_names
def get_tree_leaves_names_reduced(params):
jax.config.update('jax_platform_name', 'cpu')
params['optimizer'] = optax.scale(0)
devices = np.array([jax.devices()[0]]).reshape((1, 1))
with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
network = CausalTransformer(params)
leaves_ids = tree_leaves_with_names(network.state, to_id=id)
leaves = jax.tree_leaves(network.state)
leaves_names = [leaves_ids[id(l)] for l in leaves]
return leaves_names
# TODO(nijkamp): rewrite this mess
def reshard(x, old_shape, do_shard_ln, do_shard_bias):
if len(x.shape) == 1:
# print("step")
out = x[0:1]
elif len(x.shape) == 2:
# print(f"LN/bias {x.shape}")
# TODO(nijkamp): incorrect
# if (x[1:] == x[-1]).all():
if do_shard_ln or do_shard_bias:
# print("LN")
# if (x[1:] == 0).all() or (x[1:] == 1).all():
if do_shard_ln:
# TODO(nijkamp): for thise case, expression (x[1:] == 0).all() or (x[1:] == 1).all() should hold
out = x[0:1]
else:
# print("shard bias")
# out = x[0:1] * x.shape[0] / old_shape[0]
# TODO(nijkamp): sum() bias terms, is this correct?
out = np.reshape(np.sum(x, axis=0), old_shape)
else:
# print("bias")
out = x.reshape(old_shape)
elif len(x.shape) == 3:
# print(f"weight {x.shape}")
if x.shape[0] * x.shape[2] == old_shape[2]:
out = jnp.transpose(x, (1, 0, 2)).reshape(old_shape)
elif x.shape[0] * x.shape[1] == old_shape[1]:
out = x.reshape(old_shape)
else:
raise Exception(f"unimplemented, {x.shape}, {old_shape}")
else:
raise Exception(f"unimplemented, {x}")
return out
def read_shard(ckpt_dir, pieces=16):
out = []
for idx in range(pieces):
file_path = ckpt_dir + f"{idx}.npz"
with open(file_path, "rb") as f:
buf = f.read()
f_io = io.BytesIO(buf)
deserialized = np.load(f_io)
for i in deserialized:
out.append(deserialized[i])
return out
def read_ckpt(pytree, leaves_names, leaves_names_reduced, dir, shards_in, shards_out):
old_flattened, structure = jax.tree_flatten(pytree)
print(len(old_flattened))
print(len(leaves_names))
print(len(leaves_names_reduced))
# TODO(nijkamp): rewrite this mess
with multiprocessing.pool.ThreadPool(shards_in) as p:
start = time.time()
shards = list((p.imap(read_shard, [f"{dir}shard_{i}/" for i in range(shards_in)])))
print(f"read checkpoint in {time.time() - start:.06}s")
unsharded = []
old_i = 0
for i, (leave_name, *all_shards) in enumerate(zip(leaves_names, *shards)):
if not leave_name.startswith('/params') and not leave_name.startswith('/step'):
continue
assert leave_name == leaves_names_reduced[old_i], f'{leave_name} {leaves_names_reduced[old_i]}'
old = old_flattened[old_i]
old_i += 1
x = np.stack(all_shards)
# TODO(nijkamp): what is this?
if x.dtype == np.dtype('V2'):
x.dtype = jnp.bfloat16
print(i, old_i, leave_name)
print(x.shape)
print(old.shape)
if shards_out != shards_in:
x = reshard(x, old.shape, do_shard_bias=leave_name.endswith('embedding_shard/~/linear/b') or leave_name.endswith('linear_5/b'), do_shard_ln=leave_name.endswith('replicated_layer_norm/offset') or leave_name.endswith('replicated_layer_norm/scale'))
unsharded.append(x)
assert x.shape == old.shape, f"Incompatible checkpoints {x.shape} vs {old.shape} (leave_name)"
loaded_pytree = jax.tree_unflatten(structure, unsharded)
return loaded_pytree
def sample(tokenizer, per_replica_batch, seq, network, context, top_k=40, top_p=0.9, temp=1.0, gen_len=512):
tokens = tokenizer.encode(context)
provided_ctx = len(tokens)
pad_amount = seq - provided_ctx
padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
batched_tokens = np.array([padded_tokens] * per_replica_batch)
length = np.ones(per_replica_batch, dtype=np.uint32) * len(tokens)
start = time.time()
output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(per_replica_batch) * top_p, "top_k": top_k is not None and (np.ones(per_replica_batch, dtype=np.int32) * top_k) or None, "temp": np.ones(per_replica_batch) * temp})
samples = []
decoded_tokens = output[1][0]
for o in decoded_tokens[:, :, 0]:
samples.append(tokenizer.decode(o))
print(f"completion done in {time.time() - start:06}s")
return samples
def main():
params = {
"layers": 28,
"d_model": 4096,
"n_heads": 16,
"n_vocab": 50400,
"norm": "layernorm",
"pe": "rotary",
"pe_rotary_dims": 64,
"early_cast": True,
"seq": 2048,
"cores_per_replica": 1,
"per_replica_batch": 1,
}
# leaves_names_original = get_tree_leaves_names_original(params)
# print(leaves_names_original)
leaves_names_original = ['/opt_state', '/opt_state/causal_transformer_shard/~/embedding_shard/~/linear/b', '/opt_state/causal_transformer_shard/~/embedding_shard/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_0/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_0/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_1/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_1/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_1/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_10/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_10/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_10/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_11/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_11/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_11/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_12/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_12/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_12/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_13/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_13/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_13/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_14/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_14/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_14/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_15/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_15/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_15/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_16/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_16/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_16/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_17/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_17/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_17/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_18/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_18/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_18/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_19/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_19/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_19/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_2/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_2/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_2/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_20/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_20/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_20/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_21/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_21/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_21/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_22/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_22/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_22/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_23/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_23/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_23/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_24/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_24/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_24/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_25/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_25/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_25/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_26/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_26/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_26/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_27/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_27/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_27/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_3/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_3/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_3/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_4/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_4/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_4/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_5/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_5/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_5/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_6/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_6/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_6/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_7/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_7/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_7/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_8/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_8/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_8/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_9/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_9/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_9/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/projection_shard/~/linear/b', '/opt_state/causal_transformer_shard/~/projection_shard/~/linear/w', '/opt_state/causal_transformer_shard/~/projection_shard/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/projection_shard/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/embedding_shard/~/linear/b', '/opt_state/causal_transformer_shard/~/embedding_shard/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_0/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_0/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_0/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_1/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_1/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_1/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_1/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_10/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_10/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_10/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_10/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_11/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_11/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_11/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_11/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_12/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_12/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_12/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_12/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_13/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_13/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_13/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_13/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_14/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_14/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_14/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_14/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_15/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_15/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_15/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_15/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_16/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_16/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_16/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_16/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_17/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_17/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_17/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_17/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_18/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_18/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_18/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_18/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_19/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_19/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_19/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_19/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_2/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_2/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_2/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_2/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_20/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_20/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_20/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_20/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_21/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_21/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_21/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_21/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_22/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_22/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_22/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_22/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_23/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_23/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_23/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_23/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_24/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_24/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_24/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_24/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_25/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_25/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_25/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_25/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_26/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_26/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_26/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_26/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_27/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_27/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_27/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_27/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_3/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_3/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_3/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_3/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_4/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_4/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_4/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_4/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_5/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_5/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_5/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_5/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_6/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_6/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_6/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_6/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_7/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_7/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_7/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_7/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_8/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_8/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_8/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_8/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/layer_9/~/linear/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_1/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_2/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_3/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_4/b', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_4/w', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_5/b', '/opt_state/causal_transformer_shard/~/layer_9/~/linear_5/w', '/opt_state/causal_transformer_shard/~/layer_9/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/layer_9/~/replicated_layer_norm/scale', '/opt_state/causal_transformer_shard/~/projection_shard/~/linear/b', '/opt_state/causal_transformer_shard/~/projection_shard/~/linear/w', '/opt_state/causal_transformer_shard/~/projection_shard/~/replicated_layer_norm/offset', '/opt_state/causal_transformer_shard/~/projection_shard/~/replicated_layer_norm/scale', '/opt_state', '/params/causal_transformer_shard/~/embedding_shard/~/linear/b', '/params/causal_transformer_shard/~/embedding_shard/~/linear/w', '/params/causal_transformer_shard/~/layer_0/~/linear/w', '/params/causal_transformer_shard/~/layer_0/~/linear_1/w', '/params/causal_transformer_shard/~/layer_0/~/linear_2/w', '/params/causal_transformer_shard/~/layer_0/~/linear_3/w', '/params/causal_transformer_shard/~/layer_0/~/linear_4/b', '/params/causal_transformer_shard/~/layer_0/~/linear_4/w', '/params/causal_transformer_shard/~/layer_0/~/linear_5/b', '/params/causal_transformer_shard/~/layer_0/~/linear_5/w', '/params/causal_transformer_shard/~/layer_0/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_0/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_1/~/linear/w', '/params/causal_transformer_shard/~/layer_1/~/linear_1/w', '/params/causal_transformer_shard/~/layer_1/~/linear_2/w', '/params/causal_transformer_shard/~/layer_1/~/linear_3/w', '/params/causal_transformer_shard/~/layer_1/~/linear_4/b', '/params/causal_transformer_shard/~/layer_1/~/linear_4/w', '/params/causal_transformer_shard/~/layer_1/~/linear_5/b', '/params/causal_transformer_shard/~/layer_1/~/linear_5/w', '/params/causal_transformer_shard/~/layer_1/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_1/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_10/~/linear/w', '/params/causal_transformer_shard/~/layer_10/~/linear_1/w', '/params/causal_transformer_shard/~/layer_10/~/linear_2/w', '/params/causal_transformer_shard/~/layer_10/~/linear_3/w', '/params/causal_transformer_shard/~/layer_10/~/linear_4/b', '/params/causal_transformer_shard/~/layer_10/~/linear_4/w', '/params/causal_transformer_shard/~/layer_10/~/linear_5/b', '/params/causal_transformer_shard/~/layer_10/~/linear_5/w', '/params/causal_transformer_shard/~/layer_10/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_10/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_11/~/linear/w', '/params/causal_transformer_shard/~/layer_11/~/linear_1/w', '/params/causal_transformer_shard/~/layer_11/~/linear_2/w', '/params/causal_transformer_shard/~/layer_11/~/linear_3/w', '/params/causal_transformer_shard/~/layer_11/~/linear_4/b', '/params/causal_transformer_shard/~/layer_11/~/linear_4/w', '/params/causal_transformer_shard/~/layer_11/~/linear_5/b', '/params/causal_transformer_shard/~/layer_11/~/linear_5/w', '/params/causal_transformer_shard/~/layer_11/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_11/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_12/~/linear/w', '/params/causal_transformer_shard/~/layer_12/~/linear_1/w', '/params/causal_transformer_shard/~/layer_12/~/linear_2/w', '/params/causal_transformer_shard/~/layer_12/~/linear_3/w', '/params/causal_transformer_shard/~/layer_12/~/linear_4/b', '/params/causal_transformer_shard/~/layer_12/~/linear_4/w', '/params/causal_transformer_shard/~/layer_12/~/linear_5/b', '/params/causal_transformer_shard/~/layer_12/~/linear_5/w', '/params/causal_transformer_shard/~/layer_12/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_12/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_13/~/linear/w', '/params/causal_transformer_shard/~/layer_13/~/linear_1/w', '/params/causal_transformer_shard/~/layer_13/~/linear_2/w', '/params/causal_transformer_shard/~/layer_13/~/linear_3/w', '/params/causal_transformer_shard/~/layer_13/~/linear_4/b', '/params/causal_transformer_shard/~/layer_13/~/linear_4/w', '/params/causal_transformer_shard/~/layer_13/~/linear_5/b', '/params/causal_transformer_shard/~/layer_13/~/linear_5/w', '/params/causal_transformer_shard/~/layer_13/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_13/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_14/~/linear/w', '/params/causal_transformer_shard/~/layer_14/~/linear_1/w', '/params/causal_transformer_shard/~/layer_14/~/linear_2/w', '/params/causal_transformer_shard/~/layer_14/~/linear_3/w', '/params/causal_transformer_shard/~/layer_14/~/linear_4/b', '/params/causal_transformer_shard/~/layer_14/~/linear_4/w', '/params/causal_transformer_shard/~/layer_14/~/linear_5/b', '/params/causal_transformer_shard/~/layer_14/~/linear_5/w', '/params/causal_transformer_shard/~/layer_14/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_14/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_15/~/linear/w', '/params/causal_transformer_shard/~/layer_15/~/linear_1/w', '/params/causal_transformer_shard/~/layer_15/~/linear_2/w', '/params/causal_transformer_shard/~/layer_15/~/linear_3/w', '/params/causal_transformer_shard/~/layer_15/~/linear_4/b', '/params/causal_transformer_shard/~/layer_15/~/linear_4/w', '/params/causal_transformer_shard/~/layer_15/~/linear_5/b', '/params/causal_transformer_shard/~/layer_15/~/linear_5/w', '/params/causal_transformer_shard/~/layer_15/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_15/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_16/~/linear/w', '/params/causal_transformer_shard/~/layer_16/~/linear_1/w', '/params/causal_transformer_shard/~/layer_16/~/linear_2/w', '/params/causal_transformer_shard/~/layer_16/~/linear_3/w', '/params/causal_transformer_shard/~/layer_16/~/linear_4/b', '/params/causal_transformer_shard/~/layer_16/~/linear_4/w', '/params/causal_transformer_shard/~/layer_16/~/linear_5/b', '/params/causal_transformer_shard/~/layer_16/~/linear_5/w', '/params/causal_transformer_shard/~/layer_16/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_16/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_17/~/linear/w', '/params/causal_transformer_shard/~/layer_17/~/linear_1/w', '/params/causal_transformer_shard/~/layer_17/~/linear_2/w', '/params/causal_transformer_shard/~/layer_17/~/linear_3/w', '/params/causal_transformer_shard/~/layer_17/~/linear_4/b', '/params/causal_transformer_shard/~/layer_17/~/linear_4/w', '/params/causal_transformer_shard/~/layer_17/~/linear_5/b', '/params/causal_transformer_shard/~/layer_17/~/linear_5/w', '/params/causal_transformer_shard/~/layer_17/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_17/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_18/~/linear/w', '/params/causal_transformer_shard/~/layer_18/~/linear_1/w', '/params/causal_transformer_shard/~/layer_18/~/linear_2/w', '/params/causal_transformer_shard/~/layer_18/~/linear_3/w', '/params/causal_transformer_shard/~/layer_18/~/linear_4/b', '/params/causal_transformer_shard/~/layer_18/~/linear_4/w', '/params/causal_transformer_shard/~/layer_18/~/linear_5/b', '/params/causal_transformer_shard/~/layer_18/~/linear_5/w', '/params/causal_transformer_shard/~/layer_18/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_18/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_19/~/linear/w', '/params/causal_transformer_shard/~/layer_19/~/linear_1/w', '/params/causal_transformer_shard/~/layer_19/~/linear_2/w', '/params/causal_transformer_shard/~/layer_19/~/linear_3/w', '/params/causal_transformer_shard/~/layer_19/~/linear_4/b', '/params/causal_transformer_shard/~/layer_19/~/linear_4/w', '/params/causal_transformer_shard/~/layer_19/~/linear_5/b', '/params/causal_transformer_shard/~/layer_19/~/linear_5/w', '/params/causal_transformer_shard/~/layer_19/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_19/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_2/~/linear/w', '/params/causal_transformer_shard/~/layer_2/~/linear_1/w', '/params/causal_transformer_shard/~/layer_2/~/linear_2/w', '/params/causal_transformer_shard/~/layer_2/~/linear_3/w', '/params/causal_transformer_shard/~/layer_2/~/linear_4/b', '/params/causal_transformer_shard/~/layer_2/~/linear_4/w', '/params/causal_transformer_shard/~/layer_2/~/linear_5/b', '/params/causal_transformer_shard/~/layer_2/~/linear_5/w', '/params/causal_transformer_shard/~/layer_2/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_2/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_20/~/linear/w', '/params/causal_transformer_shard/~/layer_20/~/linear_1/w', '/params/causal_transformer_shard/~/layer_20/~/linear_2/w', '/params/causal_transformer_shard/~/layer_20/~/linear_3/w', '/params/causal_transformer_shard/~/layer_20/~/linear_4/b', '/params/causal_transformer_shard/~/layer_20/~/linear_4/w', '/params/causal_transformer_shard/~/layer_20/~/linear_5/b', '/params/causal_transformer_shard/~/layer_20/~/linear_5/w', '/params/causal_transformer_shard/~/layer_20/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_20/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_21/~/linear/w', '/params/causal_transformer_shard/~/layer_21/~/linear_1/w', '/params/causal_transformer_shard/~/layer_21/~/linear_2/w', '/params/causal_transformer_shard/~/layer_21/~/linear_3/w', '/params/causal_transformer_shard/~/layer_21/~/linear_4/b', '/params/causal_transformer_shard/~/layer_21/~/linear_4/w', '/params/causal_transformer_shard/~/layer_21/~/linear_5/b', '/params/causal_transformer_shard/~/layer_21/~/linear_5/w', '/params/causal_transformer_shard/~/layer_21/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_21/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_22/~/linear/w', '/params/causal_transformer_shard/~/layer_22/~/linear_1/w', '/params/causal_transformer_shard/~/layer_22/~/linear_2/w', '/params/causal_transformer_shard/~/layer_22/~/linear_3/w', '/params/causal_transformer_shard/~/layer_22/~/linear_4/b', '/params/causal_transformer_shard/~/layer_22/~/linear_4/w', '/params/causal_transformer_shard/~/layer_22/~/linear_5/b', '/params/causal_transformer_shard/~/layer_22/~/linear_5/w', '/params/causal_transformer_shard/~/layer_22/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_22/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_23/~/linear/w', '/params/causal_transformer_shard/~/layer_23/~/linear_1/w', '/params/causal_transformer_shard/~/layer_23/~/linear_2/w', '/params/causal_transformer_shard/~/layer_23/~/linear_3/w', '/params/causal_transformer_shard/~/layer_23/~/linear_4/b', '/params/causal_transformer_shard/~/layer_23/~/linear_4/w', '/params/causal_transformer_shard/~/layer_23/~/linear_5/b', '/params/causal_transformer_shard/~/layer_23/~/linear_5/w', '/params/causal_transformer_shard/~/layer_23/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_23/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_24/~/linear/w', '/params/causal_transformer_shard/~/layer_24/~/linear_1/w', '/params/causal_transformer_shard/~/layer_24/~/linear_2/w', '/params/causal_transformer_shard/~/layer_24/~/linear_3/w', '/params/causal_transformer_shard/~/layer_24/~/linear_4/b', '/params/causal_transformer_shard/~/layer_24/~/linear_4/w', '/params/causal_transformer_shard/~/layer_24/~/linear_5/b', '/params/causal_transformer_shard/~/layer_24/~/linear_5/w', '/params/causal_transformer_shard/~/layer_24/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_24/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_25/~/linear/w', '/params/causal_transformer_shard/~/layer_25/~/linear_1/w', '/params/causal_transformer_shard/~/layer_25/~/linear_2/w', '/params/causal_transformer_shard/~/layer_25/~/linear_3/w', '/params/causal_transformer_shard/~/layer_25/~/linear_4/b', '/params/causal_transformer_shard/~/layer_25/~/linear_4/w', '/params/causal_transformer_shard/~/layer_25/~/linear_5/b', '/params/causal_transformer_shard/~/layer_25/~/linear_5/w', '/params/causal_transformer_shard/~/layer_25/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_25/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_26/~/linear/w', '/params/causal_transformer_shard/~/layer_26/~/linear_1/w', '/params/causal_transformer_shard/~/layer_26/~/linear_2/w', '/params/causal_transformer_shard/~/layer_26/~/linear_3/w', '/params/causal_transformer_shard/~/layer_26/~/linear_4/b', '/params/causal_transformer_shard/~/layer_26/~/linear_4/w', '/params/causal_transformer_shard/~/layer_26/~/linear_5/b', '/params/causal_transformer_shard/~/layer_26/~/linear_5/w', '/params/causal_transformer_shard/~/layer_26/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_26/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_27/~/linear/w', '/params/causal_transformer_shard/~/layer_27/~/linear_1/w', '/params/causal_transformer_shard/~/layer_27/~/linear_2/w', '/params/causal_transformer_shard/~/layer_27/~/linear_3/w', '/params/causal_transformer_shard/~/layer_27/~/linear_4/b', '/params/causal_transformer_shard/~/layer_27/~/linear_4/w', '/params/causal_transformer_shard/~/layer_27/~/linear_5/b', '/params/causal_transformer_shard/~/layer_27/~/linear_5/w', '/params/causal_transformer_shard/~/layer_27/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_27/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_3/~/linear/w', '/params/causal_transformer_shard/~/layer_3/~/linear_1/w', '/params/causal_transformer_shard/~/layer_3/~/linear_2/w', '/params/causal_transformer_shard/~/layer_3/~/linear_3/w', '/params/causal_transformer_shard/~/layer_3/~/linear_4/b', '/params/causal_transformer_shard/~/layer_3/~/linear_4/w', '/params/causal_transformer_shard/~/layer_3/~/linear_5/b', '/params/causal_transformer_shard/~/layer_3/~/linear_5/w', '/params/causal_transformer_shard/~/layer_3/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_3/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_4/~/linear/w', '/params/causal_transformer_shard/~/layer_4/~/linear_1/w', '/params/causal_transformer_shard/~/layer_4/~/linear_2/w', '/params/causal_transformer_shard/~/layer_4/~/linear_3/w', '/params/causal_transformer_shard/~/layer_4/~/linear_4/b', '/params/causal_transformer_shard/~/layer_4/~/linear_4/w', '/params/causal_transformer_shard/~/layer_4/~/linear_5/b', '/params/causal_transformer_shard/~/layer_4/~/linear_5/w', '/params/causal_transformer_shard/~/layer_4/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_4/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_5/~/linear/w', '/params/causal_transformer_shard/~/layer_5/~/linear_1/w', '/params/causal_transformer_shard/~/layer_5/~/linear_2/w', '/params/causal_transformer_shard/~/layer_5/~/linear_3/w', '/params/causal_transformer_shard/~/layer_5/~/linear_4/b', '/params/causal_transformer_shard/~/layer_5/~/linear_4/w', '/params/causal_transformer_shard/~/layer_5/~/linear_5/b', '/params/causal_transformer_shard/~/layer_5/~/linear_5/w', '/params/causal_transformer_shard/~/layer_5/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_5/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_6/~/linear/w', '/params/causal_transformer_shard/~/layer_6/~/linear_1/w', '/params/causal_transformer_shard/~/layer_6/~/linear_2/w', '/params/causal_transformer_shard/~/layer_6/~/linear_3/w', '/params/causal_transformer_shard/~/layer_6/~/linear_4/b', '/params/causal_transformer_shard/~/layer_6/~/linear_4/w', '/params/causal_transformer_shard/~/layer_6/~/linear_5/b', '/params/causal_transformer_shard/~/layer_6/~/linear_5/w', '/params/causal_transformer_shard/~/layer_6/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_6/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_7/~/linear/w', '/params/causal_transformer_shard/~/layer_7/~/linear_1/w', '/params/causal_transformer_shard/~/layer_7/~/linear_2/w', '/params/causal_transformer_shard/~/layer_7/~/linear_3/w', '/params/causal_transformer_shard/~/layer_7/~/linear_4/b', '/params/causal_transformer_shard/~/layer_7/~/linear_4/w', '/params/causal_transformer_shard/~/layer_7/~/linear_5/b', '/params/causal_transformer_shard/~/layer_7/~/linear_5/w', '/params/causal_transformer_shard/~/layer_7/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_7/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_8/~/linear/w', '/params/causal_transformer_shard/~/layer_8/~/linear_1/w', '/params/causal_transformer_shard/~/layer_8/~/linear_2/w', '/params/causal_transformer_shard/~/layer_8/~/linear_3/w', '/params/causal_transformer_shard/~/layer_8/~/linear_4/b', '/params/causal_transformer_shard/~/layer_8/~/linear_4/w', '/params/causal_transformer_shard/~/layer_8/~/linear_5/b', '/params/causal_transformer_shard/~/layer_8/~/linear_5/w', '/params/causal_transformer_shard/~/layer_8/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_8/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_9/~/linear/w', '/params/causal_transformer_shard/~/layer_9/~/linear_1/w', '/params/causal_transformer_shard/~/layer_9/~/linear_2/w', '/params/causal_transformer_shard/~/layer_9/~/linear_3/w', '/params/causal_transformer_shard/~/layer_9/~/linear_4/b', '/params/causal_transformer_shard/~/layer_9/~/linear_4/w', '/params/causal_transformer_shard/~/layer_9/~/linear_5/b', '/params/causal_transformer_shard/~/layer_9/~/linear_5/w', '/params/causal_transformer_shard/~/layer_9/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_9/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/projection_shard/~/linear/b', '/params/causal_transformer_shard/~/projection_shard/~/linear/w', '/params/causal_transformer_shard/~/projection_shard/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/projection_shard/~/replicated_layer_norm/scale', '/step']
# leaves_names_reduced = get_tree_leaves_names_reduced(params)
# print(leaves_names_reduced)
leaves_names_reduced = ['/params/causal_transformer_shard/~/embedding_shard/~/linear/b', '/params/causal_transformer_shard/~/embedding_shard/~/linear/w', '/params/causal_transformer_shard/~/layer_0/~/linear/w', '/params/causal_transformer_shard/~/layer_0/~/linear_1/w', '/params/causal_transformer_shard/~/layer_0/~/linear_2/w', '/params/causal_transformer_shard/~/layer_0/~/linear_3/w', '/params/causal_transformer_shard/~/layer_0/~/linear_4/b', '/params/causal_transformer_shard/~/layer_0/~/linear_4/w', '/params/causal_transformer_shard/~/layer_0/~/linear_5/b', '/params/causal_transformer_shard/~/layer_0/~/linear_5/w', '/params/causal_transformer_shard/~/layer_0/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_0/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_1/~/linear/w', '/params/causal_transformer_shard/~/layer_1/~/linear_1/w', '/params/causal_transformer_shard/~/layer_1/~/linear_2/w', '/params/causal_transformer_shard/~/layer_1/~/linear_3/w', '/params/causal_transformer_shard/~/layer_1/~/linear_4/b', '/params/causal_transformer_shard/~/layer_1/~/linear_4/w', '/params/causal_transformer_shard/~/layer_1/~/linear_5/b', '/params/causal_transformer_shard/~/layer_1/~/linear_5/w', '/params/causal_transformer_shard/~/layer_1/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_1/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_10/~/linear/w', '/params/causal_transformer_shard/~/layer_10/~/linear_1/w', '/params/causal_transformer_shard/~/layer_10/~/linear_2/w', '/params/causal_transformer_shard/~/layer_10/~/linear_3/w', '/params/causal_transformer_shard/~/layer_10/~/linear_4/b', '/params/causal_transformer_shard/~/layer_10/~/linear_4/w', '/params/causal_transformer_shard/~/layer_10/~/linear_5/b', '/params/causal_transformer_shard/~/layer_10/~/linear_5/w', '/params/causal_transformer_shard/~/layer_10/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_10/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_11/~/linear/w', '/params/causal_transformer_shard/~/layer_11/~/linear_1/w', '/params/causal_transformer_shard/~/layer_11/~/linear_2/w', '/params/causal_transformer_shard/~/layer_11/~/linear_3/w', '/params/causal_transformer_shard/~/layer_11/~/linear_4/b', '/params/causal_transformer_shard/~/layer_11/~/linear_4/w', '/params/causal_transformer_shard/~/layer_11/~/linear_5/b', '/params/causal_transformer_shard/~/layer_11/~/linear_5/w', '/params/causal_transformer_shard/~/layer_11/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_11/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_12/~/linear/w', '/params/causal_transformer_shard/~/layer_12/~/linear_1/w', '/params/causal_transformer_shard/~/layer_12/~/linear_2/w', '/params/causal_transformer_shard/~/layer_12/~/linear_3/w', '/params/causal_transformer_shard/~/layer_12/~/linear_4/b', '/params/causal_transformer_shard/~/layer_12/~/linear_4/w', '/params/causal_transformer_shard/~/layer_12/~/linear_5/b', '/params/causal_transformer_shard/~/layer_12/~/linear_5/w', '/params/causal_transformer_shard/~/layer_12/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_12/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_13/~/linear/w', '/params/causal_transformer_shard/~/layer_13/~/linear_1/w', '/params/causal_transformer_shard/~/layer_13/~/linear_2/w', '/params/causal_transformer_shard/~/layer_13/~/linear_3/w', '/params/causal_transformer_shard/~/layer_13/~/linear_4/b', '/params/causal_transformer_shard/~/layer_13/~/linear_4/w', '/params/causal_transformer_shard/~/layer_13/~/linear_5/b', '/params/causal_transformer_shard/~/layer_13/~/linear_5/w', '/params/causal_transformer_shard/~/layer_13/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_13/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_14/~/linear/w', '/params/causal_transformer_shard/~/layer_14/~/linear_1/w', '/params/causal_transformer_shard/~/layer_14/~/linear_2/w', '/params/causal_transformer_shard/~/layer_14/~/linear_3/w', '/params/causal_transformer_shard/~/layer_14/~/linear_4/b', '/params/causal_transformer_shard/~/layer_14/~/linear_4/w', '/params/causal_transformer_shard/~/layer_14/~/linear_5/b', '/params/causal_transformer_shard/~/layer_14/~/linear_5/w', '/params/causal_transformer_shard/~/layer_14/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_14/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_15/~/linear/w', '/params/causal_transformer_shard/~/layer_15/~/linear_1/w', '/params/causal_transformer_shard/~/layer_15/~/linear_2/w', '/params/causal_transformer_shard/~/layer_15/~/linear_3/w', '/params/causal_transformer_shard/~/layer_15/~/linear_4/b', '/params/causal_transformer_shard/~/layer_15/~/linear_4/w', '/params/causal_transformer_shard/~/layer_15/~/linear_5/b', '/params/causal_transformer_shard/~/layer_15/~/linear_5/w', '/params/causal_transformer_shard/~/layer_15/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_15/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_16/~/linear/w', '/params/causal_transformer_shard/~/layer_16/~/linear_1/w', '/params/causal_transformer_shard/~/layer_16/~/linear_2/w', '/params/causal_transformer_shard/~/layer_16/~/linear_3/w', '/params/causal_transformer_shard/~/layer_16/~/linear_4/b', '/params/causal_transformer_shard/~/layer_16/~/linear_4/w', '/params/causal_transformer_shard/~/layer_16/~/linear_5/b', '/params/causal_transformer_shard/~/layer_16/~/linear_5/w', '/params/causal_transformer_shard/~/layer_16/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_16/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_17/~/linear/w', '/params/causal_transformer_shard/~/layer_17/~/linear_1/w', '/params/causal_transformer_shard/~/layer_17/~/linear_2/w', '/params/causal_transformer_shard/~/layer_17/~/linear_3/w', '/params/causal_transformer_shard/~/layer_17/~/linear_4/b', '/params/causal_transformer_shard/~/layer_17/~/linear_4/w', '/params/causal_transformer_shard/~/layer_17/~/linear_5/b', '/params/causal_transformer_shard/~/layer_17/~/linear_5/w', '/params/causal_transformer_shard/~/layer_17/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_17/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_18/~/linear/w', '/params/causal_transformer_shard/~/layer_18/~/linear_1/w', '/params/causal_transformer_shard/~/layer_18/~/linear_2/w', '/params/causal_transformer_shard/~/layer_18/~/linear_3/w', '/params/causal_transformer_shard/~/layer_18/~/linear_4/b', '/params/causal_transformer_shard/~/layer_18/~/linear_4/w', '/params/causal_transformer_shard/~/layer_18/~/linear_5/b', '/params/causal_transformer_shard/~/layer_18/~/linear_5/w', '/params/causal_transformer_shard/~/layer_18/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_18/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_19/~/linear/w', '/params/causal_transformer_shard/~/layer_19/~/linear_1/w', '/params/causal_transformer_shard/~/layer_19/~/linear_2/w', '/params/causal_transformer_shard/~/layer_19/~/linear_3/w', '/params/causal_transformer_shard/~/layer_19/~/linear_4/b', '/params/causal_transformer_shard/~/layer_19/~/linear_4/w', '/params/causal_transformer_shard/~/layer_19/~/linear_5/b', '/params/causal_transformer_shard/~/layer_19/~/linear_5/w', '/params/causal_transformer_shard/~/layer_19/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_19/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_2/~/linear/w', '/params/causal_transformer_shard/~/layer_2/~/linear_1/w', '/params/causal_transformer_shard/~/layer_2/~/linear_2/w', '/params/causal_transformer_shard/~/layer_2/~/linear_3/w', '/params/causal_transformer_shard/~/layer_2/~/linear_4/b', '/params/causal_transformer_shard/~/layer_2/~/linear_4/w', '/params/causal_transformer_shard/~/layer_2/~/linear_5/b', '/params/causal_transformer_shard/~/layer_2/~/linear_5/w', '/params/causal_transformer_shard/~/layer_2/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_2/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_20/~/linear/w', '/params/causal_transformer_shard/~/layer_20/~/linear_1/w', '/params/causal_transformer_shard/~/layer_20/~/linear_2/w', '/params/causal_transformer_shard/~/layer_20/~/linear_3/w', '/params/causal_transformer_shard/~/layer_20/~/linear_4/b', '/params/causal_transformer_shard/~/layer_20/~/linear_4/w', '/params/causal_transformer_shard/~/layer_20/~/linear_5/b', '/params/causal_transformer_shard/~/layer_20/~/linear_5/w', '/params/causal_transformer_shard/~/layer_20/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_20/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_21/~/linear/w', '/params/causal_transformer_shard/~/layer_21/~/linear_1/w', '/params/causal_transformer_shard/~/layer_21/~/linear_2/w', '/params/causal_transformer_shard/~/layer_21/~/linear_3/w', '/params/causal_transformer_shard/~/layer_21/~/linear_4/b', '/params/causal_transformer_shard/~/layer_21/~/linear_4/w', '/params/causal_transformer_shard/~/layer_21/~/linear_5/b', '/params/causal_transformer_shard/~/layer_21/~/linear_5/w', '/params/causal_transformer_shard/~/layer_21/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_21/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_22/~/linear/w', '/params/causal_transformer_shard/~/layer_22/~/linear_1/w', '/params/causal_transformer_shard/~/layer_22/~/linear_2/w', '/params/causal_transformer_shard/~/layer_22/~/linear_3/w', '/params/causal_transformer_shard/~/layer_22/~/linear_4/b', '/params/causal_transformer_shard/~/layer_22/~/linear_4/w', '/params/causal_transformer_shard/~/layer_22/~/linear_5/b', '/params/causal_transformer_shard/~/layer_22/~/linear_5/w', '/params/causal_transformer_shard/~/layer_22/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_22/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_23/~/linear/w', '/params/causal_transformer_shard/~/layer_23/~/linear_1/w', '/params/causal_transformer_shard/~/layer_23/~/linear_2/w', '/params/causal_transformer_shard/~/layer_23/~/linear_3/w', '/params/causal_transformer_shard/~/layer_23/~/linear_4/b', '/params/causal_transformer_shard/~/layer_23/~/linear_4/w', '/params/causal_transformer_shard/~/layer_23/~/linear_5/b', '/params/causal_transformer_shard/~/layer_23/~/linear_5/w', '/params/causal_transformer_shard/~/layer_23/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_23/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_24/~/linear/w', '/params/causal_transformer_shard/~/layer_24/~/linear_1/w', '/params/causal_transformer_shard/~/layer_24/~/linear_2/w', '/params/causal_transformer_shard/~/layer_24/~/linear_3/w', '/params/causal_transformer_shard/~/layer_24/~/linear_4/b', '/params/causal_transformer_shard/~/layer_24/~/linear_4/w', '/params/causal_transformer_shard/~/layer_24/~/linear_5/b', '/params/causal_transformer_shard/~/layer_24/~/linear_5/w', '/params/causal_transformer_shard/~/layer_24/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_24/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_25/~/linear/w', '/params/causal_transformer_shard/~/layer_25/~/linear_1/w', '/params/causal_transformer_shard/~/layer_25/~/linear_2/w', '/params/causal_transformer_shard/~/layer_25/~/linear_3/w', '/params/causal_transformer_shard/~/layer_25/~/linear_4/b', '/params/causal_transformer_shard/~/layer_25/~/linear_4/w', '/params/causal_transformer_shard/~/layer_25/~/linear_5/b', '/params/causal_transformer_shard/~/layer_25/~/linear_5/w', '/params/causal_transformer_shard/~/layer_25/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_25/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_26/~/linear/w', '/params/causal_transformer_shard/~/layer_26/~/linear_1/w', '/params/causal_transformer_shard/~/layer_26/~/linear_2/w', '/params/causal_transformer_shard/~/layer_26/~/linear_3/w', '/params/causal_transformer_shard/~/layer_26/~/linear_4/b', '/params/causal_transformer_shard/~/layer_26/~/linear_4/w', '/params/causal_transformer_shard/~/layer_26/~/linear_5/b', '/params/causal_transformer_shard/~/layer_26/~/linear_5/w', '/params/causal_transformer_shard/~/layer_26/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_26/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_27/~/linear/w', '/params/causal_transformer_shard/~/layer_27/~/linear_1/w', '/params/causal_transformer_shard/~/layer_27/~/linear_2/w', '/params/causal_transformer_shard/~/layer_27/~/linear_3/w', '/params/causal_transformer_shard/~/layer_27/~/linear_4/b', '/params/causal_transformer_shard/~/layer_27/~/linear_4/w', '/params/causal_transformer_shard/~/layer_27/~/linear_5/b', '/params/causal_transformer_shard/~/layer_27/~/linear_5/w', '/params/causal_transformer_shard/~/layer_27/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_27/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_3/~/linear/w', '/params/causal_transformer_shard/~/layer_3/~/linear_1/w', '/params/causal_transformer_shard/~/layer_3/~/linear_2/w', '/params/causal_transformer_shard/~/layer_3/~/linear_3/w', '/params/causal_transformer_shard/~/layer_3/~/linear_4/b', '/params/causal_transformer_shard/~/layer_3/~/linear_4/w', '/params/causal_transformer_shard/~/layer_3/~/linear_5/b', '/params/causal_transformer_shard/~/layer_3/~/linear_5/w', '/params/causal_transformer_shard/~/layer_3/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_3/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_4/~/linear/w', '/params/causal_transformer_shard/~/layer_4/~/linear_1/w', '/params/causal_transformer_shard/~/layer_4/~/linear_2/w', '/params/causal_transformer_shard/~/layer_4/~/linear_3/w', '/params/causal_transformer_shard/~/layer_4/~/linear_4/b', '/params/causal_transformer_shard/~/layer_4/~/linear_4/w', '/params/causal_transformer_shard/~/layer_4/~/linear_5/b', '/params/causal_transformer_shard/~/layer_4/~/linear_5/w', '/params/causal_transformer_shard/~/layer_4/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_4/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_5/~/linear/w', '/params/causal_transformer_shard/~/layer_5/~/linear_1/w', '/params/causal_transformer_shard/~/layer_5/~/linear_2/w', '/params/causal_transformer_shard/~/layer_5/~/linear_3/w', '/params/causal_transformer_shard/~/layer_5/~/linear_4/b', '/params/causal_transformer_shard/~/layer_5/~/linear_4/w', '/params/causal_transformer_shard/~/layer_5/~/linear_5/b', '/params/causal_transformer_shard/~/layer_5/~/linear_5/w', '/params/causal_transformer_shard/~/layer_5/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_5/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_6/~/linear/w', '/params/causal_transformer_shard/~/layer_6/~/linear_1/w', '/params/causal_transformer_shard/~/layer_6/~/linear_2/w', '/params/causal_transformer_shard/~/layer_6/~/linear_3/w', '/params/causal_transformer_shard/~/layer_6/~/linear_4/b', '/params/causal_transformer_shard/~/layer_6/~/linear_4/w', '/params/causal_transformer_shard/~/layer_6/~/linear_5/b', '/params/causal_transformer_shard/~/layer_6/~/linear_5/w', '/params/causal_transformer_shard/~/layer_6/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_6/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_7/~/linear/w', '/params/causal_transformer_shard/~/layer_7/~/linear_1/w', '/params/causal_transformer_shard/~/layer_7/~/linear_2/w', '/params/causal_transformer_shard/~/layer_7/~/linear_3/w', '/params/causal_transformer_shard/~/layer_7/~/linear_4/b', '/params/causal_transformer_shard/~/layer_7/~/linear_4/w', '/params/causal_transformer_shard/~/layer_7/~/linear_5/b', '/params/causal_transformer_shard/~/layer_7/~/linear_5/w', '/params/causal_transformer_shard/~/layer_7/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_7/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_8/~/linear/w', '/params/causal_transformer_shard/~/layer_8/~/linear_1/w', '/params/causal_transformer_shard/~/layer_8/~/linear_2/w', '/params/causal_transformer_shard/~/layer_8/~/linear_3/w', '/params/causal_transformer_shard/~/layer_8/~/linear_4/b', '/params/causal_transformer_shard/~/layer_8/~/linear_4/w', '/params/causal_transformer_shard/~/layer_8/~/linear_5/b', '/params/causal_transformer_shard/~/layer_8/~/linear_5/w', '/params/causal_transformer_shard/~/layer_8/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_8/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/layer_9/~/linear/w', '/params/causal_transformer_shard/~/layer_9/~/linear_1/w', '/params/causal_transformer_shard/~/layer_9/~/linear_2/w', '/params/causal_transformer_shard/~/layer_9/~/linear_3/w', '/params/causal_transformer_shard/~/layer_9/~/linear_4/b', '/params/causal_transformer_shard/~/layer_9/~/linear_4/w', '/params/causal_transformer_shard/~/layer_9/~/linear_5/b', '/params/causal_transformer_shard/~/layer_9/~/linear_5/w', '/params/causal_transformer_shard/~/layer_9/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/layer_9/~/replicated_layer_norm/scale', '/params/causal_transformer_shard/~/projection_shard/~/linear/b', '/params/causal_transformer_shard/~/projection_shard/~/linear/w', '/params/causal_transformer_shard/~/projection_shard/~/replicated_layer_norm/offset', '/params/causal_transformer_shard/~/projection_shard/~/replicated_layer_norm/scale', '/step']
# TODO(nijkamp): overwriting the optimizer mutates the pytree in order to reduce memory alloc, but this will break the serialization format, serialize model into optim / param files separately to clean this mess
params['optimizer'] = optax.scale(0)
params['sampler'] = nucleaus_sample
devices = np.array([jax.devices()[0]]).reshape((1, 1))
with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
print('CausalTransformer', 'begin')
t0 = time.time()
network = CausalTransformer(params)
print('CausalTransformer', 'end', f'{time.time() - t0:06}')
print('read_ckpt', 'begin')
t0 = time.time()
network.state = read_ckpt(network.state, leaves_names_original, leaves_names_reduced, '/export/home/gptj/step_383500/', shards_in=8, shards_out=1)
print('read_ckpt', 'end', f'{time.time() - t0:06}')
print('device_put', 'begin')
t0 = time.time()
network.state = jax.device_put(network.state, jax.devices("cpu")[0])
print('device_put', 'end', f'{time.time() - t0:06}')
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
print('sample', 'begin')
t0 = time.time()
print(sample(tokenizer=tokenizer, per_replica_batch=params['per_replica_batch'], seq=params['seq'], network=network, context='EleutherAI is'))
print('sample', 'end', f'{time.time() - t0:06}')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment