Last active
December 2, 2024 19:50
-
-
Save YouKnow-sys/322a41e131d7e8964e6c1c82b880bd99 to your computer and use it in GitHub Desktop.
A async sqlite storage for aiogram
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from pathlib import Path | |
from typing import Any, Callable, Dict, Optional, Self, cast | |
import aiosqlite | |
from aiogram.fsm.state import State | |
from aiogram.fsm.storage.base import ( | |
BaseStorage, | |
DefaultKeyBuilder, | |
KeyBuilder, | |
StateType, | |
StorageKey, | |
) | |
from lru import LRU | |
_SERIALIZER = Callable[..., str] | |
_DESERIALIZER = Callable[..., Any] | |
class SQLiteStorage(BaseStorage): | |
_state_cache: LRU | |
_conn: aiosqlite.Connection | |
_key_builder: KeyBuilder | |
_data_serializer: _SERIALIZER | |
_data_deserializer: _DESERIALIZER | |
@classmethod | |
async def connect( | |
cls, | |
db_path: str | Path = "fsm_starage.db", | |
key_builder: Optional[KeyBuilder] = None, | |
data_serializer: _SERIALIZER = json.dumps, | |
data_deserializer: _DESERIALIZER = json.loads, | |
) -> Self: | |
self = cls() | |
if key_builder is None: | |
key_builder = DefaultKeyBuilder() | |
self._key_builder = key_builder | |
self._data_serializer = data_serializer | |
self._data_deserializer = data_deserializer | |
self._state_cache = LRU(200) | |
self._conn = await aiosqlite.connect(db_path) | |
await self._conn.execute(""" | |
CREATE TABLE | |
IF NOT EXISTS `aiogram_state` ( | |
`key` TEXT NOT NULL PRIMARY KEY, | |
`state` TEXT NOT NULL | |
) STRICT | |
""") | |
await self._conn.execute(""" | |
CREATE TABLE | |
IF NOT EXISTS `aiogram_data`( | |
`key` TEXT NOT NULL PRIMARY KEY, | |
`data` TEXT | |
) STRICT | |
""") | |
return self | |
def _resolve_state(self, value: StateType) -> Optional[str]: | |
if value is None: | |
return None | |
if isinstance(value, State): | |
return value.state | |
return str(value) | |
async def set_state(self, key: StorageKey, state: StateType = None) -> None: | |
id = self._key_builder.build(key) | |
if state := self._resolve_state(state): | |
await self._conn.execute( | |
"INSERT OR REPLACE INTO `aiogram_state` (`key`, `state`) VALUES (?, ?)", | |
(id, state), | |
) | |
self._state_cache[id] = state | |
else: | |
await self._conn.execute( | |
"DELETE FROM `aiogram_state` WHERE `key` = ?", (id,) | |
) | |
self._state_cache.pop(id) | |
await self._conn.commit() | |
async def get_state(self, key: StorageKey) -> Optional[str]: | |
id = self._key_builder.build(key) | |
if id in self._state_cache: | |
return self._state_cache[id] | |
async with self._conn.execute( | |
"SELECT `state` FROM `aiogram_state` WHERE `key` = ?", (id,) | |
) as c: | |
row = await c.fetchone() | |
if row is None: | |
return None | |
self._state_cache[id] = row[0] | |
return row[0] | |
async def set_data(self, key: StorageKey, data: Dict[str, Any]) -> None: | |
id = self._key_builder.build(key) | |
if not data: | |
await self._conn.execute( | |
"DELETE FROM `aiogram_data` WHERE `key` = ?", (id,) | |
) | |
else: | |
data_str = self._data_serializer(data) | |
await self._conn.execute( | |
"INSERT OR REPLACE INTO `aiogram_data` (`key`, `data`) VALUES (?, ?)", | |
(id, data_str), | |
) | |
await self._conn.commit() | |
async def get_data(self, key: StorageKey) -> Dict[str, Any]: | |
id = self._key_builder.build(key) | |
async with await self._conn.execute( | |
"SELECT `data` FROM `aiogram_data` WHERE `key` = ?", | |
(id,), | |
) as c: | |
row = await c.fetchone() | |
if row is None or not row[0]: | |
return {} | |
value = row[0] | |
if isinstance(value, bytes): | |
value = value.decode("utf-8") | |
return cast(Dict[str, Any], self._data_deserializer(value)) | |
async def close(self) -> None: | |
await self._conn.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment