Created
February 2, 2016 12:51
-
-
Save AeroNotix/42fbb5067fe518da1325 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-module(proxy_protocol). | |
%% Code between markers is stolen from Cowboy | |
%%% BEGIN STOLEN CODE | |
-export([parse_request/3]). | |
-export([start_link/4]). | |
-export([init/4]). | |
-record(state, { | |
socket :: inet:socket(), | |
transport :: module(), | |
middlewares :: [module()], | |
compress :: boolean(), | |
env :: cowboy_middleware:env(), | |
onrequest :: undefined | cowboy:onrequest_fun(), | |
onresponse = undefined :: undefined | cowboy:onresponse_fun(), | |
max_empty_lines :: non_neg_integer(), | |
req_keepalive = 1 :: non_neg_integer(), | |
max_keepalive :: non_neg_integer(), | |
max_request_line_length :: non_neg_integer(), | |
max_header_name_length :: non_neg_integer(), | |
max_header_value_length :: non_neg_integer(), | |
max_headers :: non_neg_integer(), | |
timeout :: timeout(), | |
until :: non_neg_integer() | infinity | |
}). | |
-spec start_link(ranch:ref(), inet:socket(), module(), term()) -> {ok, pid()}. | |
start_link(Ref, Socket, Transport, Opts) -> | |
Pid = spawn_link(?MODULE, init, [Ref, Socket, Transport, Opts]), | |
{ok, Pid}. | |
%% Faster alternative to proplists:get_value/3. | |
get_value(Key, Opts, Default) -> | |
case lists:keyfind(Key, 1, Opts) of | |
{_, Value} -> Value; | |
_ -> Default | |
end. | |
-spec init(ranch:ref(), inet:socket(), module(), term()) -> ok. | |
init(Ref, Socket, Transport, Opts) -> | |
Compress = get_value(compress, Opts, false), | |
MaxEmptyLines = get_value(max_empty_lines, Opts, 5), | |
MaxHeaderNameLength = get_value(max_header_name_length, Opts, 64), | |
MaxHeaderValueLength = get_value(max_header_value_length, Opts, 4096), | |
MaxHeaders = get_value(max_headers, Opts, 100), | |
MaxKeepalive = get_value(max_keepalive, Opts, 100), | |
MaxRequestLineLength = get_value(max_request_line_length, Opts, 4096), | |
Middlewares = get_value(middlewares, Opts, [cowboy_router, cowboy_handler]), | |
Env = [{listener, Ref}|get_value(env, Opts, [])], | |
OnRequest = get_value(onrequest, Opts, undefined), | |
OnResponse = get_value(onresponse, Opts, undefined), | |
Timeout = get_value(timeout, Opts, 5000), | |
ok = ranch:accept_ack(Ref), | |
wait_request(<<>>, #state{socket=Socket, transport=Transport, | |
middlewares=Middlewares, compress=Compress, env=Env, | |
max_empty_lines=MaxEmptyLines, max_keepalive=MaxKeepalive, | |
max_request_line_length=MaxRequestLineLength, | |
max_header_name_length=MaxHeaderNameLength, | |
max_header_value_length=MaxHeaderValueLength, max_headers=MaxHeaders, | |
onrequest=OnRequest, onresponse=OnResponse, | |
timeout=Timeout, until=until(Timeout)}, 0). | |
-spec wait_request(binary(), #state{}, non_neg_integer()) -> ok. | |
wait_request(Buffer, State=#state{socket=Socket, transport=Transport, | |
until=Until}, ReqEmpty) -> | |
case recv(Socket, Transport, Until) of | |
{ok, Data} -> | |
parse_request(<< Buffer/binary, Data/binary >>, State, ReqEmpty); | |
{error, _} -> | |
terminate(State) | |
end. | |
-spec recv(inet:socket(), module(), non_neg_integer() | infinity) | |
-> {ok, binary()} | {error, closed | timeout | atom()}. | |
recv(Socket, Transport, infinity) -> | |
Transport:recv(Socket, 0, infinity); | |
recv(Socket, Transport, Until) -> | |
{Me, S, Mi} = os:timestamp(), | |
Now = Me * 1000000000 + S * 1000 + Mi div 1000, | |
Timeout = Until - Now, | |
if Timeout < 0 -> | |
{error, timeout}; | |
true -> | |
Transport:recv(Socket, 0, Timeout) | |
end. | |
-spec terminate(#state{}) -> ok. | |
terminate(#state{socket=Socket, transport=Transport}) -> | |
Transport:close(Socket), | |
ok. | |
-spec until(timeout()) -> non_neg_integer() | infinity. | |
until(infinity) -> | |
infinity; | |
until(Timeout) -> | |
{Me, S, Mi} = os:timestamp(), | |
Me * 1000000000 + S * 1000 + Mi div 1000 + Timeout. | |
%%% END STOLEN CODE | |
parse_request(<<"PROXY ", Data/binary>>, | |
State = #state{socket = Socket, transport = Transport, until = Until}, ReqEmpty) -> | |
{Proxy, Other} = case binary:split(Data, [<<"\r\n">>]) of | |
[P, O] -> {P, O}; | |
[P] -> {P, <<>>} | |
end, | |
case parse_proxy_protocol(Proxy) of | |
unknown_peer when Other =:= <<>> -> | |
{ok, NextData} = recv(Socket, Transport, Until), | |
cowboy_protocol:parse_request(NextData, State, ReqEmpty), | |
{ok, State}; | |
unknown_peer -> | |
cowboy_protocol:parse_request(Other, State, ReqEmpty), | |
{ok, State}; | |
not_proxy_protocol -> | |
Transport:close(Socket), | |
throw(not_proxy_protocol); | |
ProxyInfo when Other =:= <<>> -> | |
%% saucy | |
put(proxy_info, ProxyInfo), | |
case recv(Socket, Transport, Until) of | |
{ok, NextData} -> | |
cowboy_protocol:parse_request(NextData, State, ReqEmpty); | |
{error, _} -> | |
terminate(State) | |
end; | |
ProxyInfo -> | |
%% saucy | |
put(proxy_info, ProxyInfo), | |
cowboy_protocol:parse_request(Other, State, ReqEmpty) | |
end; | |
parse_request(Data, State, ReqEmpty) -> | |
cowboy_protocol:parse_request(Data, State, ReqEmpty). | |
parse_proxy_protocol(<<"TCP", Proto:1/binary, _:1/binary, Info/binary>>) -> | |
InfoStr = binary_to_list(Info), | |
case string:tokens(InfoStr, " \r\n") of | |
[SourceAddress, DestAddress, SourcePort, DestPort] -> | |
case {parse_inet(Proto), parse_ips([SourceAddress, DestAddress], []), | |
parse_ports([SourcePort, DestPort], [])} of | |
{ProtoParsed, [SourceInetAddress, DestInetAddress], [SourceInetPort, DestInetPort]} -> | |
{ProtoParsed, SourceInetAddress, DestInetAddress, SourceInetPort, DestInetPort}; | |
_ -> | |
malformed_proxy_protocol | |
end | |
end; | |
parse_proxy_protocol(<<"UNKNOWN", _/binary>>) -> | |
unknown_peer; | |
parse_proxy_protocol(_) -> | |
not_proxy_protocol. | |
parse_inet(<<"4">>) -> | |
ipv4; | |
parse_inet(<<"6">>) -> | |
ipv6; | |
parse_inet(_) -> | |
{error, invalid_inet_version}. | |
parse_ports([], Retval) -> | |
Retval; | |
parse_ports([Port | Ports], Retval) -> | |
try list_to_integer(Port) of | |
IntPort -> | |
parse_ports(Ports, Retval ++ [IntPort]) | |
catch | |
error:badarg -> | |
{error, invalid_port} | |
end. | |
parse_ips([], Retval) -> | |
Retval; | |
parse_ips([Ip | Ips], Retval) -> | |
case inet:parse_address(Ip) of | |
{ok, ParsedIp} -> | |
parse_ips(Ips, Retval ++ [ParsedIp]); | |
_ -> | |
{error, invalid_address} | |
end. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment