Skip to content

Instantly share code, notes, and snippets.

@theacodes
Last active November 19, 2024 01:24
Show Gist options
  • Save theacodes/8537caa9d4d27518832c63679367c0cc to your computer and use it in GitHub Desktop.
Save theacodes/8537caa9d4d27518832c63679367c0cc to your computer and use it in GitHub Desktop.
Pydantic model_replace helper
from typing import Annotated, Any, overload
import pydantic
@overload
def model_replace[T: pydantic.BaseModel](val: T, /, **kwargs) -> T: ...
@overload
def model_replace[T: pydantic.BaseModel](val: T, _partial: T, /) -> T: ...
def model_replace[T: pydantic.BaseModel](val: T, _partial: T | None = None, /, **kwargs) -> T:
"""Helper for copying a model instance while replacing some fields.
This is similar to pydantic's `model_copy`, but is sightly different in behavior.
The first form updates fields from `**kwargs`:
```python
result = model_replace(val, a=42, b="hello")
```
This will copy `val` and update _only_ fields `a` and `b`. Unlike `model_copy`, it validates the new values.
The second form updates fields from another instance:
```python
result = model_replace(val, Example(a=42))
```
This will copy `val` and update _only_ field `a`. Only the fields explicitly set will be updated. This is most
useful for models that have defaults for every field.
Refs:
- https://github.com/pydantic/pydantic/discussions/3139
- https://github.com/pydantic/pydantic/discussions/3416
- https://github.com/pydantic/pydantic/discussions/8960
"""
if _partial:
return val.model_copy(update=_partial.model_dump(exclude_unset=True, round_trip=True))
return type(val).model_validate(val.model_dump(round_trip=True) | kwargs)
import pydantic
import pytest
from .model_replace import model_replace
class ExampleState(pydantic.BaseModel):
a: int = 1
b: dict[str, str] = pydantic.Field(default_factory=dict)
class TestModelUpdate:
def test_kwargs(self):
val = ExampleState(b=dict(three="four"))
result = model_replace(val, a=42)
assert result is not val
assert result.a == 42
assert result.b == dict(three="four")
def test_kwargs_validation(self):
val = ExampleState(b=dict(three="four"))
with pytest.raises(pydantic.ValidationError):
model_replace(val, a="adfafas")
def test_partial(self):
val = ExampleState(b=dict(three="four"))
result = model_replace(val, ExampleState(a=42))
assert result is not val
assert result.a == 42
assert result.b == dict(three="four")
@mpkocher
Copy link

This looks like a + use case might be useful to consider. Perhaps consider implementing __add__ with the custom semantics for your use case?

a = ExampleState(a=2)
b = ExampleState(b={"x": "y"})
c = a + b

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment