-
-
Save chrismccord/94cfc249f10db999da3ba731857ae797 to your computer and use it in GitHub Desktop.
Phoenix Drain Stop
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defmodule DrainStop do | |
@moduledoc """ | |
DrainStop Attempts to gracefully shutdown an endpoint when a normal shutdown | |
occurs. It first shuts down the acceptor, ensuring that no new requests can be | |
made. It then waits for all pending requests to complete. If the timeout | |
expires before this happens, it stops waiting, allowing the supervision tree | |
to continue its shutdown order. | |
DrainStop should be installed in your supervision tree *after* the | |
EndPoint it is going to drain stop. | |
DrainStop takes two options: | |
`endpoint`: The `Phoenix.Endpoint` to drain stop. Required. `timeout`: The | |
amount of time to allow for requests to finish in msec. Defaults to `5000`. | |
For example: | |
children = [ | |
supervisor(MyApp.Endpoint, []), | |
worker( | |
DrainStop, | |
[[timeout: 10_000, endpoint: MyApp.Endpoint]], | |
[shutdown: 15_000] | |
) | |
] | |
""" | |
use GenServer | |
require Logger | |
import Supervisor, only: [which_children: 1, terminate_child: 2] | |
def start_link(options) do | |
GenServer.start_link(DrainStop, options) | |
end | |
def init(options) do | |
Process.flag(:trap_exit, true) | |
endpoint = Keyword.fetch!(options, :endpoint) | |
timeout = Keyword.get(options, :timeout, 5000) | |
{:ok, {endpoint, timeout}} | |
end | |
def terminate(:shutdown, {endpoint, timeout}), do: drain_endpoint(endpoint, timeout) | |
def terminate({:shutdown, _}, {endpoint, timeout}), do: drain_endpoint(endpoint, timeout) | |
def terminate(:normal, {endpoint, timeout}), do: drain_endpoint(endpoint, timeout) | |
def terminate(_, _), do: :ok | |
def drain_endpoint(endpoint, timeout) do | |
stop_listening(endpoint) | |
wait_for_requests(endpoint, timeout) | |
end | |
def wait_for_requests(endpoint, timeout) do | |
Logger.info("DrainStop starting graceful shutdown with timeout: #{timeout}") | |
timer_ref = :erlang.start_timer(timeout, self, :timeout) | |
do_wait_for_requests(endpoint, timer_ref, %{}) | |
end | |
defp do_wait_for_requests(endpoint, timer_ref, refs) do | |
get_monitor = fn pid -> | |
refs[pid] || Process.monitor(pid) | |
end | |
refs = | |
endpoint | |
|> pending_requests | |
|> Map.new(&{&1, get_monitor.(&1)}) | |
case Map.size(refs) do | |
0 -> | |
Logger.info("DrainStop Successful, no more connections") | |
:erlang.cancel_timer(timer_ref) | |
n -> | |
time_left = :erlang.read_timer(timer_ref) | |
Logger.info("DrainStop waiting #{time_left} msec for #{n} more connections to shutdown") | |
receive do | |
{:DOWN, _monitor_ref, _, _, _} -> | |
do_wait_for_requests(endpoint, timer_ref, refs) | |
{:timeout, ^timer_ref, :timeout} -> | |
Logger.error("DrainStop timeout") | |
msg -> | |
Logger.error("DrainStop unexpected msg: #{inspect msg}") | |
do_wait_for_requests(endpoint, timer_ref, refs) | |
end | |
end | |
end | |
def pending_requests(endpoint) do | |
endpoint | |
|> ranch_listener_sup_pids | |
|> Enum.map(fn pid -> | |
for {:ranch_conns_sup, sup_pid, _, _} <- which_children(pid) do | |
for {_, request_pid, _, _} <- which_children(sup_pid), do: request_pid | |
end | |
end) | |
|> List.flatten | |
end | |
def ranch_listener_sup_pids(endpoint) do | |
pids = | |
for {Phoenix.Endpoint.Server, pid, _, _} <- which_children(endpoint) do | |
for {{:ranch_listener_sup, _}, pid, _, _} <- which_children(pid), do: pid | |
end | |
List.flatten(pids) | |
end | |
def stop_listening(endpoint) do | |
endpoint | |
|> ranch_listener_sup_pids | |
|> Enum.each(&terminate_child(&1, :ranch_acceptors_sup)) | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defmodule DrainStopTest do | |
use ExUnit.Case, async: false | |
use Phoenix.ConnTest | |
defmodule TestSupervisor do | |
@drain_stop_timeout 100 | |
def start_link do | |
import Supervisor.Spec, warn: false | |
children = [ | |
# Start the endpoint when the application starts | |
supervisor(DrainStopTest.TestEndpoint, []), | |
worker( | |
DrainStop, | |
[[timeout: @drain_stop_timeout, endpoint: DrainStopTest.TestEndpoint]], | |
[shutdown: @drain_stop_timeout * 2] | |
), | |
] | |
opts = [strategy: :one_for_one] | |
Supervisor.start_link(children, opts) | |
end | |
end | |
defmodule TestPlug do | |
def init(opts), do: opts | |
def call(conn, _) do | |
pid = Application.get_env(:drain_stop_test, :test_pid) | |
conn = Plug.Conn.fetch_query_params(conn) | |
send pid, :request_start | |
{time, _} = Integer.parse(conn.params["sleep"]) | |
:timer.sleep(time) | |
send pid, :request_end | |
conn | |
end | |
end | |
defmodule TestEndpoint do | |
use Phoenix.Endpoint, otp_app: :drain_stop_test | |
plug DrainStopTest.TestPlug | |
end | |
@endpoint DrainStopTest.TestEndpoint | |
setup do | |
Application.put_env(:drain_stop_test, :test_pid, self) | |
Application.put_env(:drain_stop_test, DrainStopTest.TestEndpoint, | |
http: [port: "4807"], url: [host: "example.com"], server: true) | |
{:ok, pid} = DrainStopTest.TestSupervisor.start_link | |
Process.flag(:trap_exit, true) | |
on_exit(fn -> | |
if Process.whereis(DrainStopTest.TestEndpoint) do | |
Supervisor.stop(DrainStopTest.TestEndpoint) | |
end | |
Process.exit(pid, :brutal_kill) | |
end) | |
{:ok, pid: pid} | |
end | |
test "waits for request to finish", %{pid: pid} do | |
Task.async(fn -> | |
HTTPoison.get("http://localhost:4807/?sleep=50") | |
end) | |
assert_receive :request_start, 1000 | |
Supervisor.stop(pid, :shutdown) | |
assert_received :request_end | |
end | |
test "truncates requests that don't finish in time", %{pid: pid} do | |
Task.async(fn -> | |
HTTPoison.get("http://localhost:4807/?sleep=500") | |
end) | |
assert_receive :request_start, 1000 | |
Supervisor.stop(pid, :shutdown) | |
refute_received :request_end | |
end | |
test "does not allow new requests" do | |
# This is harder to test without reaching in to internals... | |
DrainStop.stop_listening(DrainStopTest.TestEndpoint) | |
assert_raise(HTTPoison.Error, | |
fn -> HTTPoison.get!("http://localhost:4807/?sleep=500") end) | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment