Last active
December 9, 2024 09:14
-
-
Save rohitgarud/eb60c095a53cf5303fb3ae07b98e268b to your computer and use it in GitHub Desktop.
Custom JSON Adapter for DSPy which uses ProcessSchema to simplify the JSON schema injected in the prompt when InputField or OutputField of the signature has Pydantic model as a type
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import enum | |
import inspect | |
import json | |
import re | |
import textwrap | |
from typing import Any, Dict, Literal | |
import json_repair | |
import pydantic | |
from dspy.adapters.image_utils import Image | |
from dspy.adapters.json_adapter import ( | |
FieldInfoWithName, | |
JSONAdapter, | |
_serialize_for_json, | |
enumerate_fields, | |
format_input_list_field_value, | |
format_turn, | |
parse_value, | |
) | |
from dspy.signatures.signature import SignatureMeta | |
from dspy.signatures.utils import get_dspy_field_type | |
from pydantic import TypeAdapter | |
from pydantic.fields import FieldInfo | |
from .process_schema import ProcessSchema | |
class CustomJSONAdapter(JSONAdapter): | |
def __init__(self): | |
super().__init__() | |
def __call__( | |
self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True | |
): | |
inputs = self.format(signature, demos, inputs) | |
inputs = ( | |
dict(prompt=inputs) | |
if isinstance(inputs, str) | |
else dict(messages=inputs) | |
) | |
outputs = lm(**inputs, **lm_kwargs) | |
values = [] | |
for output in outputs: | |
value = self.parse(signature, output, _parse_values=_parse_values) | |
assert ( | |
set(value.keys()) == set(signature.output_fields.keys()) | |
), f"Expected {signature.output_fields.keys()} but got {value.keys()}" | |
values.append(value) | |
return values | |
def format(self, signature, demos, inputs): | |
messages = [] | |
# Extract demos where some of the output_fields are not filled in. | |
incomplete_demos = [ | |
demo | |
for demo in demos | |
if not all(k in demo for k in signature.fields) | |
] | |
complete_demos = [ | |
demo for demo in demos if demo not in incomplete_demos | |
] | |
incomplete_demos = [ | |
demo | |
for demo in incomplete_demos | |
if any(k in demo for k in signature.input_fields) | |
and any(k in demo for k in signature.output_fields) | |
] | |
demos = incomplete_demos + complete_demos | |
messages.append( | |
{"role": "system", "content": prepare_instructions(signature)} | |
) | |
for demo in demos: | |
messages.append( | |
format_turn( | |
signature, | |
demo, | |
role="user", | |
incomplete=demo in incomplete_demos, | |
) | |
) | |
messages.append( | |
format_turn( | |
signature, | |
demo, | |
role="assistant", | |
incomplete=demo in incomplete_demos, | |
) | |
) | |
messages.append(format_turn(signature, inputs, role="user")) | |
return messages | |
def parse(self, signature, completion, _parse_values=True): | |
if "```json" in completion: | |
match = re.search(r"```json(.*?)```", completion, re.DOTALL) | |
if match: | |
completion = match.group(1).strip() | |
if r"\n" in completion: | |
completion = completion.replace("\n", "") | |
fields = json_repair.loads(completion) | |
fields = { | |
k: v for k, v in fields.items() if k in signature.output_fields | |
} | |
# attempt to cast each value to type signature.output_fields[k].annotation | |
for k, v in fields.items(): | |
if k in signature.output_fields: | |
fields[k] = parse_value( | |
v, signature.output_fields[k].annotation | |
) | |
if fields.keys() != signature.output_fields.keys(): | |
raise ValueError( | |
f"Expected {signature.output_fields.keys()} but got {fields.keys()}" | |
) | |
return fields | |
def prepare_instructions(signature: SignatureMeta): | |
parts = [] | |
parts.append( | |
"Your input fields are:\n" + enumerate_fields(signature.input_fields) | |
) | |
parts.append( | |
"Your output fields are:\n" + enumerate_fields(signature.output_fields) | |
) | |
parts.append( | |
( | |
"All interactions will be structured in the following way, with the appropriate values filled in." | |
) | |
) | |
def field_metadata(field_name, field_info): | |
type_ = field_info.annotation | |
if type_ is str: | |
desc = "" | |
elif type_ is bool: | |
desc = "must be True or False" | |
elif type_ in (int, float): | |
desc = f"must be a single {type_.__name__} value" | |
elif inspect.isclass(type_) and issubclass(type_, enum.Enum): | |
desc = f"must be one of: {'; '.join(type_.__members__)}" | |
elif hasattr(type_, "__origin__") and type_.__origin__ is Literal: | |
desc = f"must be one of: {'; '.join([str(x) for x in type_.__args__])}" # noqa: E501 | |
else: | |
desc = ( | |
"must be pareseable according to the following JSON schema: " | |
) | |
processed_schema = ProcessSchema( | |
schema=TypeAdapter(type_).json_schema() | |
).transform_schema() | |
desc += processed_schema | |
desc = ( | |
(" " * 8) + f"# note: the value you produce {desc}" if desc else "" | |
) | |
if get_dspy_field_type(field_info) == "input": | |
desc = desc.replace( | |
"# note: the value you produce must be", "#note: input will be" | |
) | |
desc = desc.replace("pareseable according to the", "having") | |
return f"{{{field_name}}}{desc}" | |
def format_signature_fields_for_instructions( | |
role, fields: Dict[str, FieldInfo] | |
): | |
formatted_fields = format_fields( | |
role=role, | |
fields_with_values={ | |
FieldInfoWithName( | |
name=field_name, info=field_info | |
): field_metadata(field_name, field_info) | |
for field_name, field_info in fields.items() | |
}, | |
) | |
return formatted_fields.replace("\\n", "\n") | |
parts.append("Inputs will have the following structure:") | |
parts.append( | |
format_signature_fields_for_instructions( | |
"user", signature.input_fields | |
) | |
) | |
parts.append("Outputs will be a JSON object with the following fields.") | |
parts.append( | |
format_signature_fields_for_instructions( | |
"assistant", signature.output_fields | |
) | |
) | |
instructions = textwrap.dedent(signature.instructions) | |
objective = ("\n" + " " * 8).join([""] + instructions.splitlines()) | |
parts.append( | |
f"In adhering to this structure, your objective is: {objective}" | |
) | |
return "\n\n".join(parts).strip() | |
def format_fields( | |
role: str, fields_with_values: Dict[FieldInfoWithName, Any] | |
) -> str: | |
""" | |
Formats the values of the specified fields according to the field's DSPy type (input or output), | |
annotation (e.g. str, int, etc.), and the type of the value itself. Joins the formatted values | |
into a single string, which is is a multiline string if there are multiple fields. | |
Args: | |
fields_with_values: A dictionary mapping information about a field to its corresponding | |
value. | |
Returns: | |
The joined formatted values of the fields, represented as a string. | |
""" | |
if role == "assistant": | |
d = fields_with_values.items() | |
d = {k.name: _serialize_for_json(v) for k, v in d} | |
return json.dumps(_serialize_for_json(d), indent=2) | |
output = [] | |
for field, field_value in fields_with_values.items(): | |
formatted_field_value = _format_field_value( | |
field_info=field.info, value=field_value | |
) | |
output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}") | |
return "\n\n".join(output).strip() | |
def _format_field_value(field_info: FieldInfo, value: Any) -> str: | |
""" | |
Formats the value of the specified field according to the field's DSPy type (input or output), | |
annotation (e.g. str, int, etc.), and the type of the value itself. | |
Args: | |
field_info: Information about the field, including its DSPy field type and annotation. | |
value: The value of the field. | |
Returns: | |
The formatted value of the field, represented as a string. | |
""" | |
if isinstance(value, list) and field_info.annotation is str: | |
# If the field has no special type requirements, format it as a nice numbere list for the LM. | |
return format_input_list_field_value(value) | |
if field_info.annotation is Image: | |
raise NotImplementedError("Images are not yet supported in JSON mode.") | |
elif ( | |
isinstance(value, pydantic.BaseModel) | |
or isinstance(value, dict) | |
or isinstance(value, list) | |
): | |
return json.dumps(_serialize_for_json(value)) | |
else: | |
return str(value) |
Experimenting with locally hosted Llama 3.1 8B Q4 quantized model and Llama 3.2 1B Q8 quantized model. Getting excellent performance in adhering to the JSON schema without using any Structured Output support from the model/LM Studio.
Also, I working on getting a response following JSON schema but without newlines and quotes for formatting as json_repair can take care of parsing the output without JSON formatting but with structure
Now, if the InputField type is other than str, like the Pydantic model, the schema will be added to the system prompt. The default dspy JSONAdapter does not add the InputField schema
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
STILL WIP: This came out of a requirement for less verbose DSPy prompts because the number of tokens does matter