This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
hyperparameter_ranges = { | |
"lr": ContinuousParameter(0.01, 0.1), | |
"embedding-dim": CategoricalParameter([6, 12]), | |
"hidden-dim": CategoricalParameter([6, 12]) | |
} | |
objective_metric_name = "validation loss" | |
objective_type = "Minimize" | |
metric_definitions = [ | |
{"Name": objective_metric_name, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
estimator = PyTorch(entry_point="tuning.py", | |
source_dir="../../allennlp-sagemaker-tuning", | |
dependencies=[from_root("example"), from_root(".venv")], | |
role=role, | |
framework_version="1.0.0", | |
train_instance_count=1, | |
train_instance_type="ml.p2.8xlarge", | |
hyperparameters={ | |
"train-file-name": os.path.basename(s3_paths[0]), | |
"validation": os.path.basename(s3_paths[1]), |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
estimator = PyTorch(entry_point="tuning.py", | |
role=role, | |
framework_version="1.0.0", | |
train_instance_count=1, | |
train_instance_type="ml.p2.8xlarge", | |
hyperparameters={ | |
"train-file-name": os.path.basename(s3_paths[0]), | |
"validation": os.path.basename(s3_paths[1]), | |
"epochs": 10 | |
}) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = LstmTagger(word_embeddings, lstm, vocab) | |
optimizer = optim.SGD(model.parameters(), lr=0.1) | |
iterator = BucketIterator(batch_size=2, sorting_keys=[("sentence", "num_tokens")]) | |
iterator.index_with(vocab) | |
trainer = Trainer(model=model, | |
optimizer=optimizer, | |
iterator=iterator, | |
train_dataset=train_dataset, | |
validation_dataset=validation_dataset, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LstmTagger(Model): | |
def __init__(self, | |
word_embeddings: TextFieldEmbedder, | |
encoder: Seq2SeqEncoder, | |
vocab: Vocabulary) -> None: | |
super().__init__(vocab) | |
self.word_embeddings = word_embeddings | |
self.encoder = encoder | |
self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(), |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class EDINETGetDocumentSensor(BaseSensorOperator): | |
@apply_defaults | |
def __init__(self, document_type="xbrl", *args, **kwargs): | |
self.document_type = document_type | |
self._next_document_index = -1 | |
super().__init__(*args, **kwargs) | |
def poke(self, context): | |
document_ids = context["task_instance"].xcom_pull( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class EDINETGetDocumentsOperator(BaseOperator): | |
@apply_defaults | |
def __init__(self, filter_func=None, *args, **kwargs): | |
self.filter_func = filter_func | |
super().__init__(*args, **kwargs) | |
def execute(self, context): | |
self.log.info("Retreave list of documents from EDINET @ {}.".format( | |
self.start_date.strftime("%Y/%m/%d"))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def update(self, states, actions, rewards, values): | |
# Calculate values (or advantage) at outside of update process. | |
advantage = reward - values | |
action_probs = self.actor(states) | |
selected_action_probs = action_probs[self.to_one_hot(actions)] | |
neg_logs = - log(selected_action_probs) | |
policy_loss = reduce_mean(neg_logs * advantages) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def update(self, states, actions, rewards): | |
values = self.critic(states) | |
advantage = reward - tf.stop_gradient(values) # Prevent gradient flows to critic | |
action_probs = self.actor(states) | |
selected_action_probs = action_probs[self.to_one_hot(actions)] | |
neg_logs = - log(selected_action_probs) | |
policy_loss = reduce_mean(neg_logs * advantages) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def update(self, states, actions, rewards): | |
values = self.critic(states) | |
advantage = reward - values | |
action_probs = self.actor(states) | |
selected_action_probs = action_probs[self.to_one_hot(actions)] | |
neg_logs = - log(selected_action_probs) | |
# If backprop executed, gradient of policy_loss will affect critic! | |
policy_loss = reduce_mean(neg_logs * advantages) | |