Skip to content

Instantly share code, notes, and snippets.

@alvarobartt
Created May 3, 2024 14:48
Show Gist options
  • Save alvarobartt/4c9618987eefeaaa47f5c660ea152f41 to your computer and use it in GitHub Desktop.
Save alvarobartt/4c9618987eefeaaa47f5c660ea152f41 to your computer and use it in GitHub Desktop.
from distilabel.llms import (
AnthropicLLM,
InferenceEndpointsLLM,
OpenAILLM,
)
from distilabel.pipeline import Pipeline
from distilabel.steps import (
CombineColumns,
KeepColumns,
LoadDataFromDicts,
LoadHubDataset,
StepInput,
step,
)
from distilabel.steps.tasks import TextGeneration, UltraFeedback
from distilabel.steps.typing import StepOutput
@step(inputs=["poll_ratings"], outputs=["avg_poll_ratings"])
def AveragePooling(*inputs: StepInput) -> StepOutput:
for input in inputs:
for item in input:
item["avg_poll_ratings"] = [
sum(col) / len(col) for col in zip(*item["poll_ratings"])
]
yield input
if __name__ == "__main__":
with Pipeline(name="replacing-judges-with-juries") as pipeline:
# load_dataset = LoadDataFromDicts(
# name="load_dataset",
# data=[
# {
# "instruction": "Arianna has 12 chocolates more than Danny. Danny has 6 chocolates more than Robbie. Arianna has twice as many chocolates as Robbie has. How many chocolates does Danny have?"
# },
# ],
# )
load_dataset = LoadHubDataset(
name="load_dataset",
repo_id="HuggingFaceH4/instruction-dataset",
split="test",
num_examples=10,
output_mappings={"prompt": "instruction"},
)
text_generation_llama3 = TextGeneration(
name="text_generation_llama3",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
api_key=os.getenv("HF_TOKEN"), # type: ignore
),
input_batch_size=10,
output_mappings={"model_name": "generation_model"},
)
text_generation_gemma = TextGeneration(
name="text_generation_gemma",
llm=InferenceEndpointsLLM(
model_id="google/gemma-1.1-7b-it",
api_key=os.getenv("HF_TOKEN"), # type: ignore
),
input_batch_size=10,
output_mappings={"model_name": "generation_model"},
)
text_generation_phi3 = TextGeneration(
name="text_generation_phi3",
llm=InferenceEndpointsLLM(
model_id="microsoft/Phi-3-mini-4k-instruct",
api_key=os.getenv("HF_TOKEN"), # type: ignore
),
input_batch_size=10,
output_mappings={"model_name": "generation_model"},
)
text_generation_mistral = TextGeneration(
name="text_generation_mistral",
llm=InferenceEndpointsLLM(
model_id="mistralai/Mistral-7B-Instruct-v0.2",
api_key=os.getenv("HF_TOKEN"), # type: ignore
),
input_batch_size=10,
output_mappings={"model_name": "generation_model"},
)
combine_generation_columns = CombineColumns(
name="combine_generation_columns",
columns=["generation", "generation_model"],
output_columns=["generations", "generation_models"],
)
# ultrafeedback_haiku = UltraFeedback(
# name="ultrafeedback_haiku",
# llm=AnthropicLLM(
# model="claude-3-haiku-20240307",
# api_key=os.getenv("ANTHROPIC_API_KEY"), # type: ignore
# ),
# input_batch_size=5,
# aspect="instruction-following",
# )
ultrafeedback_cmdr_plus = UltraFeedback(
name="ultrafeedback_cmdr_plus",
llm=InferenceEndpointsLLM(
model_id="CohereForAI/c4ai-command-r-plus",
api_key=os.getenv("HF_TOKEN"), # type: ignore
),
input_batch_size=5,
aspect="instruction-following",
)
ultrafeedback_gpt35 = UltraFeedback(
name="ultrafeedback_gpt35",
llm=OpenAILLM(
model="gpt-3.5-turbo-0125",
api_key=os.getenv("OPENAI_API_KEY"), # type: ignore
),
input_batch_size=5,
aspect="instruction-following",
)
combine_ultrafeedback_columns = CombineColumns(
name="combine_ultrafeedback_columns",
columns=["ratings", "rationales", "model_name"],
output_columns=["poll_ratings", "poll_rationales", "poll_models"],
)
avg_pooling = AveragePooling(name="avg_pooling", input_batch_size=1)
(
load_dataset
>> [text_generation_llama3, text_generation_gemma, text_generation_phi3, text_generation_mistral]
>> combine_generation_columns
# >> [ultrafeedback_haiku, ultrafeedback_cmdr_plus, ultrafeedback_gpt35]
>> [ultrafeedback_cmdr_plus, ultrafeedback_gpt35]
>> combine_ultrafeedback_columns
>> avg_pooling
)
distiset = pipeline.run(
parameters={
"text_generation_llama3": {
"llm": {
"generation_kwargs": {
"temperature": 0.7,
"max_new_tokens": 1024,
"stop_sequences": ["<|eot_id|>", "<|end_of_text|>"],
},
},
},
"text_generation_gemma": {
"llm": {
"generation_kwargs": {
"temperature": 0.7,
"max_new_tokens": 1024,
"stop_sequences": ["<eos>", "<end_of_turn>"],
},
},
},
"text_generation_phi3": {
"llm": {
"generation_kwargs": {
"temperature": 0.7,
"max_new_tokens": 1024,
"stop_sequences": ["</s>", "<|endoftext|>"],
},
},
},
"text_generation_mistral": {
"llm": {
"generation_kwargs": {
"temperature": 0.7,
"max_new_tokens": 1024,
"stop_sequences": ["</s>"],
},
},
},
# "ultrafeedback_haiku": {
# "llm": {"generation_kwargs": {"temperature": 1.0, "max_tokens": 4096}},
# },
"ultrafeedback_cmdr_plus": {
"llm": {
"generation_kwargs": {
"temperature": 1.0,
"max_new_tokens": 4096,
"stop_sequences": ["<EOS_TOKEN>", "<|END_OF_TURN_TOKEN|>"],
},
},
},
"ultrafeedback_gpt35": {
"llm": {
"generation_kwargs": {"temperature": 1.0, "max_new_tokens": 4096}
},
},
}
)
if distiset is not None:
distiset.push_to_hub(
"replacing-judges-with-juries-distilabel",
token=os.getenv("HF_TOKEN"),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment