Skip to content

Instantly share code, notes, and snippets.

@a-r-d
Created July 26, 2024 00:10
Show Gist options
  • Save a-r-d/aab8e5efa02a60d959f13f77fc4e3664 to your computer and use it in GitHub Desktop.
Save a-r-d/aab8e5efa02a60d959f13f77fc4e3664 to your computer and use it in GitHub Desktop.
Generate supabase dataclasses from a tables that are subscriptable like dicts.
# note: run the sql file in your database first.
import os
from supabase import create_client, Client
from typing import List, Optional, Any, Union
# Ensure the models directory exists
os.makedirs("./models", exist_ok=True)
url: str = os.environ.get("SUPABASE_URL")
key: str = os.environ.get("SUPABASE_KEY")
supabase: Client = create_client(url, key)
def generate_datetime_parser():
return """
@staticmethod
def parse_datetime(dt_string: str) -> Optional[datetime]:
if not dt_string:
return None
try:
return datetime.fromisoformat(dt_string.replace('Z', '+00:00'))
except ValueError:
try:
from dateutil import parser
return parser.isoparse(dt_string)
except:
print(f"Error parsing datetime: {dt_string}")
return None
"""
def generate_from_dict_method(table_name: str, columns: List[dict]):
method = f"""
@classmethod
def from_dict(cls, data: Union[dict, str]) -> Optional['{table_name.capitalize()}']:
if isinstance(data, str):
try:
data = json.loads(data)
except json.JSONDecodeError:
print(f"Error: Unable to parse string as JSON: {{data[:100]}}")
return None
if not isinstance(data, dict):
print(f"Error: from_dict received unexpected type: {{type(data)}}")
return None
parsed_data = {{}}
"""
for column in columns:
col_name = column['column_name']
data_type = column['data_type']
if data_type in ['timestamp with time zone', 'date', 'time']:
method += f" parsed_data['{col_name}'] = cls.parse_datetime(data.get('{col_name}'))\n"
elif data_type == 'numeric':
method += f" parsed_data['{col_name}'] = Decimal(data.get('{col_name}')) if data.get('{col_name}') is not None else None\n"
else:
method += f" parsed_data['{col_name}'] = data.get('{col_name}')\n"
method += " return cls(**parsed_data)\n"
return method
def generate_class(table_name: str, columns: List[dict]) -> str:
class_str = f"""@dataclass
class {table_name.capitalize()}:
"""
for column in columns:
col_name = column['column_name']
data_type = column['data_type']
python_type = get_python_type(data_type)
class_str += f" {col_name}: {python_type}\n"
class_str += generate_datetime_parser()
class_str += generate_from_dict_method(table_name, columns)
# Add methods to make the class subscriptable
class_str += """
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
setattr(self, key, value)
def get(self, key, default=None):
return getattr(self, key, default)
"""
class_str += "\n"
return class_str
def get_python_type(pg_type: str) -> str:
type_map = {
'integer': 'int',
'bigint': 'int',
'smallint': 'int',
'text': 'str',
'character varying': 'str',
'boolean': 'bool',
'timestamp with time zone': 'Optional[datetime]',
'date': 'Optional[date]',
'time': 'Optional[time]',
'numeric': 'Optional[Decimal]',
'real': 'float',
'double precision': 'float',
'json': 'dict',
'jsonb': 'dict',
}
return type_map.get(pg_type, 'Any')
def generate_classes():
response = supabase.rpc('get_public_schema_columns', {}).execute()
tables = {}
for row in response.data:
table_name = row['table_name']
if table_name not in tables:
tables[table_name] = []
tables[table_name].append(row)
with open("./models/schema.py", "w") as f:
f.write("from __future__ import annotations\n")
f.write("from typing import Optional, Any, Union\n")
f.write("from datetime import datetime, date, time\n")
f.write("from decimal import Decimal\n")
f.write("from dataclasses import dataclass\n")
f.write("import json\n\n")
for table_name, columns in tables.items():
class_def = generate_class(table_name, columns)
f.write(class_def)
generate_classes()
print("Classes have been written to ./models/schema.py")
CREATE OR REPLACE FUNCTION get_public_schema_columns()
RETURNS TABLE (
table_name text,
column_name text,
data_type text
) AS $$
BEGIN
RETURN QUERY
SELECT c.table_name::text, c.column_name::text, c.data_type::text
FROM information_schema.columns c
WHERE c.table_schema = 'public'
ORDER BY c.table_name, c.ordinal_position;
END;
$$ LANGUAGE plpgsql;
from typing import List, Optional
from main import supabaseClient
from models.schema import Workflows
async def getActiveWorkflowsForUser(user_id: int) -> Optional[List[Workflows]]:
response = (
supabaseClient.table("workflows")
.select("*")
.eq("user_id", user_id)
.eq("is_live", True)
.execute()
)
if response and response.data:
return [Workflows.from_dict(workflow_data) for workflow_data in response.data]
else:
return []
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment