Created
November 27, 2016 18:47
-
-
Save ShigekiKarita/4db4fd506f6fe6167840b8395224c677 to your computer and use it in GitHub Desktop.
accumulate config of attention-lvcsr
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import six | |
from os.path import basename | |
import yaml | |
root = "/home/karita/tool/attention-lvcsr" | |
cfgs = [] | |
def recursive_update(original, update): | |
""" | |
Recursively update a dict. | |
Subdict's won't be overwritten but also updated. | |
""" | |
for key, value in original.items(): | |
if key not in update: | |
try: | |
update[key] = value | |
except Exception as e: | |
print(key) | |
print(str(value)) | |
raise e | |
elif isinstance(value, dict): | |
recursive_update(value, update[key]) | |
return update | |
def read(path, config={}): | |
p = root + path.replace("$LVSR", "") | |
print(p) | |
y = yaml.load(open(p, "r"), Loader=yaml.Loader) | |
cfgs.append(y) | |
if basename(p) == "prototype_speech.yaml": | |
return | |
if "parent" in y.keys(): | |
read(y["parent"], config) | |
read("/exp/wsj/configs/wsj_paper7.yaml") | |
merge = {} | |
for c in cfgs: | |
recursive_update(c, merge) | |
def tostr(d): | |
for k, v in d.items(): | |
if isinstance(v, dict): | |
tostr(v) | |
else: | |
d[k] = str(v) | |
print(merge) | |
yaml.dump(tostr(merge), open("result.yaml", "w")) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
data: | |
add_bos: '1' | |
add_eos: 'True' | |
batch_size: '10' | |
dataset_class: <class 'lvsr.datasets.h5py.H5PYAudioDataset'> | |
dataset_filename: wsj.h5 | |
default_sources: '[''recordings'', ''labels'']' | |
name_mapping: {test: test_eval92, train: train_si284, valid: test_dev93} | |
normalization: '' | |
sources_map: {labels: characters, recordings: fbank_dd, uttids: uttids} | |
initialization: | |
/recognizer: {biases_init: Constant(_constant=0.0), rec_weights_init: 'IsotropicGaussian(_mean=0,_std=0.1)', | |
weights_init: 'IsotropicGaussian(_mean=0,_std=0.1)'} | |
monitoring: | |
search: {beam_size: '10', char_discount: '0.1', round_to_inf: '1000000000.0', stop_on: optimistic_future_cost} | |
search_every_batches: '0' | |
search_every_epochs: '1' | |
validate_every_batches: '0' | |
validate_every_epochs: '1' | |
net: | |
attention_type: content_and_conv | |
bottom: {activation: '<blocks.bricks.simple.Rectifier object at 0x7fcef1e2a890: | |
name=rectifier>', bottom_class: <class 'lvsr.bricks.recognizer.SpeechBottom'>, | |
dims: '[]'} | |
conv_n: '100' | |
criterion: {name: log_likelihood} | |
dec_transition: <class 'blocks.bricks.recurrent.GatedRecurrent'> | |
dim_dec: '250' | |
dims_bidir: '[250, 250, 250, 250]' | |
enc_transition: <class 'blocks.bricks.recurrent.GatedRecurrent'> | |
lm: {} | |
max_decoded_length_scale: '3.0' | |
post_merge_activation: '<blocks.bricks.simple.Rectifier object at 0x7fcef1e2a1d0: | |
name=rectifier>' | |
post_merge_dims: '[250]' | |
prior: {after: '100', before: '100', initial_begin: '0', initial_end: '80', max_speed: '4.4', | |
min_speed: '2.4', type: window_around_median} | |
subsample: '[1, 1, 2, 2]' | |
use_states_for_readout: 'True' | |
parent: $LVSR/exp/wsj/configs/wsj_paper.yaml | |
regularization: {dropout: 'False', max_norm: '1.0'} | |
stages: | |
annealing1: | |
number: '200' | |
training: {epsilon: 1e-10, num_epochs: '3', restart_from: _best} | |
annealing2: | |
number: '300' | |
training: {epsilon: 1e-12, num_epochs: '3', restart_from: _best_ll} | |
main: | |
number: '100' | |
training: {num_epochs: '15', restart_from: _best} | |
pretraining: | |
net: | |
prior: {initial_begin: '0', initial_end: '40', max_speed: '2.2', min_speed: '1.2', | |
type: expanding} | |
number: '0' | |
training: {num_epochs: '4'} | |
training: {decay_rate: '0.95', epsilon: 1e-08, gradient_threshold: '100.0', momentum: '0.0', | |
rules: '[''momentum'', ''adadelta'']', scale: '0.1'} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment