Created
July 26, 2024 00:10
-
-
Save a-r-d/aab8e5efa02a60d959f13f77fc4e3664 to your computer and use it in GitHub Desktop.
Generate supabase dataclasses from a tables that are subscriptable like dicts.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# note: run the sql file in your database first. | |
import os | |
from supabase import create_client, Client | |
from typing import List, Optional, Any, Union | |
# Ensure the models directory exists | |
os.makedirs("./models", exist_ok=True) | |
url: str = os.environ.get("SUPABASE_URL") | |
key: str = os.environ.get("SUPABASE_KEY") | |
supabase: Client = create_client(url, key) | |
def generate_datetime_parser(): | |
return """ | |
@staticmethod | |
def parse_datetime(dt_string: str) -> Optional[datetime]: | |
if not dt_string: | |
return None | |
try: | |
return datetime.fromisoformat(dt_string.replace('Z', '+00:00')) | |
except ValueError: | |
try: | |
from dateutil import parser | |
return parser.isoparse(dt_string) | |
except: | |
print(f"Error parsing datetime: {dt_string}") | |
return None | |
""" | |
def generate_from_dict_method(table_name: str, columns: List[dict]): | |
method = f""" | |
@classmethod | |
def from_dict(cls, data: Union[dict, str]) -> Optional['{table_name.capitalize()}']: | |
if isinstance(data, str): | |
try: | |
data = json.loads(data) | |
except json.JSONDecodeError: | |
print(f"Error: Unable to parse string as JSON: {{data[:100]}}") | |
return None | |
if not isinstance(data, dict): | |
print(f"Error: from_dict received unexpected type: {{type(data)}}") | |
return None | |
parsed_data = {{}} | |
""" | |
for column in columns: | |
col_name = column['column_name'] | |
data_type = column['data_type'] | |
if data_type in ['timestamp with time zone', 'date', 'time']: | |
method += f" parsed_data['{col_name}'] = cls.parse_datetime(data.get('{col_name}'))\n" | |
elif data_type == 'numeric': | |
method += f" parsed_data['{col_name}'] = Decimal(data.get('{col_name}')) if data.get('{col_name}') is not None else None\n" | |
else: | |
method += f" parsed_data['{col_name}'] = data.get('{col_name}')\n" | |
method += " return cls(**parsed_data)\n" | |
return method | |
def generate_class(table_name: str, columns: List[dict]) -> str: | |
class_str = f"""@dataclass | |
class {table_name.capitalize()}: | |
""" | |
for column in columns: | |
col_name = column['column_name'] | |
data_type = column['data_type'] | |
python_type = get_python_type(data_type) | |
class_str += f" {col_name}: {python_type}\n" | |
class_str += generate_datetime_parser() | |
class_str += generate_from_dict_method(table_name, columns) | |
# Add methods to make the class subscriptable | |
class_str += """ | |
def __getitem__(self, key): | |
return getattr(self, key) | |
def __setitem__(self, key, value): | |
setattr(self, key, value) | |
def get(self, key, default=None): | |
return getattr(self, key, default) | |
""" | |
class_str += "\n" | |
return class_str | |
def get_python_type(pg_type: str) -> str: | |
type_map = { | |
'integer': 'int', | |
'bigint': 'int', | |
'smallint': 'int', | |
'text': 'str', | |
'character varying': 'str', | |
'boolean': 'bool', | |
'timestamp with time zone': 'Optional[datetime]', | |
'date': 'Optional[date]', | |
'time': 'Optional[time]', | |
'numeric': 'Optional[Decimal]', | |
'real': 'float', | |
'double precision': 'float', | |
'json': 'dict', | |
'jsonb': 'dict', | |
} | |
return type_map.get(pg_type, 'Any') | |
def generate_classes(): | |
response = supabase.rpc('get_public_schema_columns', {}).execute() | |
tables = {} | |
for row in response.data: | |
table_name = row['table_name'] | |
if table_name not in tables: | |
tables[table_name] = [] | |
tables[table_name].append(row) | |
with open("./models/schema.py", "w") as f: | |
f.write("from __future__ import annotations\n") | |
f.write("from typing import Optional, Any, Union\n") | |
f.write("from datetime import datetime, date, time\n") | |
f.write("from decimal import Decimal\n") | |
f.write("from dataclasses import dataclass\n") | |
f.write("import json\n\n") | |
for table_name, columns in tables.items(): | |
class_def = generate_class(table_name, columns) | |
f.write(class_def) | |
generate_classes() | |
print("Classes have been written to ./models/schema.py") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CREATE OR REPLACE FUNCTION get_public_schema_columns() | |
RETURNS TABLE ( | |
table_name text, | |
column_name text, | |
data_type text | |
) AS $$ | |
BEGIN | |
RETURN QUERY | |
SELECT c.table_name::text, c.column_name::text, c.data_type::text | |
FROM information_schema.columns c | |
WHERE c.table_schema = 'public' | |
ORDER BY c.table_name, c.ordinal_position; | |
END; | |
$$ LANGUAGE plpgsql; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Optional | |
from main import supabaseClient | |
from models.schema import Workflows | |
async def getActiveWorkflowsForUser(user_id: int) -> Optional[List[Workflows]]: | |
response = ( | |
supabaseClient.table("workflows") | |
.select("*") | |
.eq("user_id", user_id) | |
.eq("is_live", True) | |
.execute() | |
) | |
if response and response.data: | |
return [Workflows.from_dict(workflow_data) for workflow_data in response.data] | |
else: | |
return [] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment