Skip to content

Instantly share code, notes, and snippets.

@MaximeRivest
Created August 3, 2025 03:06
Show Gist options
  • Save MaximeRivest/b30defc6d797691cf6178d03b97e1ba0 to your computer and use it in GitHub Desktop.
Save MaximeRivest/b30defc6d797691cf6178d03b97e1ba0 to your computer and use it in GitHub Desktop.
import dspy
from pydantic import BaseModel, Field
from typing import Literal
# from pydantic_v2_adapter import PydanticV2Adapter # Assuming you saved the code
# 1. Define your Pydantic models
class PatientAddress(BaseModel):
street: str
city: str
country: Literal["US", "CA"]
class PatientDetails(BaseModel):
name: str = Field(description="Full name of the patient.")
age: int
address: PatientAddress | None
# 2. Define a signature using the Pydantic model as an output field
class ExtractPatientInfo(dspy.Signature):
'''Extract patient information from the clinical note.'''
clinical_note: str = dspy.InputField()
patient_info: PatientDetails = dspy.OutputField()
# 3. Configure dspy to use the new adapter
llm = dspy.LM(model="gpt-4.1")
dspy.configure(lm=llm, adapter=PydanticV2Adapter())
# 4. Use your program
extractor = dspy.Predict(ExtractPatientInfo)
note = "John Doe, 45 years old, lives at 123 Main St, Anytown. Resident of the US."
result = extractor(clinical_note=note)
print(result.patient_info)
# Expected output:
# PatientDetails(name='John Doe', age=45, address=PatientAddress(street='123 Main St', city='Anytown', country='US'))
```
```{python}
import dspy
from pydantic import BaseModel, Field
from typing import Literal, Optional, Union, List, Dict
# from pydantic_v2_adapter import PydanticV2Adapter
# 1. Define your Pydantic models
class ContactInfo(BaseModel):
email: Optional[str] = Field(default=None, description="Email address if available.")
phone: Optional[str] = None
class InsurancePolicy(BaseModel):
provider: str = Field(description="Insurance company name.")
policy_number: str
valid_until: Optional[str] = Field(default=None, description="YYYY-MM-DD")
class Allergy(BaseModel):
substance: str = Field(description="Allergen substance (e.g., penicillin, peanuts).")
severity: Literal["mild", "moderate", "severe"]
reactions: List[str] = Field(description="Typical reactions (e.g., rash, anaphylaxis).")
class LabResult(BaseModel):
name: str
value: Union[float, str]
units: Optional[str] = None
reference_range: Optional[str] = None
class EmergencyContact(BaseModel):
name: str
relationship: Literal["parent", "spouse", "sibling", "friend", "other"]
contact_info: ContactInfo
class PatientAddress(BaseModel):
street: str
city: str
region: Optional[str] = None
country: Literal["US", "CA", "MX"] = "US"
postal_code: Optional[str] = None
class PatientDetails(BaseModel):
name: str = Field(description="Full name of the patient.")
gender: Literal["male", "female", "other", "unknown"]
age: int
height_cm: Optional[float] = Field(default=None, description="Height in centimeters.")
weight_kg: Optional[float] = None
smoker: bool = False
address: PatientAddress
contacts: List[EmergencyContact]
allergies: List[Allergy] = []
insurance: Optional[InsurancePolicy] = None
lab_results: Optional[List[LabResult]] = None
notes: Optional[str] = None
metadata: Dict[str, Union[str, int, float, bool]] = {}
# 2. Define a signature using the Pydantic model as an output field
class ExtractPatientInfo(dspy.Signature):
"""
Extract as much patient information as possible from the clinical note,
including all nested and structured information.
"""
clinical_note: str = dspy.InputField()
patient_info: PatientDetails = dspy.OutputField()
# 3. Configure dspy to use the new adapter
llm = dspy.LM(model="gpt-4.1")
dspy.configure(lm=llm, adapter=PydanticV2Adapter())
# 4. Use your program
extractor = dspy.Predict(ExtractPatientInfo)
note = (
"Patient: Jane Smith (female), 38, 172 cm, 65 kg. "
"Address: 87 River Rd, Montreal, QC, Canada H2X 2A3. "
"Contact: husband, John Smith ([email protected], 555-2323). "
"Allergies: penicillin (severe, anaphylaxis), cat dander (mild, rash). "
"Insurance: SunLife, policy 89392Z, valid until 2027-12-31. "
"Lab: Hemoglobin 13.2 g/dL (12-16), Blood Type O+. "
"Smoker: No. "
"Extra note: Patient is training for a marathon. "
"Metadata: language=en, cohort=studyB."
)
result = extractor(clinical_note=note)
print(result.patient_info)
# Example expected output (not literal, but shows off all fields):
# PatientDetails(
# name='Jane Smith',
# gender='female',
# age=38,
# height_cm=172.0,
# weight_kg=65.0,
# smoker=False,
# address=PatientAddress(
# street='87 River Rd',
# city='Montreal',
# region='QC',
# country='CA',
# postal_code='H2X 2A3'
# ),
# contacts=[
# EmergencyContact(
# name='John Smith',
# relationship='spouse',
# contact_info=ContactInfo(email='[email protected]', phone='555-2323')
# )
# ],
# allergies=[
# Allergy(substance='penicillin', severity='severe', reactions=['anaphylaxis']),
# Allergy(substance='cat dander', severity='mild', reactions=['rash'])
# ],
# insurance=InsurancePolicy(
# provider='SunLife',
# policy_number='89392Z',
# valid_until='2027-12-31'
# ),
# lab_results=[
# LabResult(name='Hemoglobin', value=13.2, units='g/dL', reference_range='12-16'),
# LabResult(name='Blood Type', value='O+', units=None, reference_range=None)
# ],
# notes='Patient is training for a marathon.',
# metadata={'language': 'en', 'cohort': 'studyB'}
# )
import inspect
import types
from typing import Any, get_args, get_origin, Literal, Union
import pydantic
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from dspy.adapters.json_adapter import JSONAdapter
from dspy.signatures.signature import Signature
from dspy.adapters.utils import format_field_value as original_format_field_value
# Helper functions for rendering Pydantic schemas
def _render_type_str(t: Any) -> str:
"""Recursively renders a type annotation into a simplified string."""
origin = get_origin(t)
args = get_args(t)
# Handle Optional[T] or T | None
if origin in (types.UnionType, Union):
non_none_args = [arg for arg in args if type(arg) is not type(None)]
# Render the non-None part of the union
type_render = " or ".join([_render_type_str(arg) for arg in non_none_args])
# Add 'or null' if None was part of the union
if len(non_none_args) < len(args):
return f"{type_render} or null"
return type_render
# Base types
if t is str: return "string"
if t is int: return "int"
if t is float: return "float"
if t is bool: return "boolean"
# Composite types
if origin is Literal:
return " or ".join(f'"{arg}"' for arg in args)
if origin is list:
return f"{_render_type_str(args[0])}[]"
if origin is dict:
return f"dict[{_render_type_str(args[0])}, {_render_type_str(args[1])}]"
# Pydantic models (we'll recurse in the main function)
if inspect.isclass(t) and issubclass(t, BaseModel):
return _build_simplified_schema(t)
# Fallback
if hasattr(t, "__name__"):
return t.__name__
return str(t)
def _build_simplified_schema(model: type[BaseModel], indent: int = 0) -> str:
"""Builds a simplified, human-readable schema from a Pydantic model."""
lines = []
idt = ' ' * indent
idt_p1 = ' ' * (indent + 1)
lines.append(f"{idt}{{")
fields = model.model_fields
for i, (name, field) in enumerate(fields.items()):
alias = field.alias or name
if field.description:
lines.append(f"{idt_p1}# {field.description}")
# Check for a nested Pydantic model
field_type_to_render = field.annotation
# Unpack Optional[T] to get T
origin = get_origin(field_type_to_render)
if origin in (types.UnionType, Union):
non_none_args = [arg for arg in get_args(field_type_to_render) if type(arg) is not type(None)]
if len(non_none_args) == 1:
field_type_to_render = non_none_args[0]
# Unpack list[T] to get T
origin = get_origin(field_type_to_render)
if origin is list:
field_type_to_render = get_args(field_type_to_render)[0]
if inspect.isclass(field_type_to_render) and issubclass(field_type_to_render, BaseModel):
# Recursively build schema for nested models
nested_schema = _build_simplified_schema(field_type_to_render, indent + 1)
rendered_type = _render_type_str(field.annotation).replace(field_type_to_render.__name__, nested_schema)
else:
rendered_type = _render_type_str(field.annotation)
line = f"{idt_p1}{alias}: {rendered_type}"
if i < len(fields) - 1:
line += ","
lines.append(line)
lines.append(f"{idt}}}")
return "\n".join(lines)
class PydanticV2Adapter(JSONAdapter):
"""
A DSPy adapter that improves the rendering of Pydantic models for LLMs.
This adapter generates a simplified, human-readable schema for Pydantic output
fields, which is more token-efficient and easier for models to follow than a
full JSON schema. It also formats Pydantic input instances as clean, indented JSON.
Example Usage:
#```python
import dspy
from pydantic import BaseModel, Field
from typing import Literal
# from pydantic_v2_adapter import PydanticV2Adapter # Assuming you saved the code
# 1. Define your Pydantic models
class PatientAddress(BaseModel):
street: str
city: str
country: Literal["US", "CA"]
class PatientDetails(BaseModel):
name: str = Field(description="Full name of the patient.")
age: int
address: PatientAddress | None
# 2. Define a signature using the Pydantic model as an output field
class ExtractPatientInfo(dspy.Signature):
'''Extract patient information from the clinical note.'''
clinical_note: str = dspy.InputField()
patient_info: PatientDetails = dspy.OutputField()
# 3. Configure dspy to use the new adapter
llm = dspy.OpenAI(model="gpt-4o-mini")
dspy.configure(lm=llm, adapter=PydanticV2Adapter())
# 4. Use your program
extractor = dspy.Predict(ExtractPatientInfo)
note = "John Doe, 45 years old, lives at 123 Main St, Anytown. Resident of the US."
result = extractor(clinical_note=note)
print(result.patient_info)
# Expected output:
# PatientDetails(name='John Doe', age=45, address=PatientAddress(street='123 Main St', city='Anytown', country='US'))
#```
"""
def format_field_structure(self, signature: type[Signature]) -> str:
"""Overrides the base method to generate a simplified schema for Pydantic models."""
instruction = "You must produce a single, valid JSON object that strictly adheres to the following schema. Do not output anything else."
output_schemas = []
for name, field in signature.output_fields.items():
field_type = field.annotation
main_type = field_type
# Find the core type if it's wrapped in Optional or Union
origin = get_origin(field_type)
if origin in (types.UnionType, Union):
non_none_args = [arg for arg in get_args(field_type) if type(arg) is not type(None)]
if len(non_none_args) == 1:
main_type = non_none_args[0]
if inspect.isclass(main_type) and issubclass(main_type, BaseModel):
# We have a pydantic model, so build the simplified schema for it.
# Assuming the entire output is one JSON object corresponding to this model.
schema_str = _build_simplified_schema(main_type)
output_schemas.append(schema_str)
else:
# Handle non-pydantic or primitive types simply
type_str = _render_type_str(field_type)
output_schemas.append(f"Output field `{name}` should be of type: {type_str}")
# Assuming a single Pydantic model output field is the common case
return f"{instruction}\n\nSchema:\n" + "\n\n".join(output_schemas)
def format_user_message_content(
self,
signature: type[Signature],
inputs: dict[str, Any],
prefix: str = "",
suffix: str = "",
main_request: bool = False,
) -> str:
"""Overrides the base method to render Pydantic input instances as clean JSON."""
messages = [prefix]
for key, field_info in signature.input_fields.items():
if key in inputs:
value = inputs.get(key)
formatted_value = ""
if isinstance(value, BaseModel):
# Use clean, indented JSON for Pydantic instances
formatted_value = value.model_dump_json(indent=2, by_alias=True)
else:
# Fallback to the original dspy formatter for other types
formatted_value = original_format_field_value(field_info=field_info, value=value)
messages.append(f"[[ ## {key} ## ]]\n{formatted_value}")
if main_request:
output_requirements = self.user_message_output_requirements(signature)
if output_requirements is not None:
messages.append(output_requirements)
messages.append(suffix)
return "\n\n".join(m for m in messages if m).strip()
@prrao87
Copy link

prrao87 commented Aug 3, 2025

Great example! A blog post on this would be amazing, as this is something a lot of people will face. A couple of small nits:

  • On L66 comment syntax of // for the JSON-like syntax tends to work better on average, and it's easier to read for humans too
    lines.append(f"{idt_p1}# {field.description}") could be changed back to lines.append(f"{idt_p1}// {field.description}") (as your screenshot on Twitter showed)
  • On L177, the string header should be [[ ## schema ## ]] to be consistent with the rest of DSPy's header formatting, and to also be more explicit to the LLM where this section begins.
    return f"{instruction}\n\n[[ ## schema ## ]]:\n" + "\n\n".join(output_schemas)

Would be really cool if we could work this upstream to be exposed to users! I'd be happy to contribute benchmarks to help evaluate quantitatively that this format is objectively better than the JSON schema for all sizes of models and datasets. BAML has also extensively benchmarked this format, showing that it's better than raw JSON schema.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment