Created
November 25, 2023 08:06
-
-
Save Jackarain/49043e156985a806aca76e3ab17e80cc to your computer and use it in GitHub Desktop.
基于 c++ 20 的 boost asio ssl_stream 实现
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// ssl_stream.hpp | |
// ~~~~~~~~~~~~~~ | |
// | |
// Copyright (c) 2023 Jack (jack dot wgm at gmail dot com) | |
// | |
// Distributed under the Boost Software License, Version 1.0. (See accompanying | |
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) | |
// | |
#ifndef INCLUDE__2023_11_24__SSL_STREAM_HPP | |
#define INCLUDE__2023_11_24__SSL_STREAM_HPP | |
#include <type_traits> | |
#include <boost/asio/ssl/context.hpp> | |
#include <boost/asio/ssl/detail/engine.hpp> | |
#include <boost/asio/ssl/stream.hpp> | |
#include <boost/asio/ssl/detail/handshake_op.hpp> | |
#include <boost/asio/ssl/detail/read_op.hpp> | |
#include <boost/asio/ssl/detail/write_op.hpp> | |
#include <boost/asio/ssl/detail/shutdown_op.hpp> | |
#include <boost/asio/co_spawn.hpp> | |
#include <boost/asio/detached.hpp> | |
#include "proxy/use_awaitable.hpp" | |
namespace secure { | |
namespace net = boost::asio; | |
template <typename Stream> | |
class ssl_stream : public net::ssl::stream_base | |
{ | |
ssl_stream(const ssl_stream&) = delete; | |
ssl_stream& operator=(const ssl_stream&) = delete; | |
enum { max_tls_record_size = 17 * 1024 }; | |
using engine = net::ssl::detail::engine; | |
auto handshake_op(stream_base::handshake_type type) | |
{ | |
return [this, type]( | |
engine& eng, | |
boost::system::error_code& ec, | |
std::size_t& bytes_transferred | |
) mutable | |
{ | |
bytes_transferred = 0; | |
return eng.handshake(type, ec); | |
}; | |
} | |
auto shutdown_op() | |
{ | |
return [this]( | |
engine& eng, | |
boost::system::error_code& ec, | |
std::size_t& bytes_transferred | |
) mutable | |
{ | |
bytes_transferred = 0; | |
return eng.shutdown(ec); | |
}; | |
} | |
template <typename ConstBufferSequence> | |
auto write_op(const ConstBufferSequence& buffers) | |
{ | |
return [this, buffers]( | |
engine& eng, | |
boost::system::error_code& ec, | |
std::size_t& bytes_transferred | |
) mutable | |
{ | |
unsigned char storage[ | |
net::detail::buffer_sequence_adapter< | |
net::const_buffer, | |
ConstBufferSequence>::linearisation_storage_size | |
]; | |
net::const_buffer buffer = | |
net::detail::buffer_sequence_adapter< | |
net::const_buffer, ConstBufferSequence>::linearise( | |
buffers, | |
net::buffer(storage) | |
); | |
return eng.write(buffer, ec, bytes_transferred); | |
}; | |
} | |
template <typename MutableBufferSequence> | |
auto read_op(const MutableBufferSequence& buffers) | |
{ | |
return [this, buffers]( | |
engine& eng, | |
boost::system::error_code& ec, | |
std::size_t& bytes_transferred | |
) mutable | |
{ | |
net::mutable_buffer buffer = | |
net::detail::buffer_sequence_adapter<net::mutable_buffer, | |
MutableBufferSequence>::first(buffers); | |
return eng.read(buffer, ec, bytes_transferred); | |
}; | |
} | |
template <typename ConstBufferSequence> | |
auto buffered_handshake_op(stream_base::handshake_type type, | |
const ConstBufferSequence& buffers) | |
{ | |
return [this, type, buffers, | |
total_buffer_size(net::buffer_size(buffers))] | |
( | |
engine& eng, | |
boost::system::error_code& ec, | |
std::size_t& bytes_transferred | |
) mutable | |
{ | |
auto iter = net::buffer_sequence_begin(buffers); | |
auto end = net::buffer_sequence_end(buffers); | |
std::size_t accumulated_size = 0; | |
for (;;) | |
{ | |
engine::want want = eng.handshake(type, ec); | |
if (want != engine::want_input_and_retry || | |
bytes_transferred == total_buffer_size) | |
return want; | |
while (iter != end) | |
{ | |
net::const_buffer buffer(*iter); | |
if (bytes_transferred >= accumulated_size + buffer.size()) | |
{ | |
accumulated_size += buffer.size(); | |
++iter; | |
continue; | |
} | |
if (bytes_transferred > accumulated_size) | |
buffer = buffer + (bytes_transferred - accumulated_size); | |
bytes_transferred += buffer.size(); | |
buffer = eng.put_input(buffer); | |
bytes_transferred -= buffer.size(); | |
break; | |
} | |
} | |
}; | |
} | |
template <typename Operation> | |
std::size_t | |
sync_io(const Operation& op, boost::system::error_code& ec) | |
{ | |
boost::system::error_code io_ec; | |
std::size_t bytes_transferred = 0; | |
net::mutable_buffer write_buf; | |
do switch (op(engine_, ec, bytes_transferred)) | |
{ | |
case engine::want_input_and_retry: | |
if (input_.size() == 0) | |
{ | |
input_ = net::buffer | |
( | |
input_buffer_, | |
next_layer_.read_some(input_buffer_, io_ec) | |
); | |
if (!ec) | |
ec = io_ec; | |
} | |
input_ = engine_.put_input(input_); | |
continue; | |
case engine::want_output_and_retry: | |
write_buf = engine_.get_output(output_buffer_); | |
net::write( | |
next_layer_, | |
write_buf, | |
io_ec); | |
if (!ec) | |
ec = io_ec; | |
continue; | |
case engine::want_output: | |
write_buf = engine_.get_output(output_buffer_); | |
net::write(next_layer_, | |
write_buf, | |
io_ec); | |
if (!ec) | |
ec = io_ec; | |
engine_.map_error_code(ec); | |
return bytes_transferred; | |
default: | |
engine_.map_error_code(ec); | |
return bytes_transferred; | |
} while(!ec); | |
engine_.map_error_code(ec); | |
return 0; | |
} | |
template <typename Operation, typename Handler> | |
void async_io(const Operation& op, Handler& handler) | |
{ | |
net::co_spawn(get_executor(), | |
[this, op = op, handler = std::move(handler)]() mutable -> net::awaitable<void> | |
{ | |
boost::system::error_code ec; | |
boost::system::error_code io_ec; | |
std::size_t bytes_transferred = 0; | |
net::mutable_buffer write_buf; | |
do switch (op(engine_, ec, bytes_transferred)) | |
{ | |
case engine::want_input_and_retry: | |
if (input_.size() == 0) | |
{ | |
input_ = net::buffer | |
( | |
input_buffer_, | |
co_await next_layer_.async_read_some( | |
input_buffer_, net_awaitable[io_ec]) | |
); | |
if (!ec) | |
ec = io_ec; | |
} | |
input_ = engine_.put_input(input_); | |
continue; | |
case engine::want_output_and_retry: | |
write_buf = engine_.get_output(output_buffer_); | |
co_await net::async_write( | |
next_layer_, | |
write_buf, | |
net_awaitable[io_ec]); | |
if (!ec) | |
ec = io_ec; | |
continue; | |
case engine::want_output: | |
write_buf = engine_.get_output(output_buffer_); | |
co_await net::async_write(next_layer_, | |
write_buf, | |
net_awaitable[io_ec]); | |
if (!ec) | |
ec = io_ec; | |
engine_.map_error_code(ec); | |
handler(ec, bytes_transferred); | |
co_return; | |
default: | |
engine_.map_error_code(ec); | |
handler(ec, bytes_transferred); | |
co_return; | |
} while(!ec); | |
engine_.map_error_code(ec); | |
handler(ec, 0); | |
co_return; | |
}, net::detached); | |
} | |
public: | |
using native_handle_type = SSL*; | |
using next_layer_type = typename std::remove_reference<Stream>::type; | |
using lowest_layer_type = typename next_layer_type::lowest_layer_type; | |
using executor_type = typename lowest_layer_type::executor_type; | |
public: | |
template <typename Arg> | |
ssl_stream(Arg&& arg, net::ssl::context& ctx) | |
: next_layer_(std::move(arg)) | |
, context_(ctx) | |
, engine_(ctx.native_handle()) | |
, input_buffer_space_(max_tls_record_size) | |
, output_buffer_space_(max_tls_record_size) | |
, input_buffer_(boost::asio::buffer(input_buffer_space_)) | |
, output_buffer_(boost::asio::buffer(output_buffer_space_)) | |
{ | |
} | |
~ssl_stream() = default; | |
ssl_stream(ssl_stream&& other) | |
: next_layer_(std::move(other.next_layer_)) | |
, context_(other.context_) | |
, engine_(std::move(other.engine_)) | |
, input_buffer_space_(std::move(other.input_buffer_space_)) | |
, output_buffer_space_(std::move(other.output_buffer_space_)) | |
{ | |
input_buffer_ = boost::asio::buffer(input_buffer_space_); | |
output_buffer_ = boost::asio::buffer(output_buffer_space_); | |
input_ = std::move(other.input_); | |
} | |
ssl_stream& operator=(ssl_stream&& other) | |
{ | |
if (this != &other) | |
{ | |
next_layer_ = std::move(other.next_layer_); | |
context_ = other.context_; | |
engine_ = std::move(other.engine_); | |
input_buffer_space_ = std::move(other.input_buffer_space_); | |
output_buffer_space_ = std::move(other.output_buffer_space_); | |
input_buffer_ = boost::asio::buffer(input_buffer_space_); | |
output_buffer_ = boost::asio::buffer(output_buffer_space_); | |
input_ = std::move(other.input_); | |
} | |
return *this; | |
} | |
executor_type get_executor() noexcept | |
{ | |
return next_layer_.lowest_layer().get_executor(); | |
} | |
native_handle_type native_handle() | |
{ | |
return engine_.native_handle(); | |
} | |
const next_layer_type& next_layer() const | |
{ | |
return next_layer_; | |
} | |
next_layer_type& next_layer() | |
{ | |
return next_layer_; | |
} | |
lowest_layer_type& lowest_layer() | |
{ | |
return next_layer_.lowest_layer(); | |
} | |
const lowest_layer_type& lowest_layer() const | |
{ | |
return next_layer_.lowest_layer(); | |
} | |
void set_verify_mode(net::ssl::verify_mode v) | |
{ | |
boost::system::error_code ec; | |
set_verify_mode(v, ec); | |
net::detail::throw_error(ec, "set_verify_mode"); | |
} | |
void set_verify_mode(net::ssl::verify_mode v, boost::system::error_code& ec) | |
{ | |
engine_.set_verify_mode(v, ec); | |
} | |
void set_verify_depth(int depth) | |
{ | |
boost::system::error_code ec; | |
set_verify_depth(depth, ec); | |
net::detail::throw_error(ec, "set_verify_depth"); | |
} | |
void set_verify_depth(int depth, boost::system::error_code& ec) | |
{ | |
engine_.set_verify_depth(depth, ec); | |
} | |
template <typename VerifyCallback> | |
void set_verify_callback(VerifyCallback callback) | |
{ | |
boost::system::error_code ec; | |
this->set_verify_callback(callback, ec); | |
net::detail::throw_error(ec, "set_verify_callback"); | |
} | |
template <typename VerifyCallback> | |
void set_verify_callback(VerifyCallback callback, boost::system::error_code& ec) | |
{ | |
engine_.set_verify_callback( | |
new net::ssl::detail::verify_callback<VerifyCallback>(callback), ec); | |
} | |
void handshake(handshake_type type) | |
{ | |
boost::system::error_code ec; | |
handshake(type, ec); | |
net::detail::throw_error(ec, "handshake"); | |
} | |
void handshake(handshake_type type, boost::system::error_code& ec) | |
{ | |
sync_io(handshake_op(type), ec); | |
} | |
template <typename ConstBufferSequence> | |
void handshake(handshake_type type, const ConstBufferSequence& buffers) | |
{ | |
boost::system::error_code ec; | |
handshake(type, buffers, ec); | |
net::detail::throw_error(ec, "handshake"); | |
} | |
template <typename ConstBufferSequence> | |
void handshake(handshake_type type, | |
const ConstBufferSequence& buffers, boost::system::error_code& ec) | |
{ | |
sync_io(buffered_handshake_op<ConstBufferSequence>(type, buffers), ec); | |
} | |
template < | |
BOOST_ASIO_COMPLETION_TOKEN_FOR(void (boost::system::error_code)) | |
HandshakeToken = net::default_completion_token_t<executor_type>> | |
auto async_handshake(handshake_type type, | |
HandshakeToken&& token = net::default_completion_token_t<executor_type>()) | |
{ | |
return net::async_initiate<HandshakeToken, | |
void (boost::system::error_code, std::size_t)>( | |
[this] (auto handler, auto type) mutable | |
{ | |
async_io(handshake_op(type), handler); | |
}, token, type); | |
} | |
template <typename ConstBufferSequence, | |
BOOST_ASIO_COMPLETION_TOKEN_FOR(void (boost::system::error_code, | |
std::size_t)) BufferedHandshakeToken | |
= net::default_completion_token_t<executor_type>> | |
auto async_handshake(handshake_type type, const ConstBufferSequence& buffers, | |
BufferedHandshakeToken&& token | |
= net::default_completion_token_t<executor_type>()) | |
{ | |
return net::async_initiate<BufferedHandshakeToken, | |
void (boost::system::error_code, std::size_t)>( | |
[this] (auto handler, auto type, auto buffers) mutable | |
{ | |
async_io(buffered_handshake_op< | |
decltype(buffers)>(type, buffers), handler); | |
}, token, type, buffers); | |
} | |
void shutdown() | |
{ | |
boost::system::error_code ec; | |
shutdown(ec); | |
net::detail::throw_error(ec, "shutdown"); | |
} | |
void shutdown(boost::system::error_code& ec) | |
{ | |
sync_io(shutdown_op(), ec); | |
} | |
template < | |
BOOST_ASIO_COMPLETION_TOKEN_FOR(void (boost::system::error_code)) | |
ShutdownToken | |
= net::default_completion_token_t<executor_type>> | |
auto async_shutdown( | |
ShutdownToken&& token = net::default_completion_token_t<executor_type>()) | |
{ | |
return net::async_initiate<ShutdownToken, | |
void (boost::system::error_code, std::size_t)>( | |
[this] (auto handler) mutable | |
{ | |
async_io(shutdown_op(), handler); | |
}, token); | |
} | |
template <typename ConstBufferSequence> | |
std::size_t write_some(const ConstBufferSequence& buffers) | |
{ | |
boost::system::error_code ec; | |
std::size_t n = write_some(buffers, ec); | |
net::detail::throw_error(ec, "write_some"); | |
return n; | |
} | |
template <typename ConstBufferSequence> | |
std::size_t write_some(const ConstBufferSequence& buffers, | |
boost::system::error_code& ec) | |
{ | |
return sync_io(write_op<ConstBufferSequence>(buffers), ec); | |
} | |
template <typename ConstBufferSequence, | |
BOOST_ASIO_COMPLETION_TOKEN_FOR(void (boost::system::error_code, | |
std::size_t)) WriteToken = net::default_completion_token_t<executor_type>> | |
auto async_write_some(const ConstBufferSequence& buffers, | |
WriteToken&& token = net::default_completion_token_t<executor_type>()) | |
{ | |
return net::async_initiate<WriteToken, | |
void (boost::system::error_code, std::size_t)>( | |
[this] (auto handler, auto buffers) mutable | |
{ | |
async_io(write_op<decltype(buffers)>(buffers), handler); | |
}, token, buffers); | |
} | |
template <typename MutableBufferSequence> | |
std::size_t read_some(const MutableBufferSequence& buffers) | |
{ | |
boost::system::error_code ec; | |
std::size_t n = read_some(buffers, ec); | |
net::detail::throw_error(ec, "read_some"); | |
return n; | |
} | |
template <typename MutableBufferSequence> | |
std::size_t read_some(const MutableBufferSequence& buffers, | |
boost::system::error_code& ec) | |
{ | |
return sync_io(read_op<MutableBufferSequence>(buffers), ec); | |
} | |
template <typename MutableBufferSequence, | |
BOOST_ASIO_COMPLETION_TOKEN_FOR(void (boost::system::error_code, | |
std::size_t)) ReadToken = net::default_completion_token_t<executor_type>> | |
auto async_read_some(const MutableBufferSequence& buffers, | |
ReadToken&& token = net::default_completion_token_t<executor_type>()) | |
{ | |
return net::async_initiate<ReadToken, | |
void (boost::system::error_code, std::size_t)>( | |
[this] (auto handler, auto buffers) mutable | |
{ | |
async_io(read_op<decltype(buffers)>(buffers), handler); | |
}, token, buffers); | |
} | |
private: | |
Stream next_layer_; | |
net::ssl::context& context_; | |
engine engine_; | |
std::vector<unsigned char> output_buffer_space_; | |
net::mutable_buffer output_buffer_; | |
std::vector<unsigned char> input_buffer_space_; | |
net::mutable_buffer input_buffer_; | |
net::const_buffer input_; | |
}; | |
} | |
#endif // INCLUDE__2023_11_24__SSL_STREAM_HPP |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment