Created
November 22, 2020 20:12
-
-
Save MaxGabriel/aae96f6f8a72d0cfb5e8f98f426e29a1 to your computer and use it in GitHub Desktop.
Template Haskell to load all Persistent models, stream them from the database, and validate they deserialize correctly
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE AllowAmbiguousTypes #-} | |
module Mercury.Database.Persist.DeriveLoadAllModels (mkLoadAllModels) where | |
import ClassyPrelude | |
import Control.Monad.Logger (MonadLogger, logInfoN) | |
import Data.Acquire (with) | |
import Data.Conduit (fuse, runConduit) | |
import qualified Data.Conduit.Combinators as Conduit | |
import qualified Data.Kind as K | |
import qualified Data.Text as T | |
import Database.Persist (DBName (..), PersistEntity, PersistEntityBackend, PersistQueryRead (..), PersistValue (..), selectSourceRes) | |
import Database.Persist.Sql (Filter (..), Single (..), SqlBackend, rawSql) | |
import Database.Persist.Types (EntityDef (..), HaskellName (..)) | |
import Language.Haskell.TH.Syntax | |
import Mercury.Timing (timeAction) | |
skipLoadAllModelsAttribute :: Text | |
skipLoadAllModelsAttribute = "!skipLoadAllModels" | |
-- | Signature for the function generated by mkLoadAllModels | |
-- | |
-- (I declare this here, so I don't have to construct this out of template-haskell primitives) | |
type LoadAllModelsSignature = forall (m :: K.Type -> K.Type). (MonadUnliftIO m, MonadLogger m) => (ReaderT SqlBackend m () -> m ()) -> m () | |
-- | Given a function to run a transaction, runs 'processTable' from a MonadUnliftIO monad. | |
-- | |
-- (This could probably be combined with 'processTable', but I lost several hours to compiler errors so just settling for this) | |
processTableIO :: forall record (m :: K.Type -> K.Type). (MonadLogger m, MonadUnliftIO m, PersistEntity record, PersistEntityBackend record ~ SqlBackend) => HaskellName -> DBName -> Bool -> (ReaderT SqlBackend m () -> m ()) -> m () | |
processTableIO haskellName dbName shouldSkip runTrx = do | |
runTrx $ processTable @record haskellName dbName shouldSkip | |
pure () | |
-- | In a transaction, do the following: | |
-- | |
-- * Give an estimate of row count for a table | |
-- * Stream loading every row from that table, validating the model deserializes | |
-- * Report the results | |
processTable :: | |
forall record (m :: K.Type -> K.Type). | |
(MonadUnliftIO m, PersistEntity record, PersistEntityBackend record ~ SqlBackend, MonadLogger m) => | |
HaskellName -> | |
DBName -> | |
Bool -> | |
ReaderT SqlBackend m () | |
processTable haskellName tableName skip = do | |
let name = unHaskellName haskellName | |
tableNameText = unDBName tableName | |
if skip | |
then logInfoN ("Skipping " <> name <> " ; it had the" <> skipLoadAllModelsAttribute <> "attribute.") | |
else do | |
estimate <- getEstimatedRowCount tableName | |
logInfoN ("Starting to load rows for " <> name <> ". Estimated count: " <> tshow estimate) | |
-- Potential improvemement: Catch exceptions when loading, so if there are issues with multiple tables you find them all. | |
(actualRowCount :: Int64, seconds) <- timeAction (streamRowCount @record) | |
logInfoN $ | |
T.concat | |
[ "Loaded ", | |
tshow actualRowCount, | |
" rows from ", | |
tableNameText, | |
" in ", | |
tshow seconds, -- Future improvement: format to 3 decimal places or something | |
" seconds." | |
] | |
-- | Gets a quick estimate of the number of rows for a given table. | |
-- | |
-- This is dramatically faster than using COUNT(*). The goal is to just give the user an idea of how long it will take to load. | |
getEstimatedRowCount :: forall (m :: K.Type -> K.Type). (MonadIO m) => DBName -> ReaderT SqlBackend m Int64 | |
getEstimatedRowCount tableName = do | |
let persistTableName = PersistText $ unDBName tableName | |
(pv :: [Single PersistValue]) <- rawSql "SELECT (reltuples :: bigint) FROM pg_class WHERE relname = ?" [persistTableName] | |
case pv of | |
[Single (PersistInt64 estimatedRows)] -> pure estimatedRows | |
unexpected -> error $ "Expected a single row containing an integer; got: " <> show unexpected | |
-- | Streams the contents of an entire table row-by-row. | |
-- | |
-- I /think/ this will run in constant memory on the Haskell side, and is probably better for the database too. | |
-- In practice, I'm not sure if that works out (loading all our tables like this takes about a gigabyte, so maybe?) | |
streamRowCount :: forall record m. (MonadIO m, PersistEntity record, PersistEntityBackend record ~ SqlBackend) => ReaderT SqlBackend m Int64 | |
streamRowCount = do | |
srcRes <- selectSourceRes ([] :: [Filter record]) [] | |
liftIO $ with srcRes (\src -> runConduit $ src `fuse` Conduit.foldl (\prev _data -> prev + 1) 0) | |
-- | Creates a function of the given name, that loads every model from our database | |
-- | |
-- The goal of the generated function is to check that our deserialization code is valid for every model. | |
-- You can ignore a given model by adding the !skipLoadAllModels attribute to the entity | |
-- | |
-- An alternative approach to flagging tables to skip, is to flag tables based on how slow they are to load. | |
-- So e.g. QueuedJob might load at "Glacial" speed, FrontEventWebhook might load at "VerySlow" speed, and the user can choose how what threshold to run this for. | |
mkLoadAllModels :: String -> [EntityDef] -> Q [Dec] | |
mkLoadAllModels fnName entityDefs = do | |
let typ = ConT ''LoadAllModelsSignature | |
runTrxName <- newName "runTransaction" | |
let runTrxPat = VarP runTrxName | |
body <- body' runTrxName | |
return | |
[ SigD (mkName fnName) typ, | |
FunD (mkName fnName) [Clause [runTrxPat] (NormalB body) []] | |
] | |
where | |
body' :: Name -> Q Exp | |
body' runTrxName = | |
case entityDefs of | |
[] -> [|return ()|] | |
_ -> do | |
exps <- mapM (loadAllForTable runTrxName) entityDefs | |
sequence_E <- [|sequence_|] | |
pure $ sequence_E `AppE` ListE exps | |
-- | Generates a function to load all models for a given table. | |
loadAllForTable :: Name -> EntityDef -> Q Exp | |
loadAllForTable runTrxName entityDef = do | |
let name = entityHaskell entityDef | |
recordType = ConT $ mkName $ T.unpack $ unHaskellName name | |
shouldSkip = skipLoadAllModelsAttribute `elem` entityAttrs entityDef | |
-- I tried to do this all in one [| ... |] section | |
-- But I got the error that Type (what recordType is) is not an instance of Lift | |
-- Not sure if there's a workaround. I had it working before with selectList ([] :: [Filter $(return recordType)]) [] | |
fn <- [|processTableIO|] | |
arg1 <- [|entityHaskell entityDef|] | |
arg2 <- [|entityDB entityDef|] | |
arg3 <- [|shouldSkip|] | |
let arg4 = VarE runTrxName | |
-- We could pass in the whole EntityDef here instead | |
-- I prefer to pass only what is needed, to reduce total generated codesize (EntityDef contains a lot of data) | |
pure $ (fn `AppTypeE` recordType) `AppE` arg1 `AppE` arg2 `AppE` arg3 `AppE` arg4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- | Functions to help time actions | |
module Mercury.Timing | |
( StartTime (..), | |
getStartTime, | |
getElapsedSeconds, | |
timeAction, | |
) | |
where | |
import ClassyPrelude.Yesod | |
import Data.Ratio ((%)) | |
import System.Clock (Clock (..), TimeSpec, diffTimeSpec, getTime, toNanoSecs) | |
-- | Newtype wrapper to designate a certain time as a starting time. | |
-- Pass this to 'getElapsedSeconds' to see how long an action took. | |
newtype StartTime = StartTime TimeSpec | |
-- | Get the current time specifically as a StartTime | |
getStartTime :: MonadIO m => m StartTime | |
getStartTime = do | |
-- TODO: would using CoarseMonotonic on linux be good for a speedup? | |
start <- liftIO $ getTime Monotonic | |
pure $ StartTime start | |
-- | Gives the time passed in seconds since the 'StartTime' | |
getElapsedSeconds :: MonadIO m => StartTime -> m Double | |
getElapsedSeconds (StartTime start) = do | |
-- TODO: would using CoarseMonotonic on linux be good for a speedup? | |
end <- liftIO $ getTime Monotonic | |
-- Copied from https://github.com/fimad/prometheus-haskell/blob/ec1e3d30bd59113b0184869fc12e7d6fb7251248/wai-middleware-prometheus/src/Network/Wai/Middleware/Prometheus.hs#L154 | |
pure $ fromRational (toNanoSecs (end `diffTimeSpec` start) % 1000000000) | |
-- Helper function to time an action, returning the time it look to complete it in seconds | |
timeAction :: MonadIO m => m a -> m (a, Double) | |
timeAction action = do | |
start <- getStartTime | |
result <- action | |
duration <- getElapsedSeconds start | |
pure (result, duration) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
share | |
[mkLoadAllModels "loadAllPersistentModels"] | |
$( persistManyFileWith | |
lowerCaseSettings | |
allModelFiles | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment