Created
June 26, 2023 14:47
-
-
Save guibeira/53f815482119816e58136c4883554637 to your computer and use it in GitHub Desktop.
Sample of base client
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import json | |
from functools import partial | |
from typing import Any, Dict, List, Optional, Type, TypeVar, Union | |
import httpx | |
from pydantic import BaseModel, Field | |
T = TypeVar("T", bound=BaseModel) | |
class DefaultSuccessModel(BaseModel): | |
message: str | |
data: dict | |
class DefaultErrorModel(BaseModel): | |
error: str | |
class DefaultRequestModel(BaseModel): | |
pass | |
class DefaultDeleteResponseModel(BaseModel): | |
pass | |
class RequestMethod(BaseModel): | |
method: str = Field(...) | |
endpoint: str = Field(...) | |
method_name: str = Field(...) | |
request_model: Optional[Type[BaseModel]] | |
response_model: Optional[Type[BaseModel]] | |
response_model_errors: Optional[Dict[int, Type[BaseModel]]] | |
class BaseAsyncClient: | |
RESPONSE_MODELS: Dict[int, Type[T]] = { | |
200: DefaultSuccessModel, | |
204: DefaultDeleteResponseModel, | |
400: DefaultErrorModel, | |
500: DefaultErrorModel, | |
} | |
REQUEST_METHODS: List[RequestMethod] = [] | |
RETRIES = 3 | |
def __init__( | |
self, | |
base_url: str, | |
headers={}, | |
trust_env=False, | |
timeout=10, | |
follow_redirects=True, | |
*args, | |
**kwargs, | |
): | |
self.base_url = base_url | |
self.trust_env = trust_env | |
self.timeout = timeout | |
self.follow_redirects = follow_redirects | |
self.headers = headers | |
self.client = httpx.AsyncClient(base_url=self.base_url) | |
for request_method in self.REQUEST_METHODS: | |
self._bind_request_method(request_method) | |
def _get_client(self): | |
return httpx.AsyncClient( | |
base_url=self.base_url, | |
headers=self.headers, | |
trust_env=self.trust_env, | |
timeout=self.timeout, | |
follow_redirects=self.follow_redirects, | |
transport=httpx.AsyncHTTPTransport(retries=self.RETRIES), | |
) | |
def _bind_request_method(self, request_method: RequestMethod): | |
method = partial( | |
self.request, | |
request_method.method, | |
request_method.endpoint, | |
response_model=request_method.response_model, | |
data=request_method.request_model, | |
) | |
setattr(self, request_method.method_name, method) | |
async def request( | |
self, | |
method: str, | |
endpoint: str, | |
params: Optional[Dict[str, Any]] = None, | |
data: Optional[Union[Dict[str, Any], BaseModel]] = None, | |
response_model: Optional[Type[T]] = None, | |
response_model_errors: Optional[Dict[int, Type[T]]] = None, | |
*args, | |
**kwargs, | |
): | |
if isinstance(data, BaseModel): | |
data = data.dict() | |
headers = kwargs.pop("headers") if "headers" in kwargs else {} | |
formatted_endpoint = f"{self.base_url}{endpoint.format(**kwargs)}" | |
async with self._get_client() as client: | |
response = await client.request( | |
method, formatted_endpoint, params=params, data=data, headers=headers | |
) | |
model = self._get_response_model( | |
response, response_model, response_model_errors | |
) | |
response_data = self._extract_response_data(response) | |
if response_data is None: | |
return | |
return model(**response_data) if model else response_data | |
def _get_response_model( | |
self, | |
response, | |
response_model: Optional[Type[T]] = None, | |
response_model_errors: Optional[Dict[int, Type[T]]] = None, | |
): | |
if response_model and response.status_code == 200: | |
return response_model | |
elif response_model_errors and response.status_code in response_model_errors: | |
return response_model_errors[response.status_code] | |
elif response.status_code in self.RESPONSE_MODELS: | |
return self.RESPONSE_MODELS[response.status_code] | |
else: | |
return None | |
def _extract_response_data(self, response): | |
try: | |
return response.json() | |
except json.decoder.JSONDecodeError: | |
return None | |
class PokemonResponseModel(BaseModel): | |
name: str | |
order: int | |
class PokemonAbilityResponseModel(BaseModel): | |
name: str | |
url: str | |
class PokemonAbilitiesResponseModel(BaseModel): | |
results: List[PokemonAbilityResponseModel] | |
class PokeApi(BaseAsyncClient): | |
REQUEST_METHODS = [ | |
RequestMethod( | |
method="GET", | |
endpoint="/pokemon/{pokemon_name}", | |
method_name="get_pokemon", | |
response_model=PokemonResponseModel, | |
), | |
RequestMethod( | |
method="GET", | |
endpoint="/ability", | |
method_name="get_abilities", | |
response_model=PokemonAbilitiesResponseModel, | |
), | |
] | |
class ReqresUserResponseModel(BaseModel): | |
data: dict | |
class ReqresUserRequestModel(BaseModel): | |
name: str | |
job: str | |
class ReqresRegisterRequestModel(BaseModel): | |
email: str | |
class ReqresRegisterResponseError(BaseModel): | |
error: str | |
class Reqres(BaseAsyncClient): | |
REQUEST_METHODS = [ | |
RequestMethod( | |
method="POST", | |
endpoint="/api/users", | |
method_name="create_user", | |
request_model=ReqresUserRequestModel, | |
response_model=ReqresUserResponseModel, | |
), | |
RequestMethod( | |
method="GET", | |
endpoint="/api/users/{user_id}", | |
method_name="get_user", | |
response_model=ReqresUserResponseModel, | |
), | |
RequestMethod( | |
method="POST", | |
endpoint="/api/register", | |
method_name="register", | |
request_model=ReqresUserRequestModel, | |
response_model=ReqresUserResponseModel, | |
response_model_errors={400: ReqresRegisterResponseError}, | |
), | |
RequestMethod( | |
method="DELETE", | |
endpoint="/api/users/{user_id}", | |
method_name="delete_user", | |
), | |
RequestMethod( | |
method="PUT", | |
endpoint="/api/users/{user_id}", | |
method_name="update_user", | |
request_model=ReqresUserRequestModel, | |
response_model=ReqresUserRequestModel, | |
), | |
] | |
poke_api = PokeApi("https://pokeapi.co/api/v2") | |
reqres = Reqres("https://reqres.in") | |
async def main(): | |
# get example ■ Cannot access member "get_pokemon" for type "PokeApi" Member "get_pokemon" is unknown | |
response = await poke_api.get_pokemon(pokemon_name="pikachu") | |
print(response) | |
# get example with query params | |
params = {"limit": 2} | |
response = await poke_api.get_abilities(params=params) | |
print(response) | |
# post example with request model | |
data = ReqresUserRequestModel(name="John", job="Developer") | |
response = await reqres.create_user(data=data) | |
print(response) | |
# post example with failed response with error model | |
register_data = ReqresRegisterRequestModel(email="[email protected]") | |
response = await reqres.register(data=register_data) | |
print(response) | |
# get example with path params ■ Cannot access member "get_user" for type "Reqres" Member "get_user" is unknown | |
reponse = await reqres.get_user(user_id=2) | |
print(reponse) | |
# put example | |
update_data = ReqresUserRequestModel(name="John", job="Developer") | |
response = await reqres.update_user(user_id=2, data=update_data) | |
print(response) | |
# delete example ■ Cannot access member "delete_user" for type "Reqres" Member "delete_user" is unknown | |
response = await reqres.delete_user(user_id=2) | |
print(response) | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment