Created
December 30, 2020 04:38
-
-
Save CalebEverett/8b319f7775d4a737a083649d666ddafd to your computer and use it in GitHub Desktop.
config.gin for learning trax
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import trax.layers | |
import trax.models | |
import trax.optimizers | |
import trax.data.inputs | |
import trax.supervised.trainer_lib | |
# Parameters that will vary between experiments: | |
# ============================================================================== | |
max_len = 128 | |
input_vocab_size = 13500 | |
output_vocab_size = 3 | |
d_model = 512 | |
d_ff = 512 | |
# Parameters for multifactor: | |
# ============================================================================== | |
multifactor.constant = 0.01 | |
multifactor.factors = 'constant * linear_warmup * cosine_decay' | |
multifactor.warmup_steps = 100 | |
multifactor.steps_per_cycle = 900 | |
# Parameters for Adam: | |
# ============================================================================== | |
Adam.weight_decay_rate=0.0 | |
Adam.b1 = 0.86 | |
Adam.b2 = 0.92 | |
Adam.eps = 1e-9 | |
# Parameters for Transformer: | |
# ============================================================================== | |
# TransformerEncoder.vocab_size = %input_vocab_size | |
# TransformerEncoder.n_classes = 2 | |
# TransformerEncoder.d_ff = 512 | |
# TransformerEncoder.n_layers = 2 | |
# TransformerEncoder.max_len = %max_len | |
# Transformer.input_vocab_size = %input_vocab_size | |
# Transformer.output_vocab_size = %output_vocab_size | |
# Transformer.max_len = %max_len | |
# Transformer.d_model = %d_model | |
# Transformer.d_ff = %d_ff | |
Reformer.input_vocab_size = %input_vocab_size | |
Reformer.output_vocab_size = %output_vocab_size | |
Reformer.max_len = %max_len | |
Reformer.d_model = %d_model | |
Reformer.d_ff = %d_ff | |
# Parameters for inputs: | |
# ============================================================================== | |
train/get_ds_tfrec.folds = [0] | |
eval/get_ds_tfrec.folds = [1] | |
get_ds_tfrec.len_seq = %max_len | |
batcher.data_streams = [@train/get_ds_tfrec(), @eval/get_ds_tfrec()] | |
batcher.variable_shapes = False | |
batcher.batch_size_per_device = 8 | |
batcher.id_to_mask = 0 | |
# Parameters for train: | |
# ============================================================================== | |
train.eval_frequency = 100 | |
train.eval_steps = 10 | |
train.optimizer = @trax.optimizers.Adam | |
train.steps = 1000 | |
train.model = @trax.models.Reformer | |
train.use_memory_efficient_trainer = True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment