Last active
February 15, 2025 18:25
-
-
Save damiencorpataux/4fbedf4ea66ae21c9e6138b2d0297e15 to your computer and use it in GitHub Desktop.
Structured JSON Generation using Ollama on Haystack, demo: https://asciinema.org/a/7e6Lu7GYcD3HMvuuFFnjzZcjA
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# python 3.12 is needed: | |
# pyenv install 3.12 | |
# pyenv exec python3.12 -m venv venv | |
# . venv/bin/activate | |
# python -V | |
# pip install -r requirements.txt | |
haystack-ai | |
ollama-haystack | |
colorama |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
logging.basicConfig() | |
# Pipeline Debug | |
# from haystack import tracing | |
# from haystack.tracing.logging_tracer import LoggingTracer | |
# tracing.tracer.is_content_tracing_enabled = True # to enable tracing/logging content (inputs/outputs) | |
# tracing.enable_tracing(LoggingTracer(tags_color_strings={"haystack.component.input": "\x1b[1;31m", "haystack.component.name": "\x1b[1;34m"})) | |
# logging.getLogger("haystack").setLevel(logging.DEBUG) | |
# from haystack.telemetry import tutorial_running | |
# tutorial_running(28) # Allow telemetry | |
# Data Structures | |
from typing import List | |
from pydantic import BaseModel | |
class City(BaseModel): | |
name: str | |
country: str | |
population: int | |
class CitiesData(BaseModel): | |
cities: List[City] | |
data_model = CitiesData | |
# Output Validator | |
import json | |
import random | |
import pydantic | |
from pydantic import ValidationError | |
from typing import Optional, List | |
from colorama import Fore | |
from haystack import component | |
from haystack.dataclasses import ChatMessage | |
@component | |
# Define the component input parameters | |
class OutputValidator: | |
def __init__(self, pydantic_model: pydantic.BaseModel): | |
self.pydantic_model = pydantic_model | |
self.iteration_counter = 0 | |
@component.output_types(valid_replies=List[str], invalid_replies=Optional[List[str]], error_message=Optional[str]) | |
# Define the component output | |
def run(self, replies: List[ChatMessage]): | |
print(Fore.YELLOW + replies[0].text) | |
self.iteration_counter += 1 | |
try: | |
## Try to parse the LLM's reply ## | |
# If the LLM's reply is a valid object, return `"valid_replies"` | |
output_dict = json.loads(replies[0].text) | |
self.pydantic_model.model_validate(output_dict) | |
print( | |
Fore.GREEN | |
+ f"OutputValidator at Iteration {self.iteration_counter}: Valid JSON from LLM - No need for looping: {replies[0]}" | |
) | |
return {"valid_replies": replies} | |
except (ValueError, ValidationError) as e: | |
# If the LLM's reply is corrupted or not valid, return "invalid_replies" and the "error_message" for LLM to try again | |
print( | |
Fore.RED | |
+ f"OutputValidator at Iteration {self.iteration_counter}: Invalid JSON from LLM - Let's try again.\n" | |
# f"Output from LLM:\n {replies[0]} \n" | |
f"Error from OutputValidator: {e}" | |
) | |
return {"invalid_replies": replies, "error_message": str(e)} | |
output_validator = OutputValidator(pydantic_model=data_model) | |
# Prompt | |
from haystack.components.builders import ChatPromptBuilder | |
prompt_template = [ | |
ChatMessage.from_user( | |
""" | |
Create a JSON object from the information present in this passage: {{passage}}. | |
Only use information that is present in the passage. Follow this JSON schema, but only return the actual instances without any additional schema definition: | |
{{schema}} | |
Make sure your response is a dict and not a list. | |
{% if invalid_replies and error_message %} | |
You already created the following output in a previous attempt: {{invalid_replies}} | |
However, this doesn't comply with the format requirements from above and triggered this Python exception: {{error_message}} | |
Correct the output and try again. Just return the corrected output without any extra explanations. | |
{% endif %} | |
""" | |
) | |
] | |
print(prompt_template) | |
prompt_builder = ChatPromptBuilder(template=prompt_template) | |
# Chat Generator | |
from haystack_integrations.components.generators.ollama import OllamaChatGenerator | |
chat_generator = OllamaChatGenerator( | |
model="llama3.2", | |
# model="phi3", | |
# model="orca-mini", | |
# model="deepseek-r1", | |
# url = "http://localhost:11434/api/chat", | |
# # streaming_callback = print_streaming_chunk, | |
# generation_kwargs={ | |
# "num_predict": 100, | |
# "temperature": 0.9, | |
# 'max_tokens': 131072 | |
# }, | |
) | |
# Pipeline | |
from haystack import Pipeline | |
pipeline = Pipeline()#max_runs_per_component=12) | |
# Add components to your pipeline | |
pipeline.add_component(instance=prompt_builder, name="prompt_builder") | |
pipeline.add_component(instance=chat_generator, name="llm") | |
pipeline.add_component(instance=output_validator, name="output_validator") | |
# Now, connect the components to each other | |
pipeline.connect("prompt_builder.prompt", "llm.messages") | |
pipeline.connect("llm.replies", "output_validator") | |
# If a component has more than one output or input, explicitly specify the connections: | |
pipeline.connect("output_validator.invalid_replies", "prompt_builder.invalid_replies") | |
pipeline.connect("output_validator.error_message", "prompt_builder.error_message") | |
# pipeline.draw("auto-correct-pipeline.png") | |
# Run | |
passage = """ | |
Berlin is the capital of Germany. It has a population of 3,850,809. | |
Paris, France's capital, has 2.161 million residents. | |
Lisbon is the capital and the largest city of Portugal with the population of 504,718. | |
""" | |
result = pipeline.run({"prompt_builder": { | |
"passage": passage, | |
"schema": json.dumps(data_model.model_json_schema()) | |
}}) | |
json_string = result['output_validator']['valid_replies'][0]._content[0].text | |
print( | |
Fore.RESET, | |
'\nResult:\n', | |
json.dumps(json.loads(json_string), indent=4)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment