Skip to content

Instantly share code, notes, and snippets.

@kiranandcode
Created July 4, 2025 20:01
Show Gist options
  • Save kiranandcode/c928bab2b7bcf62c4cbfde816bfc0b08 to your computer and use it in GitHub Desktop.
Save kiranandcode/c928bab2b7bcf62c4cbfde816bfc0b08 to your computer and use it in GitHub Desktop.
from tqdm import tqdm
from transformers import AutoTokenizer, pipeline
from pathlib import Path
import torch
model_name = "kirancodes/llemma-7b-lora-sft-pretrained-RQ2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
llama_pipeline = pipeline(
"text-generation",
tokenizer=tokenizer,
model=model_name,
torch_dtype=torch.float16,
device_map="auto",
)
PREFIX = "Write a specification in Coq that ensures the following property: "
def process_data(obj):
program_text = obj["program"].strip()
context_text = obj["context"] or ""
long_desc = obj["long_description"].strip()
specification = obj["specification"].strip()
tests = obj["tests"] or []
tests = list(test_obj["coq_statement"] for test_obj in tests if isinstance(test_obj,dict) and "coq_statement" in test_obj)
test_text = "\n####\n".join(tests)
# Build the prompt: prefix + long description, then a single "Program:" section
query = "\n".join([
"### Program:",
f"{context_text}\n\n{program_text}",
"",
"### Tests:",
test_text,
"",
"### Specification:"
])
if len(query) > 5000:
oquery_len = len(query)
program_text = program_text[:5000 - len(test_text) - 48]
query = "\n".join([
"### Program:",
f"{context_text}\n\n{program_text}",
"",
"### Tests:",
test_text,
"",
"### Specification:"
])
# Completion must start with a space for OpenAI-style fine-tuning
completion = " " + specification
return ({'instruction': PREFIX, 'input': query, 'specification': specification, 'expected_output': completion, 'file_path': obj['file_path']})
# prompt: grab the dataset (jsonl) from hugging face under the repo priyamsahoo/specification-synthesis the file called 'v7-training-data.jsonl', then run `process_data` on each object; the data is structured weirdly so downloading it from datasets doesn't really work
import requests
import json
url = "https://huggingface.co/datasets/priyamsahoo/specification-synthesis/resolve/main/v7-evaluation-data.jsonl"
response = requests.get(url)
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
data_lines = response.text.strip().split('\n')
processed_data = []
for line in data_lines:
try:
obj = json.loads(line)
processed_data.append(process_data(obj))
except json.JSONDecodeError as e:
print(f"Skipping invalid JSON line: {line} - Error: {e}")
except Exception as e:
print(f"Skipping line due to processing error: {line} - Error: {e}")
# Now processed_data contains the results of running process_data on each object
import itertools
def batched(iter, sz):
ls = list(iter)
no_elts = len(ls)
batch_size = max(no_elts // sz,1)
no_batches = no_elts // batch_size + (0 if no_elts % batch_size == 0 else 1)
return (ls[i*batch_size:(i+1)*batch_size] for i in range(no_batches))
def format_prompt(elt):
return elt['instruction'] + '\n' + elt['input']
sequences = []
batch_size = 50
for i, batch in tqdm(list(enumerate(batched(processed_data, batch_size)))):
try:
res = llama_pipeline(
[format_prompt(elt) for elt in batch],
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=256,
return_full_text=False
)
except Exception as e:
res = [({"error": str(e), 'batch_index': i})]
sequences.extend(res)
output = Path('/u/kirang/scratch/llemma-7b-lora-sft-pretrained-RQ2-inferred.json')
with output.open('w') as f:
json.dump(sequences, f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment