Skip to content

Instantly share code, notes, and snippets.

@alvarobartt
Created May 8, 2024 12:41
Show Gist options
  • Save alvarobartt/e6611cf229c4fafcbe923a6da25b6bfe to your computer and use it in GitHub Desktop.
Save alvarobartt/e6611cf229c4fafcbe923a6da25b6bfe to your computer and use it in GitHub Desktop.
import time
from typing import Any, Dict, Literal
from distilabel.llms import vLLM
from distilabel.llms.typing import ChatType
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks.prometheus_eval import PrometheusEval
_CUSTOM_RUBRICS = {
"custom-rubric": "...",
}
class CustomPrometheusEval(PrometheusEval):
rubric: Literal["custom-rubric"]
def format_input(self, input: Dict[str, Any]) -> ChatType:
template_kwargs = {
"instruction": input["instruction"],
"rubric": _CUSTOM_RUBRICS[self.rubric],
}
if self.reference:
template_kwargs["reference"] = input["reference"]
if self.mode == "absolute":
if not isinstance(input["generation"], str):
raise ValueError(
f"Provided `generation` is of type {type(input['generation'])} but a string"
" should be provided instead.",
)
template_kwargs["generation"] = input["generation"]
system_message = (
"You are a fair judge assistant tasked with providing clear, objective feedback based"
" on specific criteria, ensuring each assessment reflects the absolute standards set"
" for performance."
)
else: # self.mode == "relative"
if (
not isinstance(input["generations"], list)
or not all(
isinstance(generation, str) for generation in input["generations"]
)
or len(input["generations"]) != 2
):
raise ValueError(
f"Provided `generations` is of type {type(input['generations'])} but a list of strings with length 2 should be provided instead."
)
template_kwargs["generations"] = input["generations"]
system_message = (
"You are a fair judge assistant assigned to deliver insightful feedback that compares"
" individual performances, highlighting how each stands relative to others within the"
" same cohort."
)
return [
{
"role": "system",
"content": system_message,
},
{
"role": "user",
"content": self._template.render(**template_kwargs), # type: ignore
},
]
if __name__ == "__main__":
start_time = time.time()
with Pipeline(name="prometheus") as pipeline:
load_dataset = LoadDataFromDicts(
name="load_dataset",
data=[
{
"instruction": "What's 2+2?",
"generation": "The answer is 4",
"generations": ["The answer is 4", "The answer is clearly 42"],
},
],
)
task = CustomPrometheusEval(
name="task",
llm=vLLM(
model="prometheus-eval/prometheus-7b-v2.0",
chat_template="[INST] {{ messages[0]['content'] }}\n{{ messages[1]['content'] }}[/INST]",
),
mode="absolute",
rubric="custom-rubric",
reference=False,
num_generations=1,
group_generations=False,
)
load_dataset >> task # type: ignore
distiset = pipeline.run(
parameters={
"abs_task": {
"llm": {
"generation_kwargs": {
"max_new_tokens": 1024,
"temperature": 0.7,
},
},
},
"rel_task": {
"llm": {
"generation_kwargs": {
"max_new_tokens": 1024,
"temperature": 0.7,
},
},
},
},
)
print("--- %s seconds ---" % (time.time() - start_time))
if distiset is not None:
distiset.push_to_hub("prometheus-eval-distilabel-custom-rubric")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment