Skip to content

Instantly share code, notes, and snippets.

@chrismccord
Forked from aaronjensen/drain_stop.ex
Created May 27, 2016 15:26
Show Gist options
  • Save chrismccord/94cfc249f10db999da3ba731857ae797 to your computer and use it in GitHub Desktop.
Save chrismccord/94cfc249f10db999da3ba731857ae797 to your computer and use it in GitHub Desktop.
Phoenix Drain Stop
defmodule DrainStop do
@moduledoc """
DrainStop Attempts to gracefully shutdown an endpoint when a normal shutdown
occurs. It first shuts down the acceptor, ensuring that no new requests can be
made. It then waits for all pending requests to complete. If the timeout
expires before this happens, it stops waiting, allowing the supervision tree
to continue its shutdown order.
DrainStop should be installed in your supervision tree *after* the
EndPoint it is going to drain stop.
DrainStop takes two options:
`endpoint`: The `Phoenix.Endpoint` to drain stop. Required. `timeout`: The
amount of time to allow for requests to finish in msec. Defaults to `5000`.
For example:
children = [
supervisor(MyApp.Endpoint, []),
worker(
DrainStop,
[[timeout: 10_000, endpoint: MyApp.Endpoint]],
[shutdown: 15_000]
)
]
"""
use GenServer
require Logger
import Supervisor, only: [which_children: 1, terminate_child: 2]
def start_link(options) do
GenServer.start_link(DrainStop, options)
end
def init(options) do
Process.flag(:trap_exit, true)
endpoint = Keyword.fetch!(options, :endpoint)
timeout = Keyword.get(options, :timeout, 5000)
{:ok, {endpoint, timeout}}
end
def terminate(:shutdown, {endpoint, timeout}), do: drain_endpoint(endpoint, timeout)
def terminate({:shutdown, _}, {endpoint, timeout}), do: drain_endpoint(endpoint, timeout)
def terminate(:normal, {endpoint, timeout}), do: drain_endpoint(endpoint, timeout)
def terminate(_, _), do: :ok
def drain_endpoint(endpoint, timeout) do
stop_listening(endpoint)
wait_for_requests(endpoint, timeout)
end
def wait_for_requests(endpoint, timeout) do
Logger.info("DrainStop starting graceful shutdown with timeout: #{timeout}")
timer_ref = :erlang.start_timer(timeout, self, :timeout)
do_wait_for_requests(endpoint, timer_ref, %{})
end
defp do_wait_for_requests(endpoint, timer_ref, refs) do
get_monitor = fn pid ->
refs[pid] || Process.monitor(pid)
end
refs =
endpoint
|> pending_requests
|> Map.new(&{&1, get_monitor.(&1)})
case Map.size(refs) do
0 ->
Logger.info("DrainStop Successful, no more connections")
:erlang.cancel_timer(timer_ref)
n ->
time_left = :erlang.read_timer(timer_ref)
Logger.info("DrainStop waiting #{time_left} msec for #{n} more connections to shutdown")
receive do
{:DOWN, _monitor_ref, _, _, _} ->
do_wait_for_requests(endpoint, timer_ref, refs)
{:timeout, ^timer_ref, :timeout} ->
Logger.error("DrainStop timeout")
msg ->
Logger.error("DrainStop unexpected msg: #{inspect msg}")
do_wait_for_requests(endpoint, timer_ref, refs)
end
end
end
def pending_requests(endpoint) do
endpoint
|> ranch_listener_sup_pids
|> Enum.map(fn pid ->
for {:ranch_conns_sup, sup_pid, _, _} <- which_children(pid) do
for {_, request_pid, _, _} <- which_children(sup_pid), do: request_pid
end
end)
|> List.flatten
end
def ranch_listener_sup_pids(endpoint) do
pids =
for {Phoenix.Endpoint.Server, pid, _, _} <- which_children(endpoint) do
for {{:ranch_listener_sup, _}, pid, _, _} <- which_children(pid), do: pid
end
List.flatten(pids)
end
def stop_listening(endpoint) do
endpoint
|> ranch_listener_sup_pids
|> Enum.each(&terminate_child(&1, :ranch_acceptors_sup))
end
end
defmodule DrainStopTest do
use ExUnit.Case, async: false
use Phoenix.ConnTest
defmodule TestSupervisor do
@drain_stop_timeout 100
def start_link do
import Supervisor.Spec, warn: false
children = [
# Start the endpoint when the application starts
supervisor(DrainStopTest.TestEndpoint, []),
worker(
DrainStop,
[[timeout: @drain_stop_timeout, endpoint: DrainStopTest.TestEndpoint]],
[shutdown: @drain_stop_timeout * 2]
),
]
opts = [strategy: :one_for_one]
Supervisor.start_link(children, opts)
end
end
defmodule TestPlug do
def init(opts), do: opts
def call(conn, _) do
pid = Application.get_env(:drain_stop_test, :test_pid)
conn = Plug.Conn.fetch_query_params(conn)
send pid, :request_start
{time, _} = Integer.parse(conn.params["sleep"])
:timer.sleep(time)
send pid, :request_end
conn
end
end
defmodule TestEndpoint do
use Phoenix.Endpoint, otp_app: :drain_stop_test
plug DrainStopTest.TestPlug
end
@endpoint DrainStopTest.TestEndpoint
setup do
Application.put_env(:drain_stop_test, :test_pid, self)
Application.put_env(:drain_stop_test, DrainStopTest.TestEndpoint,
http: [port: "4807"], url: [host: "example.com"], server: true)
{:ok, pid} = DrainStopTest.TestSupervisor.start_link
Process.flag(:trap_exit, true)
on_exit(fn ->
if Process.whereis(DrainStopTest.TestEndpoint) do
Supervisor.stop(DrainStopTest.TestEndpoint)
end
Process.exit(pid, :brutal_kill)
end)
{:ok, pid: pid}
end
test "waits for request to finish", %{pid: pid} do
Task.async(fn ->
HTTPoison.get("http://localhost:4807/?sleep=50")
end)
assert_receive :request_start, 1000
Supervisor.stop(pid, :shutdown)
assert_received :request_end
end
test "truncates requests that don't finish in time", %{pid: pid} do
Task.async(fn ->
HTTPoison.get("http://localhost:4807/?sleep=500")
end)
assert_receive :request_start, 1000
Supervisor.stop(pid, :shutdown)
refute_received :request_end
end
test "does not allow new requests" do
# This is harder to test without reaching in to internals...
DrainStop.stop_listening(DrainStopTest.TestEndpoint)
assert_raise(HTTPoison.Error,
fn -> HTTPoison.get!("http://localhost:4807/?sleep=500") end)
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment