Last active
June 9, 2019 21:30
-
-
Save noxecane/5dfebdbf64fc1167fbc91a9cfcfcc8e9 to your computer and use it in GitHub Desktop.
Simplify your life with Opaleye
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE NamedFieldPuns #-} | |
{-# LANGUAGE OverloadedStrings #-} | |
module OpaleyeExt | |
( MonadDB(..) | |
, ConstraintError(..) | |
, safeInsertOne | |
, safeInsertOneReturningId | |
, safeInsertOneReturningId' | |
, select_ | |
, selectOne_ | |
, selectToString | |
) where | |
import Control.Exception (throwIO) | |
import Control.Monad.Catch (catch) | |
import Data.Int (Int64) | |
import Data.Maybe (fromJust, fromMaybe, listToMaybe) | |
import Data.Profunctor.Product.Default (Default) | |
import Data.Text (Text) | |
import Data.Text.Encoding (decodeUtf8) | |
import Database.PostgreSQL.Simple (Connection, SqlError (..)) | |
import Opaleye | |
class (Monad m) => MonadDB m where | |
insertOne :: Default Constant hW pW => Table pW pR -> hW -> m Int64 | |
insertOneReturningId :: | |
(Default Constant hW pW, QueryRunnerColumnDefault a a) | |
=> Table pW pR -> (pR -> Column a) -> hW -> m (Maybe a) | |
selectOne :: (Default FromFields a b) => Select a -> m (Maybe b) | |
select :: (Default FromFields a b) => Select a -> m [b] | |
newtype ConstraintError = UniqueContraintError Text | |
deriving (Show) | |
selectToString :: Default Unpackspec a a => Select a -> String | |
selectToString = fromMaybe "Empty query" . showSqlForPostgres | |
selectOne_ :: (Default FromFields a b) => Connection -> Select a -> IO (Maybe b) | |
selectOne_ conn = fmap listToMaybe . runSelect conn . limit 1 | |
select_ :: (Default FromFields a b) => Connection -> Select a -> IO [b] | |
select_ = runSelect | |
safeInsertOne :: | |
Default Constant hW pW | |
=> Connection | |
-> Table pW pR | |
-> hW | |
-> IO (Either ConstraintError Int64) | |
safeInsertOne conn t h = runInsertSafe conn Insert | |
{ iTable = t | |
, iRows = [toFields h] | |
, iReturning = rCount | |
, iOnConflict = Nothing | |
} | |
safeInsertOneReturningId :: | |
(Default Constant hW pW, QueryRunnerColumnDefault a a) | |
=> Connection -> Table pW pR -> (pR -> Column a) -> hW -> IO (Either ConstraintError (Maybe a)) | |
safeInsertOneReturningId conn t rId h = fmap listToMaybe <$> runInsertSafe conn Insert | |
{ iTable = t | |
, iRows = [toFields h] | |
, iReturning = rReturning rId | |
, iOnConflict = Nothing | |
} | |
safeInsertOneReturningId' :: | |
(Default Constant hW pW, QueryRunnerColumnDefault a a) | |
=> Connection -> Table pW pR -> (pR -> Column a) -> hW -> IO (Either ConstraintError a) | |
safeInsertOneReturningId' conn t rId = fmap (fromJust <$>) . safeInsertOneReturningId conn t rId | |
runInsertSafe :: Connection -> Insert a -> IO (Either ConstraintError a) | |
runInsertSafe conn = flip catch handleConstraintErrors . correctInsert | |
where | |
correctInsert = fmap Right . runInsert_ conn | |
handleConstraintErrors :: SqlError -> IO (Either ConstraintError a) | |
handleConstraintErrors err@SqlError{ sqlState, sqlErrorDetail } = | |
let errorMessage = decodeUtf8 sqlErrorDetail | |
reportError = return . Left | |
in case sqlState of | |
"23505" -> reportError $ UniqueContraintError errorMessage | |
_ -> throwIO err |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE NamedFieldPuns #-} | |
{-# LANGUAGE TemplateHaskell #-} | |
module OpaleyeTH | |
( HasDefault, NonNullable, NullableDefault, Nullable | |
, Interpretation, TypeNamer | |
, mkTable, mkTable' | |
, mkTypes | |
) where | |
import Base.Prelude (capitalise, dashify, uncapitalise, unprefix) | |
import Control.Monad (foldM, replicateM) | |
import Data.List.Split (splitOn) | |
import Language.Haskell.TH | |
import Language.Haskell.TH.Syntax (VarBangType) | |
import Opaleye (Field, FieldNullable, Table) | |
-- | Descriptor for a "not null" column | |
data NonNullable r w | |
-- | Descriptor for a column that can be null | |
data Nullable r w | |
-- | Descriptor for a column with default values | |
data HasDefault r w | |
-- | Descriptor for a column that can be null with default values | |
data NullableDefault r w | |
-- Descriptors as internal type | |
data ContraintType = NonNullableConst -- NonNullable | |
| NullableConst -- Nullable | |
| DefaultConst -- HasDefault | |
| NullableDefaultConst -- NullDefault | |
-- internal dictionary for computation | |
data Env = Env | |
{ typeName :: Name | |
, typeNameStr :: String | |
, typeConstructor :: Name | |
, fields :: [(Name, Bang, Type)] | |
, derivations :: [DerivClause] | |
} deriving Show | |
-- | Interpretation modes | |
data Interpretation = PostgresW -- ^ Interpretation for opaleye to write into postgres | |
| PostgresR -- ^ Interpretation for opaleye to read from postgres | |
| HaskW -- ^ Interpretation for users to write to opaleye | |
| HaskR -- ^ Interpretation for users to read from opaleye | |
-- | Function to translate the model name for each interpretation mode | |
type TypeNamer = Interpretation -> String -> String | |
-- | Declare opaleye types using a single record declaration. | |
-- For example: | |
-- | |
-- > $(declareRecords | |
-- > [d| data Route = Route | |
-- > { id :: 'HasDefault' Int PGInt4 | |
-- > , createdAt :: 'HasDefault' UTCTime PGTimestamptz | |
-- > , duration :: 'Nullable' String PGText | |
-- > , terminal :: 'NonNullable' Int PGInt4 | |
-- > , destination :: 'NonNullable' Text PGText | |
-- > } | |
-- > |]) | |
-- | |
-- This translates to: | |
-- | |
-- > data Route a_a8Jc a_a8Jd a_a8Je a_a8Jf a_a8Jg | |
-- > = Route {routeId :: a_a8Jc, | |
-- > routeCreatedAt :: a_a8Jd, | |
-- > routeDuration :: a_a8Je, | |
-- > routeTerminal :: a_a8Jf, | |
-- > routeDestination :: a_a8Jg} | |
-- > type RouteR = Route Int UTCTime (Maybe String) Int Text | |
-- > type RouteW = | |
-- > Route (Maybe Int) (Maybe UTCTime) (Maybe String) Int Text | |
-- > type RoutePR = | |
-- > Route (Field PGInt4) (Field PGTimestamptz) (FieldNullable PGText) (Field PGInt4) (Field PGText) | |
-- > type RoutePW = | |
-- > Route (Maybe (Field PGInt4)) (Maybe (Field PGTimestamptz)) (FieldNullable PGText) (Field PGInt4) (Field PGText) | |
mkTypes :: Q [Dec] -> Q [Dec] | |
mkTypes baseRecords = concat <$> (mapM mkType =<< baseRecords) | |
-- | Generates a table definition for a polymorphic type generated by @mkType@ and | |
-- an adaptor generated by makeAdaptorAndInstance | |
-- For example: | |
-- | |
-- > $('mkTable' "my_route" "pRoute" Route'')@ | |
-- | |
-- translates to | |
-- | |
-- > routeTable :: Table RoutePW RoutePR | |
-- > routeTable = table "route" (pRoute Route | |
-- > { routeId = tableField "id" | |
-- > , routeCreatedAt = tableField "created_at" | |
-- > , routeDuration = tableField "duration" | |
-- > , routeTerminal = tableField "terminal" | |
-- > , routeDestination = tableField "destination" | |
-- > }) | |
-- | |
-- Notice that it uses @'dashify'@ on the field names. | |
mkTable :: String -> String -> Name -> Q [Dec] | |
mkTable tableName profunctor typeName = do | |
typeInfo <- reify typeName | |
let (typeNameStr, constructor, fields) = getCons typeInfo | |
tableSig <- mkTableSignature typeNameStr | |
tableDefn <- mkTableBody profunctor tableName typeNameStr constructor fields | |
return $ tableSig : [ValD (VarP $ haskTableName typeNameStr) (NormalB tableDefn) []] | |
where | |
getCons (TyConI (DataD _ name _ _ [RecC constructor fields] _)) = | |
(nameBase name, constructor, fmap getFieldName fields) | |
getCons _ = | |
error "Expected a type" | |
getFieldName (n, _, _) = n | |
-- | Like @'mkTable'@ but the table name is extracted from the @Name@ passed | |
-- with @'dashify'@ applied to it. | |
mkTable' :: String -> Name -> Q [Dec] | |
mkTable' profunctor typeName = mkTable tableName profunctor typeName | |
where tableName = dashify $ nameBase typeName | |
mkType :: Dec -> Q [Dec] | |
mkType decl = do | |
let env = extractEnv decl | |
synCreator = mkTypeSynonym defaultTypeNamer env | |
let declarations = | |
[ mkPolymorhicRecord env | |
, synCreator HaskR | |
, synCreator HaskW | |
, synCreator PostgresR | |
, synCreator PostgresW | |
] | |
sequence declarations | |
extractEnv :: Dec -> Env | |
extractEnv (DataD _ name _ _ [RecC cons fields] derivations) = | |
Env cleanName cleanNameStr consName fields derivations | |
where | |
consName = mkName $ stripRandomizer cons | |
cleanNameStr = stripRandomizer name | |
cleanName = mkName cleanNameStr | |
extractEnv _ = error "Expected data declaration" | |
defaultTypeNamer :: TypeNamer | |
defaultTypeNamer PostgresW = (++ "PW") | |
defaultTypeNamer PostgresR = (++ "PR") | |
defaultTypeNamer HaskW = (++ "W") | |
defaultTypeNamer HaskR = (++ "R") | |
mkTypeSynonym :: TypeNamer -> Env -> Interpretation -> Q Dec | |
mkTypeSynonym namer Env { typeNameStr, typeConstructor, fields } interpretation = do | |
let synName = mkName $ namer interpretation typeNameStr | |
synTyCons = ConT typeConstructor | |
synType <- mkTypeSynType synTyCons interpretation fields | |
return $ TySynD synName [] synType | |
mkPolymorhicRecord :: Env -> Q Dec | |
mkPolymorhicRecord Env { typeNameStr, typeName, fields, typeConstructor, derivations } = do | |
polyVars <- replicateM (length fields) (newName "a") | |
let conTypeVars = fmap PlainTV polyVars | |
polyFields = zipWith polify polyVars fields | |
return $ DataD [] typeName conTypeVars Nothing [RecC typeConstructor polyFields] derivations | |
where | |
recField prefix = mkName . (prefix ++) . capitalise . stripRandomizer | |
polyRecField = recField (uncapitalise typeNameStr) | |
polify varName (fieldName, bang', _) = (polyRecField fieldName, bang', VarT varName) | |
mkTableSignature :: String -> Q Dec | |
mkTableSignature typeNameStr = sigD (haskTableName typeNameStr) [t|Table $(typeSyn PostgresW) $(typeSyn PostgresR)|] | |
where typeSyn = conT . mkName . flip defaultTypeNamer typeNameStr | |
mkTableBody :: String -> String -> String -> Name -> [Name] -> Q Exp | |
mkTableBody profunctor tableName typeNameStr consName fieldNames = | |
[e|table $(stringExp tableName) ($(pfExp profunctor) $(columnsDef))|] | |
where | |
columnsDef = recConE consName (fmap fieldExp' fieldNames) | |
fieldExp' x = fieldExp x [e|tableField $(pgField x)|] | |
stringExp = litE . StringL | |
pfExp = varE . mkName | |
pgField = stringExp . noKeywords . dashify . unprefix typeNameStr . nameBase | |
noKeywords :: String -> String | |
noKeywords w | |
| last w == '_' = init w | |
| otherwise = w | |
haskTableName :: String -> Name | |
haskTableName = mkName . (++ "Table") . uncapitalise | |
stripRandomizer :: Name -> String | |
stripRandomizer = head . splitOn "_" . nameBase | |
mkTypeSynType :: Type -> Interpretation -> [VarBangType] -> Q Type | |
mkTypeSynType synTyCons interpretation fields = | |
foldM appT' synTyCons =<< mapM interpreteField fields | |
where | |
interpreteField (_, _, AppT (AppT desc rType) wType) = | |
interpreteDescriptor interpretation (fromDescriptor desc) rType wType | |
interpreteField _ = | |
error "Expected descriptor(Nullable, NonNullable...)" | |
appT' m = return . AppT m | |
interpreteDescriptor :: Interpretation -> ContraintType -> Type -> Type -> Q Type | |
interpreteDescriptor HaskR NullableConst hType _ = [t|Maybe $(return hType)|] | |
interpreteDescriptor HaskR NullableDefaultConst hType _ = [t|Maybe $(return hType)|] | |
interpreteDescriptor HaskR _ hType _ = return hType | |
interpreteDescriptor HaskW NonNullableConst hType _ = return hType | |
interpreteDescriptor HaskW _ hType _ = [t|Maybe $(return hType)|] | |
interpreteDescriptor PostgresR NullableConst _ pType = [t|FieldNullable $(return pType)|] | |
interpreteDescriptor PostgresR NullableDefaultConst _ pType = [t|FieldNullable $(return pType)|] | |
interpreteDescriptor PostgresR _ _ pType = [t|Field $(return pType)|] | |
interpreteDescriptor PostgresW NonNullableConst _ pType = [t|Field $(return pType)|] | |
interpreteDescriptor PostgresW NullableConst _ pType = [t|FieldNullable $(return pType)|] | |
interpreteDescriptor PostgresW DefaultConst _ pType = [t|Maybe (Field $(return pType))|] | |
interpreteDescriptor PostgresW NullableDefaultConst _ pType = [t|Maybe (FieldNullable $(return pType))|] | |
fromDescriptor :: Type -> ContraintType | |
fromDescriptor (ConT desc) | |
| descStr == "NonNullable" = NonNullableConst | |
| descStr == "Nullable" = NullableConst | |
| descStr == "HasDefault" = DefaultConst | |
| descStr == "NullDefault" = NullableConst | |
| otherwise = error "Uninterpretable Descriptor" | |
where descStr = nameBase desc | |
fromDescriptor _ = error "Uninterpretable Descriptor" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment