Created
November 17, 2023 09:03
-
-
Save S0PEX/7431a73231c3ea3f9b0e832fc2e73af4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_all_Steps_of_pipeline( | |
dummy_pipeline: PipelineTuple, expected_dronology_ypred_ytrue | |
): | |
""" | |
Test loading predictions from a CSV file into the pipeline and calculate | |
the metrics and export the results. | |
We aim to validate that the pipeline can successfully load prediction results from a CSV | |
file. The metrics are tested through the exporter, because they calculate the metrics by | |
using the methods from the Metrics class. | |
""" | |
# Load Prediction | |
pipeline, _, _ = dummy_pipeline | |
pipeline.load_prediction_from_csv( | |
Path( | |
f"{dir_path}/csvs/NoRBERT_Task4_IsFunctional_e10_NoSampling_Dronology-IsFunctional.csv" | |
) | |
) | |
# Assert Prediction | |
y_pred, y_true = expected_dronology_ypred_ytrue | |
assert pipeline._prediction_results[0].result.y_pred[:10] == y_pred | |
assert pipeline._prediction_results[0].result.y_true[:10] == y_true | |
# Assert Metrics and test the exporter | |
metrics = Metrics(pipeline._prediction_results) | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
temp_dir_path = Path(tmp_dir) | |
latex_output_path = temp_dir_path.joinpath('metrics.tex') | |
csv_output_path = temp_dir_path.joinpath('metrics.csv') | |
expected_exporter_output_path = Path( | |
f'{os.path.dirname(__file__)}/expected_exporter_output') | |
# Latex Exporter | |
LatexExporter().export_all_metrics(metrics, str(latex_output_path), None) | |
expected_latex_content = expected_exporter_output_path.joinpath('metrics.tex').read_text() | |
actual_latex_content = latex_output_path.read_text() | |
assert actual_latex_content == expected_latex_content | |
# CSV Exporter | |
CsvExporter().export_all_metrics(metrics, str(csv_output_path), None) | |
expected_csv_content = expected_exporter_output_path.joinpath('metrics.csv').read_text() | |
actual_csv_content = csv_output_path.read_text() | |
assert actual_csv_content == expected_csv_content |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment