Created
August 25, 2020 20:40
-
-
Save viswanathgs/8f5d52e6cece04f9d74c3ace5dc891f7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
commit 3d4383d88346b60d669ce72fa4e2a6d4a5e4861d | |
Author: Viswanath Sivakumar <[email protected]> | |
Date: Tue Aug 25 10:49:59 2020 -0400 | |
TODO nocommit | |
diff --git a/platform/pybmi/pybmi/modeling/keystrokes/corpus.py b/platform/pybmi/pybmi/modeling/keystrokes/corpus.py | |
index 30b4b8fd9..9f1bd7431 100644 | |
--- a/platform/pybmi/pybmi/modeling/keystrokes/corpus.py | |
+++ b/platform/pybmi/pybmi/modeling/keystrokes/corpus.py | |
@@ -89,6 +89,7 @@ def get_single_user_split(corpus: pd.DataFrame, | |
seed=None) -> Tuple[pd.DataFrame, pd.DataFrame]: | |
"""Get a train/validation split for a single-user experiment""" | |
corpus = corpus[corpus["user"] == user] | |
+ # TODO slog corpus = corpus[corpus['condition'].str.contains('off_keyboard')] | |
val = corpus \ | |
.pipe(get_possible_validation_sets) \ | |
.pipe(stratified_sample, n=1, seed=seed) | |
diff --git a/scripts/experiments/typing/common.py b/scripts/experiments/typing/common.py | |
index 81f6d0e2d..a42fc4016 100644 | |
--- a/scripts/experiments/typing/common.py | |
+++ b/scripts/experiments/typing/common.py | |
@@ -7,4 +7,5 @@ def parse_args(): | |
'--arch', type=str, choices=['lstm', 'w2l', 'tds'], default='lstm') | |
parser.add_argument('--batch-size', type=int, default=16) | |
parser.add_argument('--seed', type=int, default=1701) | |
+ parser.add_argument('--pp-stride', type=int, default=80) | |
return parser.parse_args() | |
diff --git a/scripts/experiments/typing/generate_cross_session.py b/scripts/experiments/typing/generate_cross_session.py | |
index f3d95d27e..bee8af472 100644 | |
--- a/scripts/experiments/typing/generate_cross_session.py | |
+++ b/scripts/experiments/typing/generate_cross_session.py | |
@@ -19,7 +19,8 @@ from common import parse_args | |
args = parse_args() | |
corpus = k.get_typing_corpus().loc[lambda df: df["ok?"] == "Y"] | |
-loader_config = k.config_utils.get_cospectrum_loader_config() | |
+loader_config = k.config_utils.get_cospectrum_loader_config( | |
+ pp_stride=args.pp_stride) | |
val_metrics = [ | |
"character_error_rate", | |
@@ -31,12 +32,13 @@ val_metrics = [ | |
# Grid search hyperparameter options | |
training_fixed = { | |
"batch_size": [args.batch_size], | |
- "learning_rate": np.logspace(-3.75, -3.75, 1).tolist(), | |
- "max_epochs": [200], | |
+ "learning_rate": np.logspace(-2.5, -2.5, 1).tolist(), | |
+ "max_epochs": [300], | |
"patience": [30], | |
- "scheduler_step_size": [200], | |
- "scheduler_gamma": [1.0], | |
+ "scheduler_step_size": [100], | |
+ "scheduler_gamma": [0.1], | |
"window_length": [200], | |
+ "scheduler_warmup_epochs": [10], | |
"temporal_jitter": [3], | |
"band_rotations": [(-1, 0, 1)], | |
"val_metrics": [val_metrics], | |
@@ -45,6 +47,7 @@ training_fixed = { | |
"n_f_masks": [2], | |
"freq_mask_param": [2], | |
"time_mask_param": [20], | |
+ "seed": [args.seed], | |
} | |
if args.arch == 'lstm': | |
@@ -57,7 +60,6 @@ elif args.arch == 'tds': | |
"L1": [3], | |
"L2": [1], | |
"kw": [11], | |
- "rpad": [-1] | |
} | |
hyperparameter_fixed = { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment