Skip to content

Instantly share code, notes, and snippets.

@mpenick
Created May 23, 2019 17:01
Show Gist options
  • Save mpenick/0893ca6f3c2171ee598003183b8af0a5 to your computer and use it in GitHub Desktop.
Save mpenick/0893ca6f3c2171ee598003183b8af0a5 to your computer and use it in GitHub Desktop.
/*
Copyright (c) DataStax, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "loop_test.hpp"
#include "connector.hpp"
#include "map.hpp"
#include "socket_connector.hpp"
#include "ssl.hpp"
#include "startup_request.hpp"
using datastax::String;
using datastax::internal::Map;
using namespace datastax::internal::core;
class MockSocketHandler : public SocketHandler {
public:
MockSocketHandler(size_t expected_response_count, Vector<Response::Ptr>& responses)
: expected_response_count_(expected_response_count)
, responses_(responses)
, response_(new ResponseMessage()) {}
virtual void on_read(Socket* socket, ssize_t nread, const uv_buf_t* buf) {
if (nread > 0) {
response_->decode(buf->base, nread);
if (response_->is_body_ready()) {
responses_.push_back(response_->response_body());
if (responses_.size() == expected_response_count_) {
socket->close();
} else {
response_.reset(new ResponseMessage());
}
}
}
free_buffer(buf);
}
virtual void on_write(Socket* socket, int status, SocketRequest* request) { delete request; }
virtual void on_close() {}
private:
size_t expected_response_count_;
Vector<Response::Ptr>& responses_;
ScopedPtr<ResponseMessage> response_;
};
class MockssandraUnitTest : public LoopTest {
public:
MockssandraUnitTest()
: expected_response_count_(0) {}
const Vector<Response::Ptr>& responses() const { return responses_; }
void on_connected(SocketConnector* connector) {
Socket::Ptr socket = connector->release_socket();
if (connector->error_code() == SocketConnector::SOCKET_OK) {
assert(!connector->ssl_session() && "SSL not supported");
socket->set_handler(new MockSocketHandler(expected_response_count_, responses_));
for (BufferVec::const_iterator it = requests_.begin(), end = requests_.end(); it != end;
++it) {
socket->write(new BufferSocketRequest(*it));
}
socket->flush();
} else {
ASSERT_TRUE(false) << "Failed to connect: " << connector->error_message();
}
}
void append_startup_request(ProtocolVersion version,
const Map<String, String>& extra = Map<String, String>()) {
Map<String, String> options;
options["CQL_VERSION"] = CASS_DEFAULT_CQL_VERSION;
size_t length = start_up_request_length(options);
Buffer header(version >= 3 ? 9 : 8);
size_t offset = 0;
offset = header.encode_byte(offset, version.value());
offset = header.encode_byte(offset, 0);
if (version >= 3) {
offset = header.encode_int16(offset, 0);
} else {
offset = header.encode_byte(offset, 0);
}
offset = header.encode_byte(offset, CQL_OPCODE_STARTUP);
offset = header.encode_int32(offset, length);
requests_.push_back(header);
requests_.push_back(Buffer(length));
requests_.back().encode_string_map(0, options);
expected_response_count_++;
}
void send_requests() {
SocketConnector::Ptr connector(new SocketConnector(
Address("127.0.0.1", 9042), bind_callback(&MockssandraUnitTest::on_connected, this)));
connector->connect(loop());
}
private:
size_t start_up_request_length(const Map<String, String>& options) {
// <options> [string map]
size_t length = sizeof(uint16_t);
for (Map<String, String>::const_iterator it = options.begin(), end = options.end(); it != end;
++it) {
length += sizeof(uint16_t) + it->first.size();
length += sizeof(uint16_t) + it->second.size();
}
return length;
}
private:
BufferVec requests_;
size_t expected_response_count_;
Vector<Response::Ptr> responses_;
};
TEST_F(MockssandraUnitTest, TooLow) {
mockssandra::SimpleCluster cluster(
mockssandra::SimpleRequestHandlerBuilder().with_supported_protocol_versions(3, 4).build());
ASSERT_EQ(cluster.start_all(), 0);
BufferVec request;
append_startup_request(1);
send_requests();
uv_run(loop(), UV_RUN_DEFAULT);
const Vector<Response::Ptr>& responses(this->responses());
ASSERT_EQ(responses.size(), 1);
EXPECT_EQ(responses[0]->opcode(), CQL_OPCODE_ERROR);
}
TEST_F(MockssandraUnitTest, TooHigh) {
mockssandra::SimpleCluster cluster(
mockssandra::SimpleRequestHandlerBuilder().with_supported_protocol_versions(3, 4).build());
ASSERT_EQ(cluster.start_all(), 0);
BufferVec request;
append_startup_request(9);
send_requests();
uv_run(loop(), UV_RUN_DEFAULT);
const Vector<Response::Ptr>& responses(this->responses());
ASSERT_EQ(responses.size(), 1);
EXPECT_EQ(responses[0]->opcode(), CQL_OPCODE_ERROR);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment