Skip to content

Instantly share code, notes, and snippets.

@YouKnow-sys
Last active December 2, 2024 19:50
Show Gist options
  • Save YouKnow-sys/322a41e131d7e8964e6c1c82b880bd99 to your computer and use it in GitHub Desktop.
Save YouKnow-sys/322a41e131d7e8964e6c1c82b880bd99 to your computer and use it in GitHub Desktop.
A async sqlite storage for aiogram
import json
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Self, cast
import aiosqlite
from aiogram.fsm.state import State
from aiogram.fsm.storage.base import (
BaseStorage,
DefaultKeyBuilder,
KeyBuilder,
StateType,
StorageKey,
)
from lru import LRU
_SERIALIZER = Callable[..., str]
_DESERIALIZER = Callable[..., Any]
class SQLiteStorage(BaseStorage):
_state_cache: LRU
_conn: aiosqlite.Connection
_key_builder: KeyBuilder
_data_serializer: _SERIALIZER
_data_deserializer: _DESERIALIZER
@classmethod
async def connect(
cls,
db_path: str | Path = "fsm_starage.db",
key_builder: Optional[KeyBuilder] = None,
data_serializer: _SERIALIZER = json.dumps,
data_deserializer: _DESERIALIZER = json.loads,
) -> Self:
self = cls()
if key_builder is None:
key_builder = DefaultKeyBuilder()
self._key_builder = key_builder
self._data_serializer = data_serializer
self._data_deserializer = data_deserializer
self._state_cache = LRU(200)
self._conn = await aiosqlite.connect(db_path)
await self._conn.execute("""
CREATE TABLE
IF NOT EXISTS `aiogram_state` (
`key` TEXT NOT NULL PRIMARY KEY,
`state` TEXT NOT NULL
) STRICT
""")
await self._conn.execute("""
CREATE TABLE
IF NOT EXISTS `aiogram_data`(
`key` TEXT NOT NULL PRIMARY KEY,
`data` TEXT
) STRICT
""")
return self
def _resolve_state(self, value: StateType) -> Optional[str]:
if value is None:
return None
if isinstance(value, State):
return value.state
return str(value)
async def set_state(self, key: StorageKey, state: StateType = None) -> None:
id = self._key_builder.build(key)
if state := self._resolve_state(state):
await self._conn.execute(
"INSERT OR REPLACE INTO `aiogram_state` (`key`, `state`) VALUES (?, ?)",
(id, state),
)
self._state_cache[id] = state
else:
await self._conn.execute(
"DELETE FROM `aiogram_state` WHERE `key` = ?", (id,)
)
self._state_cache.pop(id)
await self._conn.commit()
async def get_state(self, key: StorageKey) -> Optional[str]:
id = self._key_builder.build(key)
if id in self._state_cache:
return self._state_cache[id]
async with self._conn.execute(
"SELECT `state` FROM `aiogram_state` WHERE `key` = ?", (id,)
) as c:
row = await c.fetchone()
if row is None:
return None
self._state_cache[id] = row[0]
return row[0]
async def set_data(self, key: StorageKey, data: Dict[str, Any]) -> None:
id = self._key_builder.build(key)
if not data:
await self._conn.execute(
"DELETE FROM `aiogram_data` WHERE `key` = ?", (id,)
)
else:
data_str = self._data_serializer(data)
await self._conn.execute(
"INSERT OR REPLACE INTO `aiogram_data` (`key`, `data`) VALUES (?, ?)",
(id, data_str),
)
await self._conn.commit()
async def get_data(self, key: StorageKey) -> Dict[str, Any]:
id = self._key_builder.build(key)
async with await self._conn.execute(
"SELECT `data` FROM `aiogram_data` WHERE `key` = ?",
(id,),
) as c:
row = await c.fetchone()
if row is None or not row[0]:
return {}
value = row[0]
if isinstance(value, bytes):
value = value.decode("utf-8")
return cast(Dict[str, Any], self._data_deserializer(value))
async def close(self) -> None:
await self._conn.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment