Created
May 8, 2024 12:41
-
-
Save alvarobartt/e6611cf229c4fafcbe923a6da25b6bfe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from typing import Any, Dict, Literal | |
from distilabel.llms import vLLM | |
from distilabel.llms.typing import ChatType | |
from distilabel.pipeline import Pipeline | |
from distilabel.steps import LoadDataFromDicts | |
from distilabel.steps.tasks.prometheus_eval import PrometheusEval | |
_CUSTOM_RUBRICS = { | |
"custom-rubric": "...", | |
} | |
class CustomPrometheusEval(PrometheusEval): | |
rubric: Literal["custom-rubric"] | |
def format_input(self, input: Dict[str, Any]) -> ChatType: | |
template_kwargs = { | |
"instruction": input["instruction"], | |
"rubric": _CUSTOM_RUBRICS[self.rubric], | |
} | |
if self.reference: | |
template_kwargs["reference"] = input["reference"] | |
if self.mode == "absolute": | |
if not isinstance(input["generation"], str): | |
raise ValueError( | |
f"Provided `generation` is of type {type(input['generation'])} but a string" | |
" should be provided instead.", | |
) | |
template_kwargs["generation"] = input["generation"] | |
system_message = ( | |
"You are a fair judge assistant tasked with providing clear, objective feedback based" | |
" on specific criteria, ensuring each assessment reflects the absolute standards set" | |
" for performance." | |
) | |
else: # self.mode == "relative" | |
if ( | |
not isinstance(input["generations"], list) | |
or not all( | |
isinstance(generation, str) for generation in input["generations"] | |
) | |
or len(input["generations"]) != 2 | |
): | |
raise ValueError( | |
f"Provided `generations` is of type {type(input['generations'])} but a list of strings with length 2 should be provided instead." | |
) | |
template_kwargs["generations"] = input["generations"] | |
system_message = ( | |
"You are a fair judge assistant assigned to deliver insightful feedback that compares" | |
" individual performances, highlighting how each stands relative to others within the" | |
" same cohort." | |
) | |
return [ | |
{ | |
"role": "system", | |
"content": system_message, | |
}, | |
{ | |
"role": "user", | |
"content": self._template.render(**template_kwargs), # type: ignore | |
}, | |
] | |
if __name__ == "__main__": | |
start_time = time.time() | |
with Pipeline(name="prometheus") as pipeline: | |
load_dataset = LoadDataFromDicts( | |
name="load_dataset", | |
data=[ | |
{ | |
"instruction": "What's 2+2?", | |
"generation": "The answer is 4", | |
"generations": ["The answer is 4", "The answer is clearly 42"], | |
}, | |
], | |
) | |
task = CustomPrometheusEval( | |
name="task", | |
llm=vLLM( | |
model="prometheus-eval/prometheus-7b-v2.0", | |
chat_template="[INST] {{ messages[0]['content'] }}\n{{ messages[1]['content'] }}[/INST]", | |
), | |
mode="absolute", | |
rubric="custom-rubric", | |
reference=False, | |
num_generations=1, | |
group_generations=False, | |
) | |
load_dataset >> task # type: ignore | |
distiset = pipeline.run( | |
parameters={ | |
"abs_task": { | |
"llm": { | |
"generation_kwargs": { | |
"max_new_tokens": 1024, | |
"temperature": 0.7, | |
}, | |
}, | |
}, | |
"rel_task": { | |
"llm": { | |
"generation_kwargs": { | |
"max_new_tokens": 1024, | |
"temperature": 0.7, | |
}, | |
}, | |
}, | |
}, | |
) | |
print("--- %s seconds ---" % (time.time() - start_time)) | |
if distiset is not None: | |
distiset.push_to_hub("prometheus-eval-distilabel-custom-rubric") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment