Skip to content

Instantly share code, notes, and snippets.

@noxecane
Last active June 9, 2019 21:30
Show Gist options
  • Save noxecane/5dfebdbf64fc1167fbc91a9cfcfcc8e9 to your computer and use it in GitHub Desktop.
Save noxecane/5dfebdbf64fc1167fbc91a9cfcfcc8e9 to your computer and use it in GitHub Desktop.
Simplify your life with Opaleye
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
module OpaleyeExt
( MonadDB(..)
, ConstraintError(..)
, safeInsertOne
, safeInsertOneReturningId
, safeInsertOneReturningId'
, select_
, selectOne_
, selectToString
) where
import Control.Exception (throwIO)
import Control.Monad.Catch (catch)
import Data.Int (Int64)
import Data.Maybe (fromJust, fromMaybe, listToMaybe)
import Data.Profunctor.Product.Default (Default)
import Data.Text (Text)
import Data.Text.Encoding (decodeUtf8)
import Database.PostgreSQL.Simple (Connection, SqlError (..))
import Opaleye
class (Monad m) => MonadDB m where
insertOne :: Default Constant hW pW => Table pW pR -> hW -> m Int64
insertOneReturningId ::
(Default Constant hW pW, QueryRunnerColumnDefault a a)
=> Table pW pR -> (pR -> Column a) -> hW -> m (Maybe a)
selectOne :: (Default FromFields a b) => Select a -> m (Maybe b)
select :: (Default FromFields a b) => Select a -> m [b]
newtype ConstraintError = UniqueContraintError Text
deriving (Show)
selectToString :: Default Unpackspec a a => Select a -> String
selectToString = fromMaybe "Empty query" . showSqlForPostgres
selectOne_ :: (Default FromFields a b) => Connection -> Select a -> IO (Maybe b)
selectOne_ conn = fmap listToMaybe . runSelect conn . limit 1
select_ :: (Default FromFields a b) => Connection -> Select a -> IO [b]
select_ = runSelect
safeInsertOne ::
Default Constant hW pW
=> Connection
-> Table pW pR
-> hW
-> IO (Either ConstraintError Int64)
safeInsertOne conn t h = runInsertSafe conn Insert
{ iTable = t
, iRows = [toFields h]
, iReturning = rCount
, iOnConflict = Nothing
}
safeInsertOneReturningId ::
(Default Constant hW pW, QueryRunnerColumnDefault a a)
=> Connection -> Table pW pR -> (pR -> Column a) -> hW -> IO (Either ConstraintError (Maybe a))
safeInsertOneReturningId conn t rId h = fmap listToMaybe <$> runInsertSafe conn Insert
{ iTable = t
, iRows = [toFields h]
, iReturning = rReturning rId
, iOnConflict = Nothing
}
safeInsertOneReturningId' ::
(Default Constant hW pW, QueryRunnerColumnDefault a a)
=> Connection -> Table pW pR -> (pR -> Column a) -> hW -> IO (Either ConstraintError a)
safeInsertOneReturningId' conn t rId = fmap (fromJust <$>) . safeInsertOneReturningId conn t rId
runInsertSafe :: Connection -> Insert a -> IO (Either ConstraintError a)
runInsertSafe conn = flip catch handleConstraintErrors . correctInsert
where
correctInsert = fmap Right . runInsert_ conn
handleConstraintErrors :: SqlError -> IO (Either ConstraintError a)
handleConstraintErrors err@SqlError{ sqlState, sqlErrorDetail } =
let errorMessage = decodeUtf8 sqlErrorDetail
reportError = return . Left
in case sqlState of
"23505" -> reportError $ UniqueContraintError errorMessage
_ -> throwIO err
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TemplateHaskell #-}
module OpaleyeTH
( HasDefault, NonNullable, NullableDefault, Nullable
, Interpretation, TypeNamer
, mkTable, mkTable'
, mkTypes
) where
import Base.Prelude (capitalise, dashify, uncapitalise, unprefix)
import Control.Monad (foldM, replicateM)
import Data.List.Split (splitOn)
import Language.Haskell.TH
import Language.Haskell.TH.Syntax (VarBangType)
import Opaleye (Field, FieldNullable, Table)
-- | Descriptor for a "not null" column
data NonNullable r w
-- | Descriptor for a column that can be null
data Nullable r w
-- | Descriptor for a column with default values
data HasDefault r w
-- | Descriptor for a column that can be null with default values
data NullableDefault r w
-- Descriptors as internal type
data ContraintType = NonNullableConst -- NonNullable
| NullableConst -- Nullable
| DefaultConst -- HasDefault
| NullableDefaultConst -- NullDefault
-- internal dictionary for computation
data Env = Env
{ typeName :: Name
, typeNameStr :: String
, typeConstructor :: Name
, fields :: [(Name, Bang, Type)]
, derivations :: [DerivClause]
} deriving Show
-- | Interpretation modes
data Interpretation = PostgresW -- ^ Interpretation for opaleye to write into postgres
| PostgresR -- ^ Interpretation for opaleye to read from postgres
| HaskW -- ^ Interpretation for users to write to opaleye
| HaskR -- ^ Interpretation for users to read from opaleye
-- | Function to translate the model name for each interpretation mode
type TypeNamer = Interpretation -> String -> String
-- | Declare opaleye types using a single record declaration.
-- For example:
--
-- > $(declareRecords
-- > [d| data Route = Route
-- > { id :: 'HasDefault' Int PGInt4
-- > , createdAt :: 'HasDefault' UTCTime PGTimestamptz
-- > , duration :: 'Nullable' String PGText
-- > , terminal :: 'NonNullable' Int PGInt4
-- > , destination :: 'NonNullable' Text PGText
-- > }
-- > |])
--
-- This translates to:
--
-- > data Route a_a8Jc a_a8Jd a_a8Je a_a8Jf a_a8Jg
-- > = Route {routeId :: a_a8Jc,
-- > routeCreatedAt :: a_a8Jd,
-- > routeDuration :: a_a8Je,
-- > routeTerminal :: a_a8Jf,
-- > routeDestination :: a_a8Jg}
-- > type RouteR = Route Int UTCTime (Maybe String) Int Text
-- > type RouteW =
-- > Route (Maybe Int) (Maybe UTCTime) (Maybe String) Int Text
-- > type RoutePR =
-- > Route (Field PGInt4) (Field PGTimestamptz) (FieldNullable PGText) (Field PGInt4) (Field PGText)
-- > type RoutePW =
-- > Route (Maybe (Field PGInt4)) (Maybe (Field PGTimestamptz)) (FieldNullable PGText) (Field PGInt4) (Field PGText)
mkTypes :: Q [Dec] -> Q [Dec]
mkTypes baseRecords = concat <$> (mapM mkType =<< baseRecords)
-- | Generates a table definition for a polymorphic type generated by @mkType@ and
-- an adaptor generated by makeAdaptorAndInstance
-- For example:
--
-- > $('mkTable' "my_route" "pRoute" Route'')@
--
-- translates to
--
-- > routeTable :: Table RoutePW RoutePR
-- > routeTable = table "route" (pRoute Route
-- > { routeId = tableField "id"
-- > , routeCreatedAt = tableField "created_at"
-- > , routeDuration = tableField "duration"
-- > , routeTerminal = tableField "terminal"
-- > , routeDestination = tableField "destination"
-- > })
--
-- Notice that it uses @'dashify'@ on the field names.
mkTable :: String -> String -> Name -> Q [Dec]
mkTable tableName profunctor typeName = do
typeInfo <- reify typeName
let (typeNameStr, constructor, fields) = getCons typeInfo
tableSig <- mkTableSignature typeNameStr
tableDefn <- mkTableBody profunctor tableName typeNameStr constructor fields
return $ tableSig : [ValD (VarP $ haskTableName typeNameStr) (NormalB tableDefn) []]
where
getCons (TyConI (DataD _ name _ _ [RecC constructor fields] _)) =
(nameBase name, constructor, fmap getFieldName fields)
getCons _ =
error "Expected a type"
getFieldName (n, _, _) = n
-- | Like @'mkTable'@ but the table name is extracted from the @Name@ passed
-- with @'dashify'@ applied to it.
mkTable' :: String -> Name -> Q [Dec]
mkTable' profunctor typeName = mkTable tableName profunctor typeName
where tableName = dashify $ nameBase typeName
mkType :: Dec -> Q [Dec]
mkType decl = do
let env = extractEnv decl
synCreator = mkTypeSynonym defaultTypeNamer env
let declarations =
[ mkPolymorhicRecord env
, synCreator HaskR
, synCreator HaskW
, synCreator PostgresR
, synCreator PostgresW
]
sequence declarations
extractEnv :: Dec -> Env
extractEnv (DataD _ name _ _ [RecC cons fields] derivations) =
Env cleanName cleanNameStr consName fields derivations
where
consName = mkName $ stripRandomizer cons
cleanNameStr = stripRandomizer name
cleanName = mkName cleanNameStr
extractEnv _ = error "Expected data declaration"
defaultTypeNamer :: TypeNamer
defaultTypeNamer PostgresW = (++ "PW")
defaultTypeNamer PostgresR = (++ "PR")
defaultTypeNamer HaskW = (++ "W")
defaultTypeNamer HaskR = (++ "R")
mkTypeSynonym :: TypeNamer -> Env -> Interpretation -> Q Dec
mkTypeSynonym namer Env { typeNameStr, typeConstructor, fields } interpretation = do
let synName = mkName $ namer interpretation typeNameStr
synTyCons = ConT typeConstructor
synType <- mkTypeSynType synTyCons interpretation fields
return $ TySynD synName [] synType
mkPolymorhicRecord :: Env -> Q Dec
mkPolymorhicRecord Env { typeNameStr, typeName, fields, typeConstructor, derivations } = do
polyVars <- replicateM (length fields) (newName "a")
let conTypeVars = fmap PlainTV polyVars
polyFields = zipWith polify polyVars fields
return $ DataD [] typeName conTypeVars Nothing [RecC typeConstructor polyFields] derivations
where
recField prefix = mkName . (prefix ++) . capitalise . stripRandomizer
polyRecField = recField (uncapitalise typeNameStr)
polify varName (fieldName, bang', _) = (polyRecField fieldName, bang', VarT varName)
mkTableSignature :: String -> Q Dec
mkTableSignature typeNameStr = sigD (haskTableName typeNameStr) [t|Table $(typeSyn PostgresW) $(typeSyn PostgresR)|]
where typeSyn = conT . mkName . flip defaultTypeNamer typeNameStr
mkTableBody :: String -> String -> String -> Name -> [Name] -> Q Exp
mkTableBody profunctor tableName typeNameStr consName fieldNames =
[e|table $(stringExp tableName) ($(pfExp profunctor) $(columnsDef))|]
where
columnsDef = recConE consName (fmap fieldExp' fieldNames)
fieldExp' x = fieldExp x [e|tableField $(pgField x)|]
stringExp = litE . StringL
pfExp = varE . mkName
pgField = stringExp . noKeywords . dashify . unprefix typeNameStr . nameBase
noKeywords :: String -> String
noKeywords w
| last w == '_' = init w
| otherwise = w
haskTableName :: String -> Name
haskTableName = mkName . (++ "Table") . uncapitalise
stripRandomizer :: Name -> String
stripRandomizer = head . splitOn "_" . nameBase
mkTypeSynType :: Type -> Interpretation -> [VarBangType] -> Q Type
mkTypeSynType synTyCons interpretation fields =
foldM appT' synTyCons =<< mapM interpreteField fields
where
interpreteField (_, _, AppT (AppT desc rType) wType) =
interpreteDescriptor interpretation (fromDescriptor desc) rType wType
interpreteField _ =
error "Expected descriptor(Nullable, NonNullable...)"
appT' m = return . AppT m
interpreteDescriptor :: Interpretation -> ContraintType -> Type -> Type -> Q Type
interpreteDescriptor HaskR NullableConst hType _ = [t|Maybe $(return hType)|]
interpreteDescriptor HaskR NullableDefaultConst hType _ = [t|Maybe $(return hType)|]
interpreteDescriptor HaskR _ hType _ = return hType
interpreteDescriptor HaskW NonNullableConst hType _ = return hType
interpreteDescriptor HaskW _ hType _ = [t|Maybe $(return hType)|]
interpreteDescriptor PostgresR NullableConst _ pType = [t|FieldNullable $(return pType)|]
interpreteDescriptor PostgresR NullableDefaultConst _ pType = [t|FieldNullable $(return pType)|]
interpreteDescriptor PostgresR _ _ pType = [t|Field $(return pType)|]
interpreteDescriptor PostgresW NonNullableConst _ pType = [t|Field $(return pType)|]
interpreteDescriptor PostgresW NullableConst _ pType = [t|FieldNullable $(return pType)|]
interpreteDescriptor PostgresW DefaultConst _ pType = [t|Maybe (Field $(return pType))|]
interpreteDescriptor PostgresW NullableDefaultConst _ pType = [t|Maybe (FieldNullable $(return pType))|]
fromDescriptor :: Type -> ContraintType
fromDescriptor (ConT desc)
| descStr == "NonNullable" = NonNullableConst
| descStr == "Nullable" = NullableConst
| descStr == "HasDefault" = DefaultConst
| descStr == "NullDefault" = NullableConst
| otherwise = error "Uninterpretable Descriptor"
where descStr = nameBase desc
fromDescriptor _ = error "Uninterpretable Descriptor"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment