Skip to content

Instantly share code, notes, and snippets.

@giacomorebecchi
Created March 6, 2024 13:57
Show Gist options
  • Save giacomorebecchi/023891b8160c75094e003e2c8d2846ea to your computer and use it in GitHub Desktop.
Save giacomorebecchi/023891b8160c75094e003e2c8d2846ea to your computer and use it in GitHub Desktop.
Pydantic v2-compatible Url, based on the behaviour of the v1 version
from typing import Any, Optional
from urllib.parse import parse_qs, unquote
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
from pydantic import v1 as pydantic_v1
from pydantic_core import CoreSchema, core_schema
from typing_extensions import Self
class Url(pydantic_v1.AnyUrl):
quoted: bool = False
def query_params(self) -> dict[str, str]:
return parse_qs(self.query)
@property
def username(self) -> Optional[str]:
return self.user
@classmethod
def build(
cls,
*,
scheme: str,
user: Optional[str] = None,
password: Optional[str] = None,
host: Optional[str] = None,
port: Optional[int] = None,
path: Optional[str] = None,
query: Optional[str] = None,
fragment: Optional[str] = None,
**_kwargs: str,
) -> Self:
return super().build(
scheme=scheme,
user=user,
password=password,
host=host if host is not None else "",
port=str(port) if port is not None else None,
path=path,
query=query,
fragment=fragment,
**_kwargs,
)
def _stringify_url(self) -> str:
return self.build(
scheme=self.scheme,
user=unquote(self.username) if self.username is not None else None,
password=unquote(self.password) if self.password is not None else None,
host=self.host,
port=self.port,
path=self.path,
query=unquote(self.query) if self.query is not None else None,
fragment=self.fragment,
)
@classmethod
def _validate_from_str(cls, value: str):
if cls.strip_whitespace:
value = value.strip()
m = cls._match_url(value)
# the regex should always match,
# if it doesn't please report with details of the URL tried
if m is None:
raise ValueError("URL regex failed unexpectedly")
original_parts = m.groupdict()
parts = cls.apply_default_parts(original_parts)
parts = cls.validate_parts(parts)
if m.end() != len(value):
raise ValueError(
"URL invalid, extra characters found after valid URL:"
f" {value[m.end() :]}"
)
return cls._build_url(m, value, parts)
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
from_str_schema = core_schema.chain_schema(
[
core_schema.str_schema(
max_length=cls.max_length,
min_length=cls.min_length,
),
core_schema.no_info_plain_validator_function(cls._validate_from_str),
]
)
return core_schema.json_or_python_schema(
json_schema=from_str_schema,
python_schema=from_str_schema,
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: pydantic_v1.AnyUrl.__str__(instance)
),
)
@classmethod
def __get_pydantic_json_schema__(
cls,
core_schema: CoreSchema,
handler: GetJsonSchemaHandler,
) -> dict[str, Any]:
json_schema = handler(core_schema)
json_schema = handler.resolve_ref_schema(json_schema)
super().__modify_schema__(json_schema)
return json_schema
def __str__(self):
return (
self._stringify_url() if self.quoted else pydantic_v1.AnyUrl.__str__(self)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment