Skip to content

Instantly share code, notes, and snippets.

@damiencorpataux
Last active February 15, 2025 18:25
Show Gist options
  • Save damiencorpataux/4fbedf4ea66ae21c9e6138b2d0297e15 to your computer and use it in GitHub Desktop.
Save damiencorpataux/4fbedf4ea66ae21c9e6138b2d0297e15 to your computer and use it in GitHub Desktop.
Structured JSON Generation using Ollama on Haystack, demo: https://asciinema.org/a/7e6Lu7GYcD3HMvuuFFnjzZcjA
# python 3.12 is needed:
# pyenv install 3.12
# pyenv exec python3.12 -m venv venv
# . venv/bin/activate
# python -V
# pip install -r requirements.txt
haystack-ai
ollama-haystack
colorama
import logging
logging.basicConfig()
# Pipeline Debug
# from haystack import tracing
# from haystack.tracing.logging_tracer import LoggingTracer
# tracing.tracer.is_content_tracing_enabled = True # to enable tracing/logging content (inputs/outputs)
# tracing.enable_tracing(LoggingTracer(tags_color_strings={"haystack.component.input": "\x1b[1;31m", "haystack.component.name": "\x1b[1;34m"}))
# logging.getLogger("haystack").setLevel(logging.DEBUG)
# from haystack.telemetry import tutorial_running
# tutorial_running(28) # Allow telemetry
# Data Structures
from typing import List
from pydantic import BaseModel
class City(BaseModel):
name: str
country: str
population: int
class CitiesData(BaseModel):
cities: List[City]
data_model = CitiesData
# Output Validator
import json
import random
import pydantic
from pydantic import ValidationError
from typing import Optional, List
from colorama import Fore
from haystack import component
from haystack.dataclasses import ChatMessage
@component
# Define the component input parameters
class OutputValidator:
def __init__(self, pydantic_model: pydantic.BaseModel):
self.pydantic_model = pydantic_model
self.iteration_counter = 0
@component.output_types(valid_replies=List[str], invalid_replies=Optional[List[str]], error_message=Optional[str])
# Define the component output
def run(self, replies: List[ChatMessage]):
print(Fore.YELLOW + replies[0].text)
self.iteration_counter += 1
try:
## Try to parse the LLM's reply ##
# If the LLM's reply is a valid object, return `"valid_replies"`
output_dict = json.loads(replies[0].text)
self.pydantic_model.model_validate(output_dict)
print(
Fore.GREEN
+ f"OutputValidator at Iteration {self.iteration_counter}: Valid JSON from LLM - No need for looping: {replies[0]}"
)
return {"valid_replies": replies}
except (ValueError, ValidationError) as e:
# If the LLM's reply is corrupted or not valid, return "invalid_replies" and the "error_message" for LLM to try again
print(
Fore.RED
+ f"OutputValidator at Iteration {self.iteration_counter}: Invalid JSON from LLM - Let's try again.\n"
# f"Output from LLM:\n {replies[0]} \n"
f"Error from OutputValidator: {e}"
)
return {"invalid_replies": replies, "error_message": str(e)}
output_validator = OutputValidator(pydantic_model=data_model)
# Prompt
from haystack.components.builders import ChatPromptBuilder
prompt_template = [
ChatMessage.from_user(
"""
Create a JSON object from the information present in this passage: {{passage}}.
Only use information that is present in the passage. Follow this JSON schema, but only return the actual instances without any additional schema definition:
{{schema}}
Make sure your response is a dict and not a list.
{% if invalid_replies and error_message %}
You already created the following output in a previous attempt: {{invalid_replies}}
However, this doesn't comply with the format requirements from above and triggered this Python exception: {{error_message}}
Correct the output and try again. Just return the corrected output without any extra explanations.
{% endif %}
"""
)
]
print(prompt_template)
prompt_builder = ChatPromptBuilder(template=prompt_template)
# Chat Generator
from haystack_integrations.components.generators.ollama import OllamaChatGenerator
chat_generator = OllamaChatGenerator(
model="llama3.2",
# model="phi3",
# model="orca-mini",
# model="deepseek-r1",
# url = "http://localhost:11434/api/chat",
# # streaming_callback = print_streaming_chunk,
# generation_kwargs={
# "num_predict": 100,
# "temperature": 0.9,
# 'max_tokens': 131072
# },
)
# Pipeline
from haystack import Pipeline
pipeline = Pipeline()#max_runs_per_component=12)
# Add components to your pipeline
pipeline.add_component(instance=prompt_builder, name="prompt_builder")
pipeline.add_component(instance=chat_generator, name="llm")
pipeline.add_component(instance=output_validator, name="output_validator")
# Now, connect the components to each other
pipeline.connect("prompt_builder.prompt", "llm.messages")
pipeline.connect("llm.replies", "output_validator")
# If a component has more than one output or input, explicitly specify the connections:
pipeline.connect("output_validator.invalid_replies", "prompt_builder.invalid_replies")
pipeline.connect("output_validator.error_message", "prompt_builder.error_message")
# pipeline.draw("auto-correct-pipeline.png")
# Run
passage = """
Berlin is the capital of Germany. It has a population of 3,850,809.
Paris, France's capital, has 2.161 million residents.
Lisbon is the capital and the largest city of Portugal with the population of 504,718.
"""
result = pipeline.run({"prompt_builder": {
"passage": passage,
"schema": json.dumps(data_model.model_json_schema())
}})
json_string = result['output_validator']['valid_replies'][0]._content[0].text
print(
Fore.RESET,
'\nResult:\n',
json.dumps(json.loads(json_string), indent=4))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment