Created
March 16, 2025 20:25
-
-
Save lukehinds/2e52a62d9709db5b6497028f439717be to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import nltk | |
import ast | |
import difflib | |
# Make sure you have the required NLTK resource punkt | |
# cat punkt.py ✔ 3694 19:31:56 | |
# import nltk | |
# nltk.download('punkt_tab') | |
# [nltk_data] Downloading package punkt_tab to | |
# [nltk_data] /Users/lhinds/nltk_data... | |
# [nltk_data] Unzipping tokenizers/punkt_tab.zip. | |
nltk.download('punkt') | |
def compute_bleu(reference, candidate): | |
""" | |
Compute the BLEU score using NLTK's sentence_bleu. | |
Tokenizes the code and computes n-gram overlap. | |
""" | |
reference_tokens = nltk.word_tokenize(reference) | |
candidate_tokens = nltk.word_tokenize(candidate) | |
# Using sentence_bleu; note that this may require smoothing I expect | |
# and should be review by Pankaj and Nigel | |
score = nltk.translate.bleu_score.sentence_bleu([reference_tokens], candidate_tokens) | |
return score | |
def get_normalized_ast(code): | |
""" | |
Parse the code into an AST and return a normalized string representation. | |
This normalization removes extraneous attributes like line numbers. | |
""" | |
try: | |
tree = ast.parse(code) | |
# Produce a normalized dump of the AST that ignores fields like line numbers. | |
# This is the same approach we took in bandit. | |
normalized = ast.dump(tree, annotate_fields=False, include_attributes=False) | |
return normalized | |
except Exception as e: | |
# If parsing fails, return None. | |
return None | |
def compute_syntax_similarity(reference, candidate): | |
""" | |
Compare the ASTs of the reference and candidate code. | |
Uses difflib.SequenceMatcher to compute a similarity ratio between the normalized AST dumps. | |
""" | |
ast_ref = get_normalized_ast(reference) | |
ast_cand = get_normalized_ast(candidate) | |
if ast_ref is None or ast_cand is None: | |
return 0.0 # Cannot compute similarity if either code snippet fails to | |
# parse entirely | |
matcher = difflib.SequenceMatcher(None, ast_ref, ast_cand) | |
return matcher.ratio() | |
def compute_codebleu(reference, candidate, bleu_weight=0.5, syntax_weight=0.5): | |
""" | |
Combines BLEU and syntax similarity scores into a composite CodeBLEU score. | |
Weights for each component can be adjusted based on validation results. | |
""" | |
bleu_score = compute_bleu(reference, candidate) | |
syntax_score = compute_syntax_similarity(reference, candidate) | |
# Combine the scores using the provided weights. | |
codebleu_score = bleu_weight * bleu_score + syntax_weight * syntax_score | |
return codebleu_score | |
# Example usage: | |
if __name__ == '__main__': | |
real_code = """async def detect(self, request: Request) -> bool: | |
try: | |
data = await request.json() | |
for message in data.get("messages", []): | |
message_content = str(message.get("content", "")) | |
if self.pattern in message_content: | |
return True | |
system_content = str(data.get("system", "")) | |
if self.pattern in system_content: | |
return True | |
return False | |
except Exception as e: | |
logger.error(f"Error in content detection: {str(e)}") | |
return False | |
""" | |
model_code = """async def detect_content(self, req: Request) -> bool: | |
try: | |
payload = await req.json() | |
for msg in payload.get("messages", []): | |
content = str(msg.get("content", "")) | |
if self.pattern in content: | |
return True | |
system_message = str(payload.get("system", "")) | |
return self.pattern in system_message | |
except Exception as error: | |
return False | |
""" | |
score = compute_codebleu(real_code, model_code) | |
print("CodeBLEU Score:", score) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment