Created
July 4, 2025 20:01
-
-
Save kiranandcode/c928bab2b7bcf62c4cbfde816bfc0b08 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tqdm import tqdm | |
from transformers import AutoTokenizer, pipeline | |
from pathlib import Path | |
import torch | |
model_name = "kirancodes/llemma-7b-lora-sft-pretrained-RQ2" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
llama_pipeline = pipeline( | |
"text-generation", | |
tokenizer=tokenizer, | |
model=model_name, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
) | |
PREFIX = "Write a specification in Coq that ensures the following property: " | |
def process_data(obj): | |
program_text = obj["program"].strip() | |
context_text = obj["context"] or "" | |
long_desc = obj["long_description"].strip() | |
specification = obj["specification"].strip() | |
tests = obj["tests"] or [] | |
tests = list(test_obj["coq_statement"] for test_obj in tests if isinstance(test_obj,dict) and "coq_statement" in test_obj) | |
test_text = "\n####\n".join(tests) | |
# Build the prompt: prefix + long description, then a single "Program:" section | |
query = "\n".join([ | |
"### Program:", | |
f"{context_text}\n\n{program_text}", | |
"", | |
"### Tests:", | |
test_text, | |
"", | |
"### Specification:" | |
]) | |
if len(query) > 5000: | |
oquery_len = len(query) | |
program_text = program_text[:5000 - len(test_text) - 48] | |
query = "\n".join([ | |
"### Program:", | |
f"{context_text}\n\n{program_text}", | |
"", | |
"### Tests:", | |
test_text, | |
"", | |
"### Specification:" | |
]) | |
# Completion must start with a space for OpenAI-style fine-tuning | |
completion = " " + specification | |
return ({'instruction': PREFIX, 'input': query, 'specification': specification, 'expected_output': completion, 'file_path': obj['file_path']}) | |
# prompt: grab the dataset (jsonl) from hugging face under the repo priyamsahoo/specification-synthesis the file called 'v7-training-data.jsonl', then run `process_data` on each object; the data is structured weirdly so downloading it from datasets doesn't really work | |
import requests | |
import json | |
url = "https://huggingface.co/datasets/priyamsahoo/specification-synthesis/resolve/main/v7-evaluation-data.jsonl" | |
response = requests.get(url) | |
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) | |
data_lines = response.text.strip().split('\n') | |
processed_data = [] | |
for line in data_lines: | |
try: | |
obj = json.loads(line) | |
processed_data.append(process_data(obj)) | |
except json.JSONDecodeError as e: | |
print(f"Skipping invalid JSON line: {line} - Error: {e}") | |
except Exception as e: | |
print(f"Skipping line due to processing error: {line} - Error: {e}") | |
# Now processed_data contains the results of running process_data on each object | |
import itertools | |
def batched(iter, sz): | |
ls = list(iter) | |
no_elts = len(ls) | |
batch_size = max(no_elts // sz,1) | |
no_batches = no_elts // batch_size + (0 if no_elts % batch_size == 0 else 1) | |
return (ls[i*batch_size:(i+1)*batch_size] for i in range(no_batches)) | |
def format_prompt(elt): | |
return elt['instruction'] + '\n' + elt['input'] | |
sequences = [] | |
batch_size = 50 | |
for i, batch in tqdm(list(enumerate(batched(processed_data, batch_size)))): | |
try: | |
res = llama_pipeline( | |
[format_prompt(elt) for elt in batch], | |
do_sample=True, | |
top_k=10, | |
num_return_sequences=1, | |
eos_token_id=tokenizer.eos_token_id, | |
max_new_tokens=256, | |
return_full_text=False | |
) | |
except Exception as e: | |
res = [({"error": str(e), 'batch_index': i})] | |
sequences.extend(res) | |
output = Path('/u/kirang/scratch/llemma-7b-lora-sft-pretrained-RQ2-inferred.json') | |
with output.open('w') as f: | |
json.dump(sequences, f) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment