Skip to content

Instantly share code, notes, and snippets.

@rohitgarud
Last active December 9, 2024 09:14
Show Gist options
  • Save rohitgarud/eb60c095a53cf5303fb3ae07b98e268b to your computer and use it in GitHub Desktop.
Save rohitgarud/eb60c095a53cf5303fb3ae07b98e268b to your computer and use it in GitHub Desktop.
Custom JSON Adapter for DSPy which uses ProcessSchema to simplify the JSON schema injected in the prompt when InputField or OutputField of the signature has Pydantic model as a type
import enum
import inspect
import json
import re
import textwrap
from typing import Any, Dict, Literal
import json_repair
import pydantic
from dspy.adapters.image_utils import Image
from dspy.adapters.json_adapter import (
FieldInfoWithName,
JSONAdapter,
_serialize_for_json,
enumerate_fields,
format_input_list_field_value,
format_turn,
parse_value,
)
from dspy.signatures.signature import SignatureMeta
from dspy.signatures.utils import get_dspy_field_type
from pydantic import TypeAdapter
from pydantic.fields import FieldInfo
from .process_schema import ProcessSchema
class CustomJSONAdapter(JSONAdapter):
def __init__(self):
super().__init__()
def __call__(
self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True
):
inputs = self.format(signature, demos, inputs)
inputs = (
dict(prompt=inputs)
if isinstance(inputs, str)
else dict(messages=inputs)
)
outputs = lm(**inputs, **lm_kwargs)
values = []
for output in outputs:
value = self.parse(signature, output, _parse_values=_parse_values)
assert (
set(value.keys()) == set(signature.output_fields.keys())
), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
values.append(value)
return values
def format(self, signature, demos, inputs):
messages = []
# Extract demos where some of the output_fields are not filled in.
incomplete_demos = [
demo
for demo in demos
if not all(k in demo for k in signature.fields)
]
complete_demos = [
demo for demo in demos if demo not in incomplete_demos
]
incomplete_demos = [
demo
for demo in incomplete_demos
if any(k in demo for k in signature.input_fields)
and any(k in demo for k in signature.output_fields)
]
demos = incomplete_demos + complete_demos
messages.append(
{"role": "system", "content": prepare_instructions(signature)}
)
for demo in demos:
messages.append(
format_turn(
signature,
demo,
role="user",
incomplete=demo in incomplete_demos,
)
)
messages.append(
format_turn(
signature,
demo,
role="assistant",
incomplete=demo in incomplete_demos,
)
)
messages.append(format_turn(signature, inputs, role="user"))
return messages
def parse(self, signature, completion, _parse_values=True):
if "```json" in completion:
match = re.search(r"```json(.*?)```", completion, re.DOTALL)
if match:
completion = match.group(1).strip()
if r"\n" in completion:
completion = completion.replace("\n", "")
fields = json_repair.loads(completion)
fields = {
k: v for k, v in fields.items() if k in signature.output_fields
}
# attempt to cast each value to type signature.output_fields[k].annotation
for k, v in fields.items():
if k in signature.output_fields:
fields[k] = parse_value(
v, signature.output_fields[k].annotation
)
if fields.keys() != signature.output_fields.keys():
raise ValueError(
f"Expected {signature.output_fields.keys()} but got {fields.keys()}"
)
return fields
def prepare_instructions(signature: SignatureMeta):
parts = []
parts.append(
"Your input fields are:\n" + enumerate_fields(signature.input_fields)
)
parts.append(
"Your output fields are:\n" + enumerate_fields(signature.output_fields)
)
parts.append(
(
"All interactions will be structured in the following way, with the appropriate values filled in."
)
)
def field_metadata(field_name, field_info):
type_ = field_info.annotation
if type_ is str:
desc = ""
elif type_ is bool:
desc = "must be True or False"
elif type_ in (int, float):
desc = f"must be a single {type_.__name__} value"
elif inspect.isclass(type_) and issubclass(type_, enum.Enum):
desc = f"must be one of: {'; '.join(type_.__members__)}"
elif hasattr(type_, "__origin__") and type_.__origin__ is Literal:
desc = f"must be one of: {'; '.join([str(x) for x in type_.__args__])}" # noqa: E501
else:
desc = (
"must be pareseable according to the following JSON schema: "
)
processed_schema = ProcessSchema(
schema=TypeAdapter(type_).json_schema()
).transform_schema()
desc += processed_schema
desc = (
(" " * 8) + f"# note: the value you produce {desc}" if desc else ""
)
if get_dspy_field_type(field_info) == "input":
desc = desc.replace(
"# note: the value you produce must be", "#note: input will be"
)
desc = desc.replace("pareseable according to the", "having")
return f"{{{field_name}}}{desc}"
def format_signature_fields_for_instructions(
role, fields: Dict[str, FieldInfo]
):
formatted_fields = format_fields(
role=role,
fields_with_values={
FieldInfoWithName(
name=field_name, info=field_info
): field_metadata(field_name, field_info)
for field_name, field_info in fields.items()
},
)
return formatted_fields.replace("\\n", "\n")
parts.append("Inputs will have the following structure:")
parts.append(
format_signature_fields_for_instructions(
"user", signature.input_fields
)
)
parts.append("Outputs will be a JSON object with the following fields.")
parts.append(
format_signature_fields_for_instructions(
"assistant", signature.output_fields
)
)
instructions = textwrap.dedent(signature.instructions)
objective = ("\n" + " " * 8).join([""] + instructions.splitlines())
parts.append(
f"In adhering to this structure, your objective is: {objective}"
)
return "\n\n".join(parts).strip()
def format_fields(
role: str, fields_with_values: Dict[FieldInfoWithName, Any]
) -> str:
"""
Formats the values of the specified fields according to the field's DSPy type (input or output),
annotation (e.g. str, int, etc.), and the type of the value itself. Joins the formatted values
into a single string, which is is a multiline string if there are multiple fields.
Args:
fields_with_values: A dictionary mapping information about a field to its corresponding
value.
Returns:
The joined formatted values of the fields, represented as a string.
"""
if role == "assistant":
d = fields_with_values.items()
d = {k.name: _serialize_for_json(v) for k, v in d}
return json.dumps(_serialize_for_json(d), indent=2)
output = []
for field, field_value in fields_with_values.items():
formatted_field_value = _format_field_value(
field_info=field.info, value=field_value
)
output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}")
return "\n\n".join(output).strip()
def _format_field_value(field_info: FieldInfo, value: Any) -> str:
"""
Formats the value of the specified field according to the field's DSPy type (input or output),
annotation (e.g. str, int, etc.), and the type of the value itself.
Args:
field_info: Information about the field, including its DSPy field type and annotation.
value: The value of the field.
Returns:
The formatted value of the field, represented as a string.
"""
if isinstance(value, list) and field_info.annotation is str:
# If the field has no special type requirements, format it as a nice numbere list for the LM.
return format_input_list_field_value(value)
if field_info.annotation is Image:
raise NotImplementedError("Images are not yet supported in JSON mode.")
elif (
isinstance(value, pydantic.BaseModel)
or isinstance(value, dict)
or isinstance(value, list)
):
return json.dumps(_serialize_for_json(value))
else:
return str(value)
@rohitgarud
Copy link
Author

rohitgarud commented Dec 8, 2024

STILL WIP: This came out of a requirement for less verbose DSPy prompts because the number of tokens does matter

@rohitgarud
Copy link
Author

Experimenting with locally hosted Llama 3.1 8B Q4 quantized model and Llama 3.2 1B Q8 quantized model. Getting excellent performance in adhering to the JSON schema without using any Structured Output support from the model/LM Studio.

Also, I working on getting a response following JSON schema but without newlines and quotes for formatting as json_repair can take care of parsing the output without JSON formatting but with structure

@rohitgarud
Copy link
Author

rohitgarud commented Dec 9, 2024

Now, if the InputField type is other than str, like the Pydantic model, the schema will be added to the system prompt. The default dspy JSONAdapter does not add the InputField schema

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment