Created
September 22, 2019 15:08
-
-
Save Jackenmen/c7267322ccd90aecf3d5caa8afec77da to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import NamedTuple, Optional, Union | |
import asyncio | |
import json | |
import time | |
import aiohttp | |
class RedditException(Exception): | |
"""Base exception class for Reddit API.""" | |
class UnsupportedTokenType(RedditException): | |
"""OAuth2 Token request returned unsupported token type""" | |
class HTTPException(RedditException): | |
"""Exception that's thrown when an HTTP request operation fails. | |
Attributes | |
---------- | |
response: aiohttp.ClientResponse | |
The response of the failed HTTP request. | |
status: int | |
The status code of the HTTP request. | |
message: dict | |
Details about error. | |
""" | |
def __init__(self, response: aiohttp.ClientResponse, data: dict): | |
self.response = response | |
self.status = response.status | |
self.data = data | |
super().__init__( | |
f"{self.response.reason} (status code: {self.status}): {self.data}" | |
) | |
class TokenInfo(NamedTuple): | |
access_token: str | |
expires_at: int | |
class RedditClient: | |
def __init__( | |
self, | |
client_id: str, | |
client_secret: str, | |
username: str, | |
password: str, | |
*, | |
loop: Optional[asyncio.AbstractEventLoop] = None, | |
): | |
self.loop = asyncio.get_event_loop() if loop is None else loop | |
self._session = aiohttp.ClientSession(loop=self.loop) | |
self._token_info: Optional[TokenInfo] = None | |
self._username = username | |
self._password = password | |
self.client_id = client_id | |
self._client_secret = client_secret | |
async def _token(self, *, force_refresh: bool = False) -> str: | |
if ( | |
self._token_info is not None | |
and self._token_info.expires_at - time.time() > 60 | |
and not force_refresh | |
): | |
return self._token_info.access_token | |
self._token_info = await self._request_token() | |
return self._token_info.access_token | |
async def _request_token(self): | |
async with self._session.post( | |
"https://www.reddit.com/api/v1/access_token", | |
data={ | |
"grant_type": "password", | |
"username": self._username, | |
"password": self._password, | |
}, | |
auth=aiohttp.BasicAuth(self.client_id, self._client_secret), | |
) as resp: | |
data = await resp.json() | |
if resp.status != 200: | |
raise HTTPException(resp, data) | |
print(data) | |
if data["token_type"].lower() != "bearer": | |
raise UnsupportedTokenType( | |
f"Token type `{data['token_type']}` is not supported" | |
) | |
data["expires_at"] = time.time() + data.pop("expires_in") | |
return TokenInfo(data["access_token"], data["expires_at"]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment