Skip to content

Instantly share code, notes, and snippets.

@lukehinds
Created March 16, 2025 20:25
Show Gist options
  • Save lukehinds/2e52a62d9709db5b6497028f439717be to your computer and use it in GitHub Desktop.
Save lukehinds/2e52a62d9709db5b6497028f439717be to your computer and use it in GitHub Desktop.
import nltk
import ast
import difflib
# Make sure you have the required NLTK resource punkt
# cat punkt.py  ✔  3694  19:31:56
# import nltk
# nltk.download('punkt_tab')
# [nltk_data] Downloading package punkt_tab to
# [nltk_data] /Users/lhinds/nltk_data...
# [nltk_data] Unzipping tokenizers/punkt_tab.zip.
nltk.download('punkt')
def compute_bleu(reference, candidate):
"""
Compute the BLEU score using NLTK's sentence_bleu.
Tokenizes the code and computes n-gram overlap.
"""
reference_tokens = nltk.word_tokenize(reference)
candidate_tokens = nltk.word_tokenize(candidate)
# Using sentence_bleu; note that this may require smoothing I expect
# and should be review by Pankaj and Nigel
score = nltk.translate.bleu_score.sentence_bleu([reference_tokens], candidate_tokens)
return score
def get_normalized_ast(code):
"""
Parse the code into an AST and return a normalized string representation.
This normalization removes extraneous attributes like line numbers.
"""
try:
tree = ast.parse(code)
# Produce a normalized dump of the AST that ignores fields like line numbers.
# This is the same approach we took in bandit.
normalized = ast.dump(tree, annotate_fields=False, include_attributes=False)
return normalized
except Exception as e:
# If parsing fails, return None.
return None
def compute_syntax_similarity(reference, candidate):
"""
Compare the ASTs of the reference and candidate code.
Uses difflib.SequenceMatcher to compute a similarity ratio between the normalized AST dumps.
"""
ast_ref = get_normalized_ast(reference)
ast_cand = get_normalized_ast(candidate)
if ast_ref is None or ast_cand is None:
return 0.0 # Cannot compute similarity if either code snippet fails to
# parse entirely
matcher = difflib.SequenceMatcher(None, ast_ref, ast_cand)
return matcher.ratio()
def compute_codebleu(reference, candidate, bleu_weight=0.5, syntax_weight=0.5):
"""
Combines BLEU and syntax similarity scores into a composite CodeBLEU score.
Weights for each component can be adjusted based on validation results.
"""
bleu_score = compute_bleu(reference, candidate)
syntax_score = compute_syntax_similarity(reference, candidate)
# Combine the scores using the provided weights.
codebleu_score = bleu_weight * bleu_score + syntax_weight * syntax_score
return codebleu_score
# Example usage:
if __name__ == '__main__':
real_code = """async def detect(self, request: Request) -> bool:
try:
data = await request.json()
for message in data.get("messages", []):
message_content = str(message.get("content", ""))
if self.pattern in message_content:
return True
system_content = str(data.get("system", ""))
if self.pattern in system_content:
return True
return False
except Exception as e:
logger.error(f"Error in content detection: {str(e)}")
return False
"""
model_code = """async def detect_content(self, req: Request) -> bool:
try:
payload = await req.json()
for msg in payload.get("messages", []):
content = str(msg.get("content", ""))
if self.pattern in content:
return True
system_message = str(payload.get("system", ""))
return self.pattern in system_message
except Exception as error:
return False
"""
score = compute_codebleu(real_code, model_code)
print("CodeBLEU Score:", score)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment