Last active
January 15, 2022 23:33
-
-
Save ysdede/0e3283385eed5953d085082e5c3614f1 to your computer and use it in GitHub Desktop.
Pick best parameters from Optuna database
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import optuna | |
import statistics | |
import json | |
# This code snippet filters Optuna trials based on performance metrics and standard deviation | |
# and generates a results.csv file containing all results | |
# and a SEQ.py file containing all hyperparameters | |
study = optuna.create_study(study_name="Band5min-LongOnly", directions=["maximize", "maximize"], | |
storage="postgresql://optuna_user:optuna_password@localhost/optuna_db_3", load_if_exists=True) | |
def print_best_params(): | |
print("Number of finished trials: ", len(study.trials)) | |
trials = study.trials | |
results = [] | |
parameter_list = [] # to eliminate redundant trials with same parameters | |
candidates = {} | |
score_treshold = 1 | |
std_dev_treshold = 4 | |
from jesse.routes import router | |
import jesse.helpers as jh | |
r = router.routes[0] | |
StrategyClass = jh.get_strategy_class(r.strategy_name) | |
r.strategy = StrategyClass() | |
for trial in trials: | |
print(type(trial), trial.state) | |
if trial.state != optuna.structs.TrialState.COMPLETE: | |
continue | |
if any(v < -1 for v in trial.values): | |
continue | |
if (not trial.user_attrs['trades1']) or trial.user_attrs['trades1'] < 50 or trial.values[0] < 1: | |
continue | |
mean_value = round(statistics.mean((*trial.values, trial.user_attrs['sharpe3'])), 3) | |
std_dev = round(statistics.stdev((*trial.values, trial.user_attrs['sharpe3'])), 5) | |
rounded_params = trial.params | |
hp_new = {} | |
# Sort hyperparameters as defined in the strategy | |
for p in r.strategy.hyperparameters(): | |
hp_new[p['name']] = rounded_params[p['name']] | |
rounded_params = hp_new | |
result_line = [trial.number, *trial.values, trial.user_attrs['sharpe3'], | |
trial.user_attrs['trades1'], trial.user_attrs['trades2'], trial.user_attrs['trades3'], | |
trial.user_attrs['fees1'], trial.user_attrs['fees2'], trial.user_attrs['fees3'], | |
trial.user_attrs['wr1'], trial.user_attrs['wr2'], trial.user_attrs['wr3'], | |
mean_value, std_dev, rounded_params] | |
# If parameters meet criteria, add to candidates | |
if trial.params not in parameter_list and mean_value > score_treshold and std_dev < std_dev_treshold and trial.user_attrs['sharpe3'] > 2: | |
results.append(result_line) | |
parameter_list.append(trial.params) | |
longest_param = 0 | |
for v in rounded_params.values(): | |
if len(str(v)) > longest_param: | |
longest_param = len(str(v)) | |
hash = ''.join([f'{value:0>{longest_param}}' for key, value in rounded_params.items()]) | |
hash = f'{hash}{longest_param}' | |
candidates[hash] = rounded_params | |
# Use it! | |
sorted_results = sorted(results, key=lambda x: x[2], reverse=True) | |
print(len(results)) | |
import csv | |
# field names | |
fields = ['Trial #', 'Score1', 'Score2', 'Score3', | |
'Trades1', 'Trades2', 'Trades3', | |
'Fees1', 'Fees2', 'Fees3', | |
'Winrate1', 'Winrate2', 'Winrate3', | |
'Average', 'Deviation', | |
'Parameters'] | |
with open('Results.csv', 'w') as f: | |
write = csv.writer(f, delimiter='\t', lineterminator='\n') | |
write.writerow(fields) | |
write.writerows(results) | |
with open('SEQ.py', 'w') as f: | |
f.write("hps = ") | |
f.write(json.dumps(candidates, indent=1)) | |
print_best_params() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment