Skip to content

Instantly share code, notes, and snippets.

@MichaelNesterenko
Created September 23, 2020 20:52
Show Gist options
  • Save MichaelNesterenko/c6906348a49bb3fcdd174584c53242a5 to your computer and use it in GitHub Desktop.
Save MichaelNesterenko/c6906348a49bb3fcdd174584c53242a5 to your computer and use it in GitHub Desktop.
sql server certificate chain
package mn
import java.nio.ByteBuffer
import java.nio.channels.SocketChannel
import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLEngine
import javax.net.ssl.SSLEngineResult
import javax.net.ssl.TrustManager
import javax.net.ssl.X509TrustManager
def host = args[0];
def port = args[1] as int;
def socket = createSocket(host, port);
def sslEngine = createSslEngine(host, port);
prelogin(socket);
handshakeViaTds(sslEngine, socket);
dumpCertificates(sslEngine.handshakeSession.peerCertificates);
SocketChannel createSocket(String host, int port) {
def socket = SocketChannel.open(new InetSocketAddress(host, port));
socket.configureBlocking(false);
return socket;
}
SSLEngine createSslEngine(String host, int port) {
def sslContext = SSLContext.getInstance("TLSv1.2");
sslContext.init(
null,
[
new X509TrustManager() {
@Override
public void checkClientTrusted(X509Certificate[] chain, String authType) { }
@Override
public void checkServerTrusted(X509Certificate[] chain, String authType) { }
@Override
public X509Certificate[] getAcceptedIssuers() { return null; }
}
] as TrustManager[],
null
);
def sslEngine = sslContext.createSSLEngine(host, port);
sslEngine.useClientMode = true;
sslEngine.SSLParameters.maximumPacketSize = 102400;
return sslEngine;
}
void dumpCertificates(Certificate... certificates) {
certificates.each { cert ->
println(cert);
println("-----BEGIN CERTIFICATE-----");
println(Base64.getEncoder().encodeToString(cert.encoded).replaceAll("(.{64})", '$1\n'));
println("-----END CERTIFICATE-----");
}
}
void prelogin(SocketChannel socket) {
byte[] PRE_LOGIN_PACKET = [
0x12, 0x01, 0x00, 0x2F, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x1A, 0x00, 0x06, 0x01, 0x00, 0x20,
0x00, 0x01, 0x02, 0x00, 0x21, 0x00, 0x01, 0x03, 0x00, 0x22, 0x00, 0x04, 0x04, 0x00, 0x26, 0x00,
0x01, 0xFF, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0xB8, 0x0D, 0x00, 0x00, 0x01
];
def networkBuffer = ByteBuffer.allocate(1024);
write(socket, ByteBuffer.wrap(PRE_LOGIN_PACKET));
readPacket(socket, networkBuffer);
}
void handshakeViaTds(SSLEngine sslEngine, SocketChannel socket) {
def netData = ByteBuffer.allocate(102400);
def appData = ByteBuffer.allocate(102400);
sslEngine.beginHandshake();
SSLEngineResult.HandshakeStatus handshakeStatus = sslEngine.handshakeStatus;
SSLEngineResult.Status status = null;
while (true) {
switch (handshakeStatus) {
case SSLEngineResult.HandshakeStatus.NEED_WRAP:
println("Need wrap");
netData.clear();
switch (sslEngine.wrap(ByteBuffer.allocate(0), netData).status) {
case SSLEngineResult.Status.OK:
netData.flip();
writePacket(socket, netData);
netData.clear();
netData.flip();
break;
}
handshakeStatus = sslEngine.handshakeStatus;
status = null;
break;
case SSLEngineResult.HandshakeStatus.NEED_UNWRAP:
println("Need unwrap ${netData}");
if (status == null || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
netData.compact();
readPacket(socket, netData);
println("Received ${netData}");
}
def unwrapStatus = sslEngine.unwrap(netData, appData);
println(unwrapStatus);
handshakeStatus = unwrapStatus.handshakeStatus;
status = unwrapStatus.status;
break;
case SSLEngineResult.HandshakeStatus.NEED_TASK:
def task;
println("Need task");
while ((task = sslEngine.getDelegatedTask()) != null) {
task.run();
}
handshakeStatus = sslEngine.handshakeStatus
if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
return;
}
break;
default:
fail("Unexpected status: ${handshakeStatus}");
break;
}
}
}
void writePacket(SocketChannel socket, ByteBuffer buffer) {
int lengthWithHeader = buffer.remaining() + 8;
write(
socket,
ByteBuffer.wrap([0x12, 0x01, (lengthWithHeader >> 8) & 0xff, lengthWithHeader & 0xff, 0x00, 0x00, 0x01, 0x00] as byte[])
);
write(socket, buffer);
}
void readPacket(SocketChannel socket, ByteBuffer buffer) {
read(socket, buffer, 8);
int length = buffer.getShort(2) & 0xffff;
int needToRead = length - buffer.remaining();
buffer.position(buffer.position() + 8);
def packetBuffer = buffer.duplicate().position(buffer.limit()).limit(buffer.capacity());
read(socket, packetBuffer, needToRead);
buffer.limit(packetBuffer.limit());
}
ByteBuffer read(SocketChannel socket, ByteBuffer buffer) {
return read(socket, buffer, 0)
}
ByteBuffer read(SocketChannel socket, ByteBuffer buffer, int minSize) {
while (({ ->
int readCount = socket.read(buffer);
minSize -= readCount;
if (readCount < 0) {
fail("Could not read");
}
return minSize > 0 && buffer.hasRemaining();
})()) {}
if (minSize > 0) {
fail("Was unable to read minimum size");
}
buffer.flip();
return buffer;
}
void write(SocketChannel socket, ByteBuffer buffer) {
while (buffer.hasRemaining()) {
if (socket.write(buffer) < 0) {
fail("Could not write");
}
}
}
void fail(String msg) {
throw new IllegalStateException(msg);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment