Last active
May 24, 2023 17:36
-
-
Save michaelst/b49c05f39019fd5a92b6377f97dca3dd to your computer and use it in GitHub Desktop.
Postgres Proxy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def accept(port) do | |
{:ok, socket} = | |
:gen_tcp.listen(port, [:binary, active: false, reuseaddr: true, packet: 0, nodelay: true]) | |
Logger.info("Accepting connections on port #{port}") | |
loop_acceptor(socket) | |
end | |
defp loop_acceptor(socket) do | |
{:ok, client_conn} = :gen_tcp.accept(socket) | |
state = %{ | |
client_conn: client_conn, | |
... | |
} | |
{:ok, pid} = | |
Task.Supervisor.start_child(QueryDesk.PostgresProxy.TaskSupervisor, fn -> | |
read_client(state, nil) | |
end) | |
:ok = :gen_tcp.controlling_process(client_conn, pid) | |
loop_acceptor(socket) | |
end | |
defp read_client(%{client_conn: client_conn} = state, nil) do | |
data = client_conn |> read_line() | |
handle_message(state, parse_msg(data)) | |
end | |
defp read_line({:sslsocket, _port, _pids} = socket) do | |
{:ok, data} = :ssl.recv(socket, 0) | |
data | |
end | |
defp read_line(socket) when is_port(socket) do | |
{:ok, data} = :gen_tcp.recv(socket, 0) | |
data | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defp maybe_connect_to_database( | |
%{user: user, team: team, connect_params: connect_params, database_conn: nil} = state | |
) do | |
database = | |
QueryDesk.Api.get!( | |
Database, | |
[team_id: team.id, name: connect_params["database"]], | |
load: [:default_credential, :users] | |
) | |
# only allow connections if no reviews are required and they are allowed to access | |
if database.default_credential.reviews_required == 0 and | |
QueryDesk.Auth.Utils.can_access_database?(user, database) do | |
{:ok, database_conn, pid} = Utils.open_database_connection(database, state) | |
... | |
:ok = | |
Utils.send( | |
state, | |
Utils.startup_message(database, connect_params) | |
) | |
state | |
end | |
end | |
defp open_local_database_connection(database, state) do | |
{:ok, database_conn} = open_database_connection(database) | |
state = | |
state | |
|> Map.put(:database_conn, database_conn) | |
|> Map.put(:database, database) | |
{:ok, pid} = | |
Task.Supervisor.start_child(QueryDesk.PostgresProxy.TaskSupervisor, fn -> | |
QueryDesk.PostgresProxy.read_database(state) | |
end) | |
case database_conn do | |
{:sslsocket, _port, _pids} -> :ok = :ssl.controlling_process(database_conn, pid) | |
port when is_port(port) -> :ok = :gen_tcp.controlling_process(database_conn, pid) | |
end | |
{:ok, database_conn, pid} | |
end | |
def open_database_connection(database) do | |
{:ok, database_conn} = | |
:gen_tcp.connect(to_charlist(database.hostname), 5432, | |
mode: :binary, | |
active: false, | |
packet: :raw | |
) | |
if database.ssl do | |
# send ssl request | |
:gen_tcp.send( | |
database_conn, | |
<<8::integer-size(32), 1234::integer-size(16), 5679::integer-size(16)>> | |
) | |
# S means ssl is supported and that we can start the connection | |
{:ok, <<?S>>} = :gen_tcp.recv(database_conn, 1) | |
ssl_opts = | |
Enum.reject( | |
[ | |
verify: :verify_none, | |
cacertfile: create_ssl_file(database, :cacertfile), | |
keyfile: create_ssl_file(database, :keyfile), | |
certfile: create_ssl_file(database, :certfile) | |
], | |
fn {_k, v} -> v == "" end | |
) | |
:ssl.connect(database_conn, ssl_opts) | |
else | |
{:ok, database_conn} | |
end | |
end | |
# once connecting we will receive a message from the database to send password (md5 example) | |
defp maybe_send_to_client( | |
<<?R, 0, 0, 0, 12, 0, 0, 0, 5, salt::binary-size(4)>>, | |
%{database: database} = state | |
) do | |
user = database.default_credential.username | |
pass = database.default_credential.password | |
digest = :erlang.md5([pass, user]) |> Base.encode16(case: :lower) | |
digest = :erlang.md5([digest, salt]) |> Base.encode16(case: :lower) | |
size = byte_size(digest) + 8 | |
:ok = | |
Utils.send( | |
state, | |
<<?p, size::integer-size(32), "md5", digest::binary, 0>> | |
) | |
end | |
def startup_message(database, connect_params) do | |
encoded_connect_params = | |
connect_params | |
|> Map.put("database", database.database) | |
|> Map.put("user", database.default_credential.username) | |
|> Map.put("application_name", "QueryDesk Proxy") | |
|> Enum.reduce(<<>>, fn {k, v}, acc -> acc <> <<k::binary, 0, v::binary, 0>> end) | |
size = byte_size(encoded_connect_params) + 9 | |
<<size::integer-size(32), 0, 3, 0, 0, encoded_connect_params::binary, 0>> | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# SSL Request | |
def parse_msg(<<len::integer-32, 1234::integer-16, 5679::integer-16>> = bin) do | |
case bin do | |
<<msg_body::binary-size(len), final_rest::binary>> -> | |
{:ok, {{:msgSSLRequest, nil}, msg_body}, final_rest} | |
... | |
# Most Messages | |
def parse_msg(<<c::size(8), rest::binary>>) do | |
tag = tag_to_msg_type(c) | |
<<len::integer-32, _::binary>> = rest | |
case rest do | |
<<msg_body::binary-size(len), other_msg::binary>> -> | |
{:ok, {{tag, c}, msg_body}, other_msg} | |
_other -> | |
{:continuation, | |
fn data -> | |
handle_continuation(len, {tag, c}, rest, data) | |
end} | |
end | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defp handle_message( | |
state, | |
{:ok, {{:msgQuery, _c}, <<_len::unsigned-integer-32, query_data::binary>>}, _rest} = msg | |
) do | |
query = String.trim_trailing(query_data, <<0>>) | |
handle_query(query, state, msg) | |
next_message(state, msg) | |
end | |
defp handle_query(query, state, {:ok, {{_msg_type, c}, data}, _rest}) do | |
... | |
Logger.debug("running query: #{query}") | |
:ok = Utils.send(state, <<c, data::binary>>) | |
end | |
# Utils.send/2 | |
def send(%{database_conn: {:sslsocket, _port, _pids} = database_conn}, binary) do | |
:ssl.send(database_conn, binary) | |
end | |
def send(%{database_conn: database_conn}, binary) when is_port(database_conn) do | |
:gen_tcp.send(database_conn, binary) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment