Created
March 16, 2024 01:54
-
-
Save xennygrimmato/36884260ab4a77345e085400deaef2ac to your computer and use it in GitHub Desktop.
Devin SWE-Bench Analysis
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
from fractions import Fraction | |
import os | |
import re | |
from typing import Dict, List, Set, Tuple, Union | |
from datasets import load_dataset | |
import pandas as pd | |
def set_recall(predicted: Union[Set[str], List[str]], actual: Union[Set[str], List[str]]) -> Fraction: | |
"""Returns the percentage of actual files changed that were predicted.""" | |
correct_predictions = len(predicted.intersection(actual)) | |
if len(actual) == 0: | |
return Fraction(0, 1) | |
return Fraction(correct_predictions, len(actual)) | |
def set_precision(predicted: Union[Set[str], List[str]], actual: Union[Set[str], List[str]]) -> Fraction: | |
"""Returns the percentage of predicted files changed that were actually changed.""" | |
correct_predictions = len(predicted.intersection(actual)) | |
if len(predicted) == 0: | |
return Fraction(0, 1) | |
return Fraction(correct_predictions, len(predicted)) | |
def extract_filepaths_from_diff(diff: str) -> Set[str]: | |
# This regex looks for the patterns that starts with 'diff --git' and then captures | |
# the following file paths. | |
filepaths = re.findall(r'diff --git a/(.*?) b/(.*?)\s', diff) | |
# Flatten the list of tuples and remove duplicates | |
filepaths = set([item for sublist in filepaths for item in sublist]) | |
return filepaths | |
def get_predicted_filepaths() -> dict[str, set[str]]: | |
directories = ['output_diffs/pass/', 'output_diffs/fail/'] | |
predicted_filepaths = collections.defaultdict(set) | |
for directory in directories: | |
for filename in os.listdir(directory): | |
match = re.match(r'(.*)-diff.txt', filename) | |
if match: | |
instance_id = match.group(1) | |
with open(os.path.join(directory, filename), 'r') as file: | |
diff = file.read() | |
filepaths = extract_filepaths_from_diff(diff) | |
predicted_filepaths[instance_id] = filepaths | |
return predicted_filepaths | |
def get_actual_filepaths() -> dict[str, Set[str]]: | |
# Load the dataset | |
dataset = load_dataset("princeton-nlp/SWE-bench") | |
instance_to_actual_diff = collections.defaultdict(set) | |
test_data = dataset['test'].to_pandas() | |
for index, row in test_data.iterrows(): | |
instance_id = row['instance_id'] | |
patch = row['patch'] | |
test_patch = row['test_patch'] | |
patch_filepaths = extract_filepaths_from_diff(patch) | |
instance_to_actual_diff[instance_id] = patch_filepaths | |
# Uncomment the following lines to include the test patch in the actual diff. | |
# test_patch_filepaths = extract_filepaths_from_diff(test_patch) | |
# instance_to_actual_diff[instance_id] = patch_filepaths.union(test_patch_filepaths) | |
return instance_to_actual_diff | |
def get_metrics(predicted_filepaths_by_instance: Dict[str, Set[str]], | |
actual_filepaths_by_instance: Dict[str, Set[str]]) -> Tuple[ | |
Fraction, Fraction]: | |
print('Number of instances in actual:', len(actual_filepaths_by_instance)) | |
print('Number of instances in predicted:', len(predicted_filepaths_by_instance)) | |
precisions = {} | |
recalls = {} | |
actual_num_files_distribution = collections.Counter() | |
instance_id_not_present = 0 | |
for instance_id in actual_filepaths_by_instance: | |
if instance_id not in predicted_filepaths_by_instance: | |
instance_id_not_present += 1 | |
continue | |
predicted_files = predicted_filepaths_by_instance[instance_id] | |
actual_files = actual_filepaths_by_instance[instance_id] | |
actual_num_files_distribution[len(actual_files)] += 1 | |
precision = set_precision(predicted=predicted_files, | |
actual=actual_files) | |
recall = set_recall(predicted=predicted_files, | |
actual=actual_files) | |
precisions[instance_id] = precision | |
recalls[instance_id] = recall | |
print('Actual number of files distribution:') | |
print('(Number of files, Number of SWE-Bench instances with that number of files in the actual diff)') | |
for key in sorted(actual_num_files_distribution): | |
print(f'{key}: {actual_num_files_distribution[key]}') | |
print(f'Instance ID not present in predicted filepaths: {instance_id_not_present}') | |
print(f'Number of examples to evaluate on: {len(precisions)}') | |
if len(precisions) > 0: | |
average_precision = Fraction(sum(precisions.values()), len(precisions)) | |
else: | |
average_precision = Fraction(0, 1) | |
if len(recalls) > 0: | |
average_recall = Fraction(sum(recalls.values()), len(recalls)) | |
else: | |
average_recall = Fraction(0, 1) | |
return average_precision, average_recall | |
def get_metrics_by_num_files(predicted_filepaths: Dict[str, Set[str]], actual_filepaths: Dict[str, Set[str]]) -> Dict[ | |
int, Tuple[Fraction, Fraction]]: | |
metrics_by_num_files = {} | |
for instance_id, actual_files in actual_filepaths.items(): | |
num_files = len(actual_files) | |
if instance_id not in predicted_filepaths: | |
continue | |
predicted_files = predicted_filepaths[instance_id] | |
precision = set_precision(predicted=predicted_files, actual=actual_files) | |
recall = set_recall(predicted=predicted_files, actual=actual_files) | |
if num_files not in metrics_by_num_files: | |
metrics_by_num_files[num_files] = ([], []) | |
metrics_by_num_files[num_files][0].append(precision) | |
metrics_by_num_files[num_files][1].append(recall) | |
for num_files, (precisions, recalls) in metrics_by_num_files.items(): | |
average_precision = Fraction(sum(precisions), len(precisions)) if precisions else Fraction(0, 1) | |
average_recall = Fraction(sum(recalls), len(recalls)) if recalls else Fraction(0, 1) | |
metrics_by_num_files[num_files] = (average_precision, average_recall) | |
return metrics_by_num_files | |
if __name__ == '__main__': | |
cognition_predicted_filepaths = get_predicted_filepaths() | |
actual_filepaths = get_actual_filepaths() | |
precision, recall = get_metrics(cognition_predicted_filepaths, actual_filepaths) | |
print('*' * 80) | |
print(f'Precision: {float(precision) * 100:.2f}%') | |
print(f'Recall: {float(recall) * 100:.2f}%') | |
print('*' * 80) | |
metrics_by_num_files = get_metrics_by_num_files(cognition_predicted_filepaths, actual_filepaths) | |
print('*' * 80) | |
print('Metrics by number of files:') | |
print('(Number of files, Average precision, Average recall)') | |
for num_files in sorted(metrics_by_num_files): | |
precision, recall = metrics_by_num_files[num_files] | |
print(f'{num_files}: {float(precision) * 100:.2f}%, {float(recall) * 100:.2f}%') | |
print('*' * 80) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Metrics by number of files: | |
(Number of files, Average precision, Average recall) | |
1: 70.12%, 76.53% | |
2: 67.98%, 40.00% | |
3: 61.74%, 25.76% | |
4: 53.57%, 17.86% | |
5: 100.00%, 22.00% | |
6: 95.00%, 23.33% | |
7: 100.00%, 19.05% | |
9: 50.00%, 11.11% | |
11: 100.00%, 9.09% | |
14: 100.00%, 7.14% | |
15: 100.00%, 6.67% | |
17: 66.67%, 5.88% | |
18: 100.00%, 16.67% | |
21: 100.00%, 7.14% | |
23: 100.00%, 4.35% | |
31: 100.00%, 3.23% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment