Created
March 3, 2019 20:43
-
-
Save hartikainen/1226bd3ab5cddc6ae6d8fc373b42a9bd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
from shutil import copyfile | |
import numpy as np | |
import pandas as pd | |
def get_result_path(trial_dir): | |
return os.path.join(trial_dir, "progress.csv") | |
def get_result_backup_path(result_path): | |
return f"{result_path}.backup" | |
def result_backed_up(trial_dir): | |
result_path = get_result_path(trial_dir=trial_dir) | |
result_backup_path = get_result_backup_path(result_path=result_path) | |
return os.path.exists(result_backup_path) | |
def backup_result(trial_dir): | |
result_path = get_result_path(trial_dir=trial_dir) | |
result_backup_path = get_result_backup_path(result_path=result_path) | |
copyfile(result_path, result_backup_path) | |
def prune_restore_dataframes(dataframes): | |
if len(dataframes) < 2: | |
return dataframes[0] | |
dataframe1 = dataframes[0] | |
dataframe2 = prune_restore_dataframes(dataframes[1:]) | |
return max(dataframe1, dataframe2, key=lambda d: d.shape[0]) | |
def clean_dataframe(dataframe): | |
restore_begins = tuple( | |
np.where(dataframe['iterations_since_restore'] == 1)[0]) | |
restore_ends = (*restore_begins[1:], dataframe.shape[0]) | |
restore_dataframes_index = tuple(zip(restore_begins, restore_ends)) | |
restore_dataframes = tuple( | |
dataframe[slice(*restore_dataframe_index)] | |
for restore_dataframe_index in restore_dataframes_index) | |
cleaned_dataframe = prune_restore_dataframes(restore_dataframes) | |
return cleaned_dataframe | |
def clean_trial(trial_dir): | |
result_path = get_result_path(trial_dir=trial_dir) | |
result_backup_path = get_result_backup_path(result_path=result_path) | |
print(result_backup_path) | |
dataframe = pd.read_csv(result_backup_path) | |
cleaned_dataframe = clean_dataframe(dataframe) | |
cleaned_dataframe.to_csv(result_path, index=False) | |
def fix_ray_results(experiment_dir): | |
trial_dirs = [ | |
os.path.join(experiment_dir, trial_dir) | |
for trial_dir in next(os.walk(experiment_dir))[1] | |
] | |
for trial_dir in trial_dirs: | |
if not result_backed_up(trial_dir): | |
backup_result(trial_dir) | |
clean_trial(trial_dir) | |
def main(): | |
experiment_dir = sys.argv[1] | |
fix_ray_results(experiment_dir) | |
if __name__ == '__main__': | |
main() | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import numpy as np | |
import pandas as pd | |
from fix_results import clean_dataframe | |
TEST_CASES = ( | |
{ | |
'kwargs': { | |
'dataframe': pd.DataFrame({ | |
'iterations_since_restore': ( | |
*range(1, 4), | |
*range(1, 5), | |
), | |
}) | |
}, | |
'expected_output': pd.DataFrame({ | |
'iterations_since_restore': tuple(range(1, 5)), | |
}, index=tuple(range(3, 7))) | |
}, | |
{ | |
'kwargs': { | |
'dataframe': pd.DataFrame({ | |
'iterations_since_restore': ( | |
*range(1, 4), | |
*range(1, 3), | |
*range(1, 2), | |
), | |
}) | |
}, | |
'expected_output': pd.DataFrame({ | |
'iterations_since_restore': tuple(range(1, 4)) | |
}, index=tuple(range(0, 3))) | |
}, | |
{ | |
'kwargs': { | |
'dataframe': pd.DataFrame({ | |
'iterations_since_restore': tuple(range(1, 11)), | |
}) | |
}, | |
'expected_output': pd.DataFrame({ | |
'iterations_since_restore': tuple(range(1, 11)), | |
}, index=tuple(range(0, 10))) | |
}, | |
) | |
class TestFixResults(unittest.TestCase): | |
def test_fix_results(self): | |
for test_case in TEST_CASES: | |
output = clean_dataframe(**test_case['kwargs']) | |
pd.testing.assert_frame_equal( | |
output, test_case['expected_output']) | |
np.testing.assert_equal( | |
output['iterations_since_restore'], | |
np.arange(1, output.shape[0] + 1)) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment