Last active
May 25, 2022 18:09
-
-
Save andrevdm/d564d7261eafdfded27db923c8b28cea to your computer and use it in GitHub Desktop.
Haskell RabbitMQ wrapper. Should handle channel and connection failures. Includes e.g. RPC with timeouts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE NoImplicitPrelude #-} | |
{-# LANGUAGE OverloadedStrings #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE NumericUnderscores #-} | |
{-# LANGUAGE RankNTypes #-} | |
module Rabbit | |
( defaultOpts | |
, mkConnection | |
, runTopicConsumer | |
, createTopicPublisher | |
, runFanoutConsumer | |
, createFanoutPublisher | |
, callRpc | |
, handleRpc | |
, createRpcCaller | |
, Connection | |
, Channel | |
, Opts(..) | |
, RpcResult(..) | |
) where | |
import Verset | |
import qualified Data.ByteString.Lazy as BSL | |
import qualified Data.UUID.V4 as UU | |
import Control.Concurrent.STM (atomically) | |
import qualified Control.Concurrent.STM.TVar as Tv | |
import qualified Control.Concurrent.MVar as Mv | |
import Control.Exception.Safe (catch, throwM, catches, throwString, fromException, Handler(..)) | |
import qualified Data.Map.Strict as Map | |
import qualified Network.AMQP as Rb | |
import qualified Network.AMQP.Types as Rb | |
data ConnectionState = ConnectionState | |
{ csConnection :: !Rb.Connection | |
, csTerminate :: !Bool | |
} | |
data Connection = Connection | |
{ conState :: !(Mv.MVar ConnectionState) | |
, conOpts :: !Opts | |
, conRabbitOpts :: !Rb.ConnectionOpts | |
, conChannels :: !(Tv.TVar (Map Int Channel)) | |
, conNextId :: !(Tv.TVar Int) | |
} | |
data Opts = Opts | |
{ opRetyConnectionBackoffMilliseconds :: ![Int] | |
, opRetyConnectionMaxAttempts :: !(Maybe Int) | |
} | |
data Channel = Channel | |
{ chConnection :: !Connection | |
, chOnRestart :: !(Channel -> IO ()) | |
, chRabbitChan :: !(Tv.TVar Rb.Channel) | |
, chId :: !Int | |
} | |
defaultOpts :: Opts | |
defaultOpts = | |
Opts | |
{ opRetyConnectionBackoffMilliseconds = [100, 250, 475, 813, 1_319, 2_078, 3_217, 4_926] -- ^ milliseconds to delay between connection attempts | |
, opRetyConnectionMaxAttempts = Just 2_000 -- ^ How many times to retry connection. Nothing = never | |
} | |
-- | Try create a rabbit connection, can wait for server to be up | |
tryNewRabbitConnection :: Rb.ConnectionOpts -> Opts -> Int -> Int -> IO Rb.Connection | |
tryNewRabbitConnection ropts copts retriesLeft atRetry = do | |
-- Loop until connected | |
catch | |
(Rb.openConnection'' ropts) | |
(\(ex::SomeException) -> | |
if retriesLeft <= 0 | |
then throwM ex | |
else do | |
let | |
backoff = opRetyConnectionBackoffMilliseconds copts | |
-- Largest possible back off, default of 200ms if none was set | |
largestBackoffMs = fromMaybe 200 . lastMay $ backoff | |
-- Delay for current number of retries | |
delay = 1_000 * fromMaybe largestBackoffMs (atMay backoff atRetry) | |
threadDelay delay | |
tryNewRabbitConnection ropts copts (retriesLeft - 1) (atRetry + 1) | |
) | |
mkConnection :: Rb.ConnectionOpts -> Opts -> IO Connection | |
mkConnection ropts copts = do | |
rconn <- tryNewRabbitConnection ropts copts (fromMaybe 0 $ opRetyConnectionMaxAttempts copts) 0 | |
cs <- Mv.newMVar $ | |
ConnectionState | |
{ csConnection = rconn | |
, csTerminate = False | |
} | |
channels <- Tv.newTVarIO mempty | |
nextId <- Tv.newTVarIO 0 | |
let conn = | |
Connection | |
{ conState = cs | |
, conOpts = copts | |
, conRabbitOpts = ropts | |
, conChannels = channels | |
, conNextId = nextId | |
} | |
Rb.addConnectionClosedHandler rconn True (onConnectionClosed conn) | |
pure conn | |
onConnectionClosed :: Connection -> IO () | |
onConnectionClosed conn = do | |
csOld <- Mv.takeMVar $ conState conn | |
if csTerminate csOld | |
then pass | |
else do | |
let | |
copts = conOpts conn | |
ropts = conRabbitOpts conn | |
rconn <- tryNewRabbitConnection ropts copts (fromMaybe 0 $ opRetyConnectionMaxAttempts copts) 0 | |
let csNew = | |
ConnectionState | |
{ csConnection = rconn | |
, csTerminate = False | |
} | |
Mv.putMVar (conState conn) csNew | |
Rb.addConnectionClosedHandler rconn True (onConnectionClosed conn) | |
chans <- Tv.readTVarIO $ conChannels conn | |
for_ (Map.elems chans) $ \ch -> do | |
rch <- Rb.openChannel rconn | |
atomically $ Tv.writeTVar (chRabbitChan ch) rch | |
chOnRestart ch ch | |
mkChannel :: Connection -> Maybe (Channel -> IO ()) -> IO Channel | |
mkChannel conn onChanRestart = do | |
id <- atomically $ Tv.stateTVar (conNextId conn) $ \i -> (i + 1, i + 1) | |
rconn <- getRabbitConnection conn | |
rch <- Rb.openChannel rconn | |
rch' <- Tv.newTVarIO rch | |
let ch = | |
Channel | |
{ chConnection = conn | |
, chRabbitChan = rch' | |
, chId = id | |
, chOnRestart = fromMaybe (const pass) onChanRestart | |
} | |
atomically $ Tv.modifyTVar' (conChannels conn) (Map.insert id ch) | |
Rb.addChannelExceptionHandler rch (onChanExcept rconn ch) | |
pure ch | |
where | |
onChanExcept rconn ch ex = do | |
if isExpectedChannelCloseEx ex | |
then pass | |
else do | |
rch <- Rb.openChannel rconn | |
atomically $ Tv.writeTVar (chRabbitChan ch) rch | |
Rb.addChannelExceptionHandler rch (onChanExcept rconn ch) | |
chOnRestart ch ch | |
withRestartableChannel :: Connection -> (Channel -> IO ()) -> IO Channel | |
withRestartableChannel conn fn = do | |
ch <- mkChannel conn (Just fn) | |
fn ch | |
pure ch | |
getRabbitChannel :: Channel -> IO Rb.Channel | |
getRabbitChannel = Tv.readTVarIO . chRabbitChan | |
getRabbitConnection :: Connection -> IO Rb.Connection | |
getRabbitConnection con = do | |
cs <- Mv.readMVar $ conState con | |
if csTerminate cs | |
then throwString "Connection was terminated" | |
else pure $ csConnection cs | |
isExpectedChannelCloseEx :: SomeException -> Bool | |
isExpectedChannelCloseEx e = | |
case fromException e :: Maybe Rb.AMQPException of | |
Just (Rb.ChannelClosedException Rb.Normal _) -> True | |
_ -> False | |
------------------------------------------------------------------------------------------------------------------------------ | |
-- Topic consumer | |
------------------------------------------------------------------------------------------------------------------------------ | |
runTopicConsumer :: Connection -> Text -> Text -> Text -> (Rb.Message -> IO ()) -> IO () | |
runTopicConsumer conn exchangeName queueName routingExpr cfn = void . withRestartableChannel conn $ \ch -> do | |
let ex = Rb.newExchange { Rb.exchangeName = exchangeName | |
, Rb.exchangeType = "topic" | |
, Rb.exchangeDurable = True | |
, Rb.exchangeAutoDelete = False | |
} | |
rch <- getRabbitChannel ch | |
Rb.qos rch 0 1 True | |
Rb.declareExchange rch ex | |
(queue, _, _) <- Rb.declareQueue rch Rb.newQueue { Rb.queueName = queueName | |
, Rb.queueAutoDelete = False | |
, Rb.queueDurable = True | |
, Rb.queueHeaders = Rb.FieldTable (Map.singleton "x-queue-mode" $ Rb.FVString "lazy") | |
} | |
Rb.bindQueue rch queue exchangeName routingExpr | |
_ <- Rb.consumeMsgs rch queue Rb.Ack safeGet | |
pass | |
where | |
safeGet (msg, env) = do | |
catches | |
(cfn msg >> Rb.ackEnv env) | |
[ Handler $ \(e::Rb.ChanThreadKilledException) -> throwM e | |
, Handler $ \(e::SomeException) -> print e -- TODO do something, but don't rethrow | |
] | |
------------------------------------------------------------------------------------------------------------------------------ | |
------------------------------------------------------------------------------------------------------------------------------ | |
-- Topic publisher | |
------------------------------------------------------------------------------------------------------------------------------ | |
newtype RoutingKey = RoutingKey Text deriving (Show, Eq) | |
createTopicPublisher :: Connection -> Text -> Text -> IO (BSL.ByteString -> IO ()) | |
createTopicPublisher conn exchangeName routingKey = do | |
ch' <- withRestartableChannel conn $ \ch -> do | |
let ex = Rb.newExchange { Rb.exchangeName = exchangeName | |
, Rb.exchangeType = "topic" | |
, Rb.exchangeDurable = True | |
, Rb.exchangeAutoDelete = False | |
} | |
rch <- getRabbitChannel ch | |
Rb.qos rch 0 1 True | |
Rb.declareExchange rch ex | |
pure $ \msg -> trySend ch' msg | |
where | |
trySend ch' msg = do | |
rch <- getRabbitChannel ch' | |
void $ Rb.publishMsg rch exchangeName routingKey | |
(Rb.newMsg { Rb.msgBody = msg | |
, Rb.msgDeliveryMode = Just Rb.Persistent | |
}) | |
------------------------------------------------------------------------------------------------------------------------------ | |
------------------------------------------------------------------------------------------------------------------------------ | |
-- Fanout consumer | |
------------------------------------------------------------------------------------------------------------------------------ | |
runFanoutConsumer :: Connection -> Text -> (Rb.Message -> IO ()) -> IO () | |
runFanoutConsumer conn exchangeName cfn = void . withRestartableChannel conn $ \ch -> do | |
let ex = Rb.newExchange { Rb.exchangeName = exchangeName | |
, Rb.exchangeType = "fanout" | |
, Rb.exchangeDurable = True | |
, Rb.exchangeAutoDelete = False | |
} | |
rch <- getRabbitChannel ch | |
Rb.qos rch 0 1 True | |
Rb.declareExchange rch ex | |
(queue, _, _) <- Rb.declareQueue rch Rb.newQueue { Rb.queueName = "" | |
, Rb.queueAutoDelete = True | |
, Rb.queueDurable = False | |
, Rb.queueExclusive = True | |
, Rb.queueHeaders = Rb.FieldTable (Map.singleton "x-queue-mode" $ Rb.FVString "lazy") | |
} | |
Rb.bindQueue rch queue exchangeName "" | |
_ <- Rb.consumeMsgs rch queue Rb.Ack safeGet | |
pass | |
where | |
safeGet (msg, env) = do | |
catches | |
(cfn msg >> Rb.ackEnv env) | |
[ Handler $ \(e::Rb.ChanThreadKilledException) -> throwM e | |
, Handler $ \(e::SomeException) -> print e -- TODO do something, but don't rethrow | |
] | |
------------------------------------------------------------------------------------------------------------------------------ | |
------------------------------------------------------------------------------------------------------------------------------ | |
-- Fanout publisher | |
------------------------------------------------------------------------------------------------------------------------------ | |
createFanoutPublisher :: Connection -> Text -> IO (BSL.ByteString -> IO ()) | |
createFanoutPublisher conn exchangeName = do | |
ch' <- withRestartableChannel conn $ \ch -> do | |
let ex = Rb.newExchange { Rb.exchangeName = exchangeName | |
, Rb.exchangeType = "fanout" | |
, Rb.exchangeDurable = True | |
, Rb.exchangeAutoDelete = False | |
} | |
rch <- getRabbitChannel ch | |
Rb.qos rch 0 1 True | |
Rb.declareExchange rch ex | |
pure $ \msg -> trySend ch' msg | |
where | |
trySend ch' msg = do | |
rch <- getRabbitChannel ch' | |
void $ Rb.publishMsg rch exchangeName "" | |
(Rb.newMsg { Rb.msgBody = msg | |
, Rb.msgDeliveryMode = Just Rb.Persistent | |
}) | |
------------------------------------------------------------------------------------------------------------------------------ | |
------------------------------------------------------------------------------------------------------------------------------ | |
-- RPC Handler | |
------------------------------------------------------------------------------------------------------------------------------ | |
handleRpc :: Connection -> Text -> (Rb.Message -> IO BSL.ByteString) -> IO () | |
handleRpc conn rpcName cfn = void . withRestartableChannel conn $ \ch -> do | |
rch <- getRabbitChannel ch | |
Rb.qos rch 0 1 True | |
(queue, _, _) <- Rb.declareQueue rch Rb.newQueue { Rb.queueName = rpcName | |
, Rb.queueAutoDelete = True | |
, Rb.queueDurable = False | |
} | |
void $ Rb.consumeMsgs rch queue Rb.NoAck (safeReply rch) | |
where | |
safeReply rch (msg, _env) = do | |
catches | |
(do | |
resp <- cfn msg | |
let reply = Rb.newMsg { Rb.msgCorrelationID = Rb.msgCorrelationID msg | |
, Rb.msgBody = resp | |
} | |
void $ Rb.publishMsg rch "" (fromMaybe "?" $ Rb.msgReplyTo msg) reply | |
) | |
[ Handler $ \(e::Rb.ChanThreadKilledException) -> throwM e | |
, Handler $ \(e::SomeException) -> print e -- TODO do something, but don't rethrow | |
] | |
------------------------------------------------------------------------------------------------------------------------------ | |
data RpcResult | |
= RpcOk BSL.ByteString | |
| RpcCallerTimeout | |
| RpcHandlerTimeout | |
deriving (Show, Eq) | |
------------------------------------------------------------------------------------------------------------------------------ | |
-- Call a RPC service | |
-- see createRpcCaller | |
------------------------------------------------------------------------------------------------------------------------------ | |
callRpc :: Connection -> NominalDiffTime -> Text -> BSL.ByteString -> IO RpcResult | |
callRpc conn timeout rpcName msg = do | |
ch' <- mkChannel conn Nothing | |
rch <- getRabbitChannel ch' | |
Rb.qos rch 0 1 True | |
(queue, _, _) <- Rb.declareQueue rch Rb.newQueue { Rb.queueName = "" | |
, Rb.queueAutoDelete = True | |
, Rb.queueDurable = False | |
, Rb.queueExclusive = True | |
} | |
wait <- Mv.newEmptyMVar | |
void $ Rb.consumeMsgs rch queue Rb.NoAck $ \(reply, _) -> do | |
void . Mv.tryPutMVar wait . RpcOk $ Rb.msgBody reply | |
void . forkIO $ do | |
threadDelay $ fst (properFraction timeout) * 1_000_000 | |
void . Mv.tryPutMVar wait $ RpcCallerTimeout | |
void $ Rb.publishMsg rch "" rpcName | |
(Rb.newMsg { Rb.msgBody = msg | |
, Rb.msgDeliveryMode = Just Rb.NonPersistent | |
, Rb.msgReplyTo = Just queue | |
}) | |
Mv.takeMVar wait | |
------------------------------------------------------------------------------------------------------------------------------ | |
------------------------------------------------------------------------------------------------------------------------------ | |
-- Create a RPC caller | |
-- this is a long running caller. Similar to callRpc but more light weight as you are not | |
-- create a queue for each call. | |
-- Use this if you are going to make multiple RPC calls, use callRpc for simple ad-hoc calls | |
------------------------------------------------------------------------------------------------------------------------------ | |
createRpcCaller :: Connection -> NominalDiffTime -> Text -> IO (BSL.ByteString -> IO RpcResult) | |
createRpcCaller conn timeout rpcName = do | |
-- Map of wait handles | |
waits' <- Tv.newTVarIO mempty | |
-- Name of RPC queue | |
queueName' <- Mv.newEmptyMVar | |
void . withRestartableChannel conn $ \ch' -> do | |
rch <- getRabbitChannel ch' | |
Rb.qos rch 0 1 True | |
-- Create the RPC queue | |
(queue, _, _) <- Rb.declareQueue rch Rb.newQueue { Rb.queueName = "" | |
, Rb.queueAutoDelete = True | |
, Rb.queueDurable = False | |
, Rb.queueExclusive = True | |
} | |
-- Save the queue name, this will change on reconnect | |
void $ Mv.tryTakeMVar queueName' | |
void $ Mv.putMVar queueName' queue | |
-- Consume responses to the RPC queue | |
void $ Rb.consumeMsgs rch queue Rb.NoAck $ \(reply, _) -> do | |
-- correlation Id | |
let id = fromMaybe "?" . Rb.msgCorrelationID $ reply | |
ackCorrelationId waits' id . RpcOk $ Rb.msgBody reply | |
-- Channel for the call | |
chSend' <- mkChannel conn Nothing | |
rchSend <- getRabbitChannel chSend' | |
pure $ \request -> do | |
-- new correlation id | |
id <- show <$> UU.nextRandom | |
-- new wait handle | |
wait <- Mv.newEmptyMVar | |
atomically $ Tv.modifyTVar' waits' $ \ws -> Map.insert id wait ws | |
-- get the current rpc queue name | |
queueName <- Mv.readMVar queueName' | |
-- send the request | |
void $ Rb.publishMsg rchSend "" rpcName | |
(Rb.newMsg { Rb.msgBody = request | |
, Rb.msgDeliveryMode = Just Rb.NonPersistent | |
, Rb.msgReplyTo = Just queueName | |
, Rb.msgCorrelationID = Just id | |
}) | |
void . forkIO $ do | |
threadDelay $ fst (properFraction timeout) * 1_000_000 | |
void . Mv.tryPutMVar wait $ RpcCallerTimeout | |
-- wait for the response | |
Mv.takeMVar wait | |
where | |
ackCorrelationId waits' id result = do | |
-- try find the wait handle for this correlation id | |
waitRes <- atomically . Tv.stateTVar waits' $ \waits -> | |
case Map.lookup id waits of | |
Nothing -> (Left $ "unknown correlation Id: " <> id, waits) | |
Just w -> (Right w, Map.delete id waits) | |
case waitRes of | |
Right wait -> void . Mv.tryPutMVar wait $ result | |
Left e -> print e --TODO | |
------------------------------------------------------------------------------------------------------------------------------ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment