Last active
July 2, 2021 14:18
-
-
Save brimworks/409e6b847a969896ace387385447c2c6 to your computer and use it in GitHub Desktop.
Implements the ByteChannel and GatheringByteChannel interfaces in a blocking way.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package io.nats.client.support; | |
import java.io.IOException; | |
import java.nio.BufferUnderflowException; | |
import java.nio.ByteBuffer; | |
import java.nio.channels.ReadableByteChannel; | |
import static java.nio.charset.StandardCharsets.UTF_8; | |
/** | |
* Note that a ByteBuffer is always in one of two modes: | |
* | |
* <ul> | |
* <li>Get mode: Valid data is between position and limit. Some documentation may refer to | |
* this as "read" mode, since the buffer is ready for various <code>.get()</code> method calls, | |
* however this can be confusing, since you would NEVER call the | |
* {@link java.nio.channels.ReadableByteChannel#read(ByteBuffer)} method when operating in | |
* this mode. This may also be referred to as "flush" mode, but nothing in the Buffer | |
* documentation refers to the term "flush", and thus we use the term "get" mode here. | |
* <li>Put mode: Valid data is between 0 and position. Some documentation may refer to this as | |
* "write" mode, since the buffer is ready for varous <code>.put()</code> method calls, | |
* however this can be confusing since you would NEVER call the | |
* {@link java.nio.channels.WritableByteChannel#write(ByteBuffer)} method when operating in | |
* this mode. This may also be referred to as "fill" mode, but the term fill is only used | |
* in one place in the Buffer.clear() documentation. | |
* </ul> | |
* | |
* All documentation will be using the terms "get mode" or "put mode". | |
*/ | |
public interface BufferUtils { | |
static final char[] HEX_ARRAY = "0123456789ABCDEF".toCharArray(); | |
static final char SUBSTITUTE_CHAR = 0x2423; | |
/** | |
* It is not to uncommon to buffer data into a temporary buffer | |
* and then append this temporary buffer into a destination buffer. | |
* This method makes this operation easy since your temporary buffer | |
* is typically in "put" mode and your destination is always in "put" | |
* mode, and you need to take into account that the temporary buffer | |
* may contain more bytes than your destination buffer, but you don't | |
* want a BufferOverflowException to occur, instead you just want to | |
* fullfill as many bytes from your temporary buffer as is possible. | |
* | |
* See <code>org.eclipse.jetty.util.BufferUtil.append(ByteBuffer, ByteBuffer)</code> | |
* for a similar method. | |
* | |
* @param src is a buffer in "put" mode which will be flip'ed | |
* and then "safely" put into dst followed by a compact call. | |
* @param dst is a buffer in "put" mode which will be populated | |
* from src. | |
* @param max is the max bytes to transfer. | |
* @return min(src.position(), dst.position(), max) | |
*/ | |
static int append(ByteBuffer src, ByteBuffer dst, int max) { | |
if (src.position() < max) { | |
max = src.position(); | |
} | |
if (dst.remaining() < max) { | |
max = dst.remaining(); | |
} | |
src.flip(); | |
try { | |
ByteBuffer slice = src.slice(); | |
slice.limit(max); | |
dst.put(slice); | |
} finally { | |
src.position(max); | |
src.compact(); | |
} | |
return max; | |
} | |
/** | |
* Delegates to {@link #append(ByteBuffer,ByteBuffer,int)}, with | |
* max set to Integer.MAX_VALUE. | |
* | |
* @param src is a buffer in "put" mode which will be flip'ed | |
* and then "safely" put into dst followed by a compact call. | |
* @param dst is a buffer in "put" mode which will be populated | |
* from src. | |
* @return min(src.position(), dst.position()) | |
*/ | |
static int append(ByteBuffer src, ByteBuffer dst) { | |
return append(src, dst, Integer.MAX_VALUE); | |
} | |
/** | |
* Throws BufferUnderflowException if there are insufficient capacity in | |
* buffer to fillfill the request. | |
* | |
* @param readBuffer is in "put" mode (0 - position are valid) | |
* @param reader is a reader used to populate the buffer if insufficient remaining | |
* bytes exist in buffer. May be null if buffer should not be populated. | |
* @throws BufferUnderflowException if the buffer has insufficient capacity | |
* to read a full line. | |
* @throws IOException if reader.read() throws this exception. | |
* @return a line without line terminators or null if end of channel. | |
*/ | |
static String readLine(ByteBuffer readBuffer, ReadableByteChannel reader) throws IOException { | |
if (null == readBuffer) { | |
throw new NullPointerException("Expected non-null readBuffer"); | |
} | |
int end = 0; | |
boolean foundCR = false; | |
int newlineLength = 1; | |
FIND_END: | |
while (true) { | |
if (end >= readBuffer.position()) { | |
if (readBuffer.position() == readBuffer.limit()) { | |
// Insufficient capacity in ByteBuffer to read a full line! | |
throw new BufferUnderflowException(); | |
} | |
if (null == reader || reader.read(readBuffer) < 0) { | |
if (end > 0) { | |
if (!foundCR) { | |
newlineLength = 0; | |
} | |
break FIND_END; | |
} | |
return null; | |
} | |
} | |
switch (readBuffer.get(end++)) { | |
case '\r': | |
if (foundCR) { | |
--end; | |
break FIND_END; // Legacy MAC end of line | |
} | |
foundCR = true; | |
break; | |
case '\n': | |
if (foundCR) { | |
newlineLength++; | |
} | |
break FIND_END; | |
default: | |
if (foundCR) { | |
--end; | |
break FIND_END; // Legacy MAC end of line | |
} | |
} | |
} | |
String result; | |
readBuffer.flip(); | |
try { | |
ByteBuffer slice = readBuffer.slice(); | |
slice.limit(end - newlineLength); | |
result = UTF_8.decode(slice).toString(); | |
} finally { | |
readBuffer.position(end); | |
readBuffer.compact(); | |
} | |
return result; | |
} | |
static long remaining(ByteBuffer[] buffers, int offset, int length) { | |
int total = 0; | |
int end = offset + length; | |
while (offset < end) { | |
total += buffers[offset++].remaining(); | |
} | |
return total; | |
} | |
/** | |
* Utility method for stringifying a bytebuffer (use ByteBuffer.wrap(byte[]) | |
* if you want to stringify a byte array). Mostly useful for debugging or | |
* tracing. | |
* | |
* @param bytes is the byte buffer. | |
* @param off is the offset within the byte buffer to begin. | |
* @param len is the number of bytes to print. | |
* @return a "hexdump" of the bytes | |
*/ | |
static String hexdump(ByteBuffer bytes, int off, int len) { | |
int end = off + len; | |
StringBuilder sb = new StringBuilder(); | |
for (int i=off; i < end;) { | |
sb.append(String.format("%04x ", i)); | |
int start = i; | |
do { | |
int ch = bytes.get(i) & 0xFF; | |
sb.append(" "); | |
if (i % 16 == 8) { | |
sb.append(" "); | |
} | |
sb.append(HEX_ARRAY[ch >>> 4]); | |
sb.append(HEX_ARRAY[ch & 0x0F]); | |
} while (++i % 16 != 0 && i < end); | |
if (i % 16 != 0) { | |
sb.append(new String(new char[16 - i % 16]).replace("\0", " ")); | |
if (i % 16 < 7) { | |
sb.append(" "); | |
} | |
} | |
sb.append(" "); | |
i = start; | |
do { | |
char ch = (char)bytes.get(i); | |
if (ch < 0x21) { | |
// Control chars: | |
switch (ch) { | |
case ' ': | |
sb.append((char)0x2420); | |
break; | |
case '\t': | |
sb.append((char)0x2409); | |
break; | |
case '\r': | |
sb.append((char)0x240D); | |
break; | |
case '\n': | |
sb.append((char)0x2424); | |
break; | |
default: | |
sb.append(SUBSTITUTE_CHAR); | |
} | |
} else if (ch < 0x7F) { | |
sb.append(ch); | |
} else { | |
// control chars: | |
sb.append(SUBSTITUTE_CHAR); | |
} | |
} while (++i % 16 != 0 && i < end); | |
sb.append("\n"); | |
} | |
return sb.toString(); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package io.nats.client.channels; | |
import java.io.IOException; | |
import java.nio.ByteBuffer; | |
import java.nio.channels.ByteChannel; | |
import java.nio.channels.ClosedChannelException; | |
import java.nio.channels.GatheringByteChannel; | |
import java.util.concurrent.TimeUnit; | |
import java.util.concurrent.locks.Lock; | |
import java.util.concurrent.locks.ReentrantLock; | |
import javax.net.ssl.SSLEngine; | |
import javax.net.ssl.SSLEngineResult; | |
import javax.net.ssl.SSLEngineResult.HandshakeStatus; | |
import static io.nats.client.support.BufferUtils.append; | |
import static io.nats.client.support.BufferUtils.remaining; | |
/** | |
* It blows my mind that JDK doesn't provide this functionality by default. | |
* | |
* This is an implementation of ByteChannel which uses an SSLEngine to encrypt data sent | |
* to a ByteChannel that is being wrapped, and then decrypts received data. | |
*/ | |
public class TLSByteChannel implements ByteChannel, GatheringByteChannel { | |
private static final ByteBuffer[] EMPTY = new ByteBuffer[]{ByteBuffer.allocate(0)}; | |
private final ByteChannel wrap; | |
private final SSLEngine engine; | |
// NOTE: Locks should always be acquired in this order: | |
// readLock > writeLock > stateLock | |
private final Lock readLock = new ReentrantLock(); | |
private final Lock writeLock = new ReentrantLock(); | |
// All of the below state is controlled with this lock: | |
private final Object stateLock = new Object(); | |
private Thread readThread = null; | |
private Thread writeThread = null; | |
private State state; | |
// end state protected by the stateLock | |
private final ByteBuffer outNetBuffer; // in "get" mode, protected by writeLock | |
private final ByteBuffer inNetBuffer; // in "put" mode, protected by readLock | |
private final ByteBuffer inAppBuffer; // in "put" mode, protected by readLock | |
private enum State { | |
HANDSHAKING_READ, | |
HANDSHAKING_WRITE, | |
HANDSHAKING_TASK, | |
OPEN, | |
CLOSING, | |
CLOSED; | |
} | |
public TLSByteChannel(ByteChannel wrap, SSLEngine engine) throws IOException { | |
this.wrap = wrap; | |
this.engine = engine; | |
int netBufferSize = engine.getSession().getPacketBufferSize(); | |
int appBufferSize = engine.getSession().getApplicationBufferSize(); | |
outNetBuffer = ByteBuffer.allocate(netBufferSize); | |
outNetBuffer.flip(); | |
inNetBuffer = ByteBuffer.allocate(netBufferSize); | |
inAppBuffer = ByteBuffer.allocate(appBufferSize); | |
engine.beginHandshake(); | |
state = toState(engine.getHandshakeStatus()); | |
} | |
/** | |
* Translate an SSLEngine.HandshakeStatus into internal state. | |
*/ | |
private static State toState(HandshakeStatus status) { | |
switch (status) { | |
case NEED_TASK: | |
return State.HANDSHAKING_TASK; | |
case NEED_UNWRAP: | |
return State.HANDSHAKING_READ; | |
case NEED_WRAP: | |
return State.HANDSHAKING_WRITE; | |
case FINISHED: | |
case NOT_HANDSHAKING: | |
return State.OPEN; | |
default: | |
throw new IllegalStateException("Unexpected SSLEngine.HandshakeStatus=" + status); | |
} | |
} | |
/** | |
* Force a TLS handshake to take place if it has not already happened. | |
* @return false if end of file is observed. | |
* @throws IOException if any underlying read or write call throws. | |
*/ | |
public boolean handshake() throws IOException { | |
while (true) { | |
boolean needsRead = false; | |
boolean needsWrite = false; | |
synchronized (stateLock) { | |
switch (state) { | |
case HANDSHAKING_TASK: | |
executeTasks(); | |
state = toState(engine.getHandshakeStatus()); | |
break; | |
case HANDSHAKING_READ: | |
needsRead = true; | |
break; | |
case HANDSHAKING_WRITE: | |
needsWrite = true; | |
break; | |
default: | |
return true; | |
} | |
} | |
if (needsRead) { | |
if (readImpl(EMPTY[0]) < 0) { | |
return false; | |
} | |
} else if (needsWrite) { | |
if (writeImpl(EMPTY, 0, 1) < 0) { | |
return false; | |
} | |
} | |
} | |
} | |
/** | |
* Gracefully close the TLS session and the underlying wrap'ed socket. | |
*/ | |
@Override | |
public void close() throws IOException { | |
if (!wrap.isOpen()) { | |
return; | |
} | |
// [1] Make sure any handshake has happened: | |
handshake(); | |
// [2] Set state to closing, or return if another thread | |
// is already closing. | |
synchronized (stateLock) { | |
if (State.CLOSED == state || State.CLOSING == state) { | |
return; | |
} else { | |
state = State.CLOSING; | |
} | |
// [3] Interrupt any reading/writing threads: | |
if (null != readThread) { | |
readThread.interrupt(); | |
} | |
if (null != writeThread) { | |
writeThread.interrupt(); | |
} | |
} | |
// [4] Try to acquire readLock: | |
try { | |
if (!readLock.tryLock(100, TimeUnit.MICROSECONDS)) { | |
wrap.close(); | |
return; | |
} | |
try { | |
// [5] Try to acquire writeLock: | |
if (!writeLock.tryLock(100, TimeUnit.MICROSECONDS)) { | |
wrap.close(); | |
return; | |
} | |
try { | |
// [6] Finally, implement close sequence. | |
closeImpl(); | |
} finally { | |
writeLock.unlock(); | |
} | |
} finally { | |
readLock.unlock(); | |
} | |
} catch (InterruptedException ex) { | |
// Non-graceful close! | |
Thread.currentThread().interrupt(); | |
wrap.close(); | |
return; | |
} | |
} | |
@Override | |
public boolean isOpen() { | |
return wrap.isOpen(); | |
} | |
/** | |
* Implement the close procedure. | |
* | |
* Precondition: read & write locks are acquired | |
* | |
* Postcondition: state is CLOSED | |
*/ | |
private void closeImpl() throws IOException { | |
synchronized (stateLock) { | |
if (State.CLOSED == state) { | |
return; | |
} | |
state = State.CLOSING; | |
} | |
try { | |
// NOTE: unread data may be lost. However, we assume this is desired | |
// since we are transitioning to closing: | |
inAppBuffer.clear(); | |
if (outNetBuffer.hasRemaining()) { | |
wrap.write(outNetBuffer); | |
} | |
engine.closeOutbound(); | |
try { | |
while (!engine.isOutboundDone()) { | |
if (writeImpl(EMPTY, 0, 1) < 0) { | |
throw new ClosedChannelException(); | |
} | |
} | |
while (!engine.isInboundDone()) { | |
if (readImpl(EMPTY[0]) < 0) { | |
throw new ClosedChannelException(); | |
} | |
} | |
engine.closeInbound(); | |
} catch (ClosedChannelException ex) { | |
// already closed, ignore. | |
} | |
} finally { | |
try { | |
// No matter what happens, we need to close the | |
// wrapped channel: | |
wrap.close(); | |
} finally { | |
// ...and no matter what happens, we need to | |
// indicate that we are in a CLOSED state: | |
synchronized (stateLock) { | |
state = State.CLOSED; | |
} | |
} | |
} | |
} | |
/** | |
* Read plaintext by decrypting the underlying wrap'ed sockets encrypted bytes. | |
* | |
* @param dst is the buffer to populate between position and limit. | |
* @return the number of bytes populated or -1 to indicate end of stream, | |
* and the dst position will also be incremented appropriately. | |
*/ | |
@Override | |
public int read(ByteBuffer dst) throws IOException { | |
int result = 0; | |
while (0 == result) { | |
if (!handshake()) { | |
return -1; | |
} | |
if (!dst.hasRemaining()) { | |
return 0; | |
} | |
result = readImpl(dst); | |
} | |
return result; | |
} | |
/** | |
* Precondition: handshake() was called, or this code was called | |
* by the handshake() implementation. | |
*/ | |
private int readImpl(ByteBuffer dst) throws IOException { | |
readLock.lock(); | |
try { | |
// [1] Check if this is a read for a handshake: | |
synchronized (stateLock) { | |
if (isHandshaking(state)) { | |
if (state != State.HANDSHAKING_READ) { | |
return 0; | |
} | |
dst = EMPTY[0]; | |
} | |
readThread = Thread.currentThread(); | |
} | |
// [2] Satisfy read via inAppBuffer: | |
int count = append(inAppBuffer, dst); | |
if (count > 0) { | |
return count; | |
} | |
// [3] Read & decrypt loop: | |
return readAndDecryptLoop(dst); | |
} finally { | |
readLock.unlock(); | |
readThread = null; | |
} | |
} | |
/** | |
* Return true if we are handshaking. | |
*/ | |
private static boolean isHandshaking(State state) { | |
switch (state) { | |
case HANDSHAKING_READ: | |
case HANDSHAKING_WRITE: | |
case HANDSHAKING_TASK: | |
return true; | |
case CLOSED: | |
case OPEN: | |
case CLOSING: | |
} | |
return false; | |
} | |
/** | |
* Precondition: readLock acquired | |
*/ | |
private int readAndDecryptLoop(ByteBuffer dst) throws IOException { | |
boolean networkRead = inNetBuffer.position() == 0; | |
while (true) { | |
// Read from network: | |
if (networkRead) { | |
synchronized (stateLock) { | |
if (State.OPEN == state && !dst.hasRemaining()) { | |
return 0; | |
} | |
} | |
if (wrap.read(inNetBuffer) < 0) { | |
return -1; | |
} | |
} | |
SSLEngineResult result; | |
synchronized(stateLock) { | |
// Decrypt: | |
inNetBuffer.flip(); | |
try { | |
result = engine.unwrap(inNetBuffer, dst); | |
} finally { | |
inNetBuffer.compact(); | |
} | |
State newState = toState(result.getHandshakeStatus()); | |
if (state != State.CLOSING && newState != state) { | |
state = newState; | |
} | |
} | |
SSLEngineResult.Status status = result.getStatus(); | |
switch (status) { | |
case BUFFER_OVERFLOW: | |
if (dst == inAppBuffer) { | |
throw new IllegalStateException( | |
"SSLEngine indicated app buffer size=" + inAppBuffer.capacity() + | |
", but unwrap() returned BUFFER_OVERFLOW with an empty buffer"); | |
} | |
// Not enough space in dst, so buffer it into inAppBuffer: | |
readAndDecryptLoop(inAppBuffer); | |
return append(inAppBuffer, dst); | |
case BUFFER_UNDERFLOW: | |
if (!inNetBuffer.hasRemaining()) { | |
throw new IllegalStateException( | |
"SSLEngine indicated net buffer size=" + inNetBuffer.capacity() + | |
", but unwrap() returned BUFFER_UNDERFLOW with a full buffer"); | |
} | |
networkRead = inNetBuffer.hasRemaining(); | |
break; // retry network read | |
case CLOSED: | |
try { | |
wrap.close(); | |
} finally { | |
synchronized (stateLock) { | |
state = State.CLOSED; | |
} | |
} | |
return -1; | |
case OK: | |
return result.bytesProduced(); | |
default: | |
throw new IllegalStateException("Unexpected status=" + status); | |
} | |
} | |
} | |
/** | |
* Write plaintext by encrypting and writing this to the underlying wrap'ed socket. | |
* | |
* @param srcs are the buffers of plaintext to encrypt. | |
* @param offset is the offset within the array to begin writing. | |
* @param length is the number of buffers within the srcs array that should be written. | |
* @return the number of bytes that got written or -1 to indicate end of | |
* stream and the src position will also be incremented appropriately. | |
*/ | |
@Override | |
public long write(ByteBuffer[] srcs, int offset, int length) throws IOException { | |
int result = 0; | |
while (0 == result) { | |
if (!handshake()) { | |
return -1; | |
} | |
if (0 == remaining(srcs, offset, length)) { | |
return 0; | |
} | |
result = writeImpl(srcs, offset, length); | |
} | |
return result; | |
} | |
@Override | |
public int write(ByteBuffer src) throws IOException { | |
return Math.toIntExact(write(new ByteBuffer[]{src}, 0, 1)); | |
} | |
@Override | |
public long write(ByteBuffer[] srcs) throws IOException { | |
return write(srcs, 0, srcs.length); | |
} | |
/** | |
* While there are delegatedTasks to run, run them. | |
* | |
* Precondition: stateLock acquired | |
*/ | |
private void executeTasks() { | |
while (true) { | |
Runnable runnable = engine.getDelegatedTask(); | |
if (null == runnable) { | |
break; | |
} | |
runnable.run(); | |
} | |
} | |
/** | |
* Implement a write operation. | |
* | |
* @param src is the source buffer to write | |
* @return the number of bytes written or -1 if end of stream. | |
* | |
* Precondition: write lock is acquired. | |
*/ | |
private int writeImpl(ByteBuffer[] srcs, int offset, int length) throws IOException { | |
writeLock.lock(); | |
try { | |
// [1] Wait until handshake is complete in other thread. | |
synchronized (stateLock) { | |
if (isHandshaking(state)) { | |
if (state != State.HANDSHAKING_WRITE) { | |
return 0; | |
} | |
srcs = EMPTY; | |
} | |
writeThread = Thread.currentThread(); | |
} | |
// [2] Write & decrypt loop: | |
return writeAndEncryptLoop(srcs, offset, length); | |
} finally { | |
writeLock.unlock(); | |
writeThread = null; | |
} | |
} | |
private int writeAndEncryptLoop(ByteBuffer[] srcs, int offset, int length) throws IOException { | |
if (offset >= length) { | |
return 0; | |
} | |
int count = 0; | |
boolean finalNetFlush = false; | |
int srcsEnd = offset + length; | |
while (true) { | |
SSLEngineResult result = null; | |
synchronized (stateLock) { | |
// Encrypt: | |
outNetBuffer.compact(); | |
try { | |
for (; offset < srcsEnd; offset++, length--) { | |
ByteBuffer src = srcs[offset]; | |
int startPosition = src.position(); | |
result = engine.wrap(src, outNetBuffer); | |
count += src.position() - startPosition; | |
if (result.getStatus() != SSLEngineResult.Status.OK) { | |
break; | |
} | |
} | |
} finally { | |
outNetBuffer.flip(); | |
} | |
State newState = toState(result.getHandshakeStatus()); | |
if (state != State.CLOSING && state != newState) { | |
state = newState; | |
} | |
} | |
SSLEngineResult.Status status = result.getStatus(); | |
switch (status) { | |
case BUFFER_OVERFLOW: | |
if (outNetBuffer.remaining() == outNetBuffer.capacity()) { | |
throw new IllegalStateException( | |
"SSLEngine indicated net buffer size=" + outNetBuffer.capacity() + | |
", but wrap() returned BUFFER_OVERFLOW with a full buffer"); | |
} | |
break; // retry network write. | |
case BUFFER_UNDERFLOW: | |
throw new IllegalStateException("SSLEngine.wrap() should never return BUFFER_UNDERFLOW"); | |
case CLOSED: | |
finalNetFlush = true; | |
break; | |
case OK: | |
finalNetFlush = offset >= srcsEnd; | |
break; // perform a final net write. | |
default: | |
throw new IllegalStateException("Unexpected status=" + result.getStatus()); | |
} | |
// Write to network: | |
if (outNetBuffer.remaining() > 0) { | |
if (wrap.write(outNetBuffer) < 0) { | |
return -1; | |
} | |
} | |
if (finalNetFlush || 0 == remaining(srcs, offset, length)) { | |
break; | |
} | |
} | |
return count; | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package io.nats.client.channels; | |
import java.net.URI; | |
import java.nio.ByteBuffer; | |
import java.nio.channels.ByteChannel; | |
import java.time.Duration; | |
import java.util.concurrent.CountDownLatch; | |
import java.util.concurrent.ExecutorService; | |
import java.util.concurrent.Executors; | |
import java.util.concurrent.Future; | |
import java.util.concurrent.atomic.AtomicBoolean; | |
import javax.net.ssl.SSLContext; | |
import javax.net.ssl.SSLEngine; | |
import javax.net.ssl.SSLException; | |
import org.junit.jupiter.api.Test; | |
import io.nats.client.NatsTestServer; | |
import io.nats.client.TestSSLUtils; | |
import static java.nio.charset.StandardCharsets.UTF_8; | |
import static org.junit.jupiter.api.Assertions.assertEquals; | |
import static org.junit.jupiter.api.Assertions.assertFalse; | |
import static org.junit.jupiter.api.Assertions.assertThrows; | |
import static org.junit.jupiter.api.Assertions.assertTrue; | |
import java.io.IOException; | |
public class TLSByteChannelTests { | |
private static final ByteBuffer EMPTY = ByteBuffer.allocate(0); | |
@Test | |
public void testShortAppRead() throws Exception { | |
// Scenario: Net read TLS frame which is larger than the read buffer. | |
try (NatsTestServer ts = new NatsTestServer("src/test/resources/tls.conf", false)) { | |
URI uri = new URI(ts.getURI()); | |
NatsChannel socket = SocketNatsChannel.factory().connect(uri, Duration.ofSeconds(2)); | |
ByteBuffer info = ByteBuffer.allocate(1024 * 1024); | |
socket.read(info); | |
TLSByteChannel tls = new TLSByteChannel(socket, createSSLEngine(uri)); | |
write(tls, "CONNECT {}\r\n"); | |
ByteBuffer oneByte = ByteBuffer.allocate(1); | |
assertEquals(1, tls.read(oneByte)); | |
assertEquals(1, oneByte.position()); | |
assertEquals((byte)'+', oneByte.get(0)); // got 0? | |
oneByte.clear(); | |
assertEquals(1, tls.read(oneByte)); | |
assertEquals((byte)'O', oneByte.get(0)); | |
oneByte.clear(); | |
assertEquals(1, tls.read(oneByte)); | |
assertEquals((byte)'K', oneByte.get(0)); | |
oneByte.clear(); | |
// Follow up with a larger buffer read, | |
// ...to ensure that we don't block on | |
// a net read: | |
info.clear(); | |
int result = tls.read(info); | |
assertEquals(2, result); | |
assertEquals(2, info.position()); | |
assertEquals((byte)'\r', info.get(0)); | |
assertEquals((byte)'\n', info.get(1)); | |
oneByte.clear(); | |
assertTrue(tls.isOpen()); | |
assertTrue(socket.isOpen()); | |
tls.close(); | |
assertFalse(tls.isOpen()); | |
assertFalse(socket.isOpen()); | |
} | |
} | |
@Test | |
public void testImmediateClose() throws Exception { | |
// Scenario: Net read TLS frame which is larger than the read buffer. | |
try (NatsTestServer ts = new NatsTestServer("src/test/resources/tls.conf", false)) { | |
URI uri = new URI(ts.getURI()); | |
NatsChannel socket = SocketNatsChannel.factory().connect(uri, Duration.ofSeconds(2)); | |
ByteBuffer info = ByteBuffer.allocate(1024 * 1024); | |
socket.read(info); | |
TLSByteChannel tls = new TLSByteChannel(socket, createSSLEngine(uri)); | |
assertTrue(tls.isOpen()); | |
assertTrue(socket.isOpen()); | |
tls.close(); | |
assertFalse(tls.isOpen()); | |
assertFalse(socket.isOpen()); | |
} | |
} | |
@Test | |
public void testRenegotiation() throws Exception { | |
try (NatsTestServer ts = new NatsTestServer("src/test/resources/tls.conf", false)) { | |
URI uri = new URI(ts.getURI()); | |
NatsChannel socket = SocketNatsChannel.factory().connect(uri, Duration.ofSeconds(2)); | |
ByteBuffer readBuffer = ByteBuffer.allocate(1024 * 1024); | |
socket.read(readBuffer); | |
SSLEngine sslEngine = createSSLEngine(uri); | |
TLSByteChannel tls = new TLSByteChannel(socket, sslEngine); | |
write(tls, "CONNECT {}\r\n"); | |
readBuffer.clear(); | |
tls.read(readBuffer); | |
readBuffer.flip(); | |
assertEquals(ByteBuffer.wrap("+OK\r\n".getBytes(UTF_8)), readBuffer); | |
// Now force a renegotiation: | |
sslEngine.getSession().invalidate(); | |
sslEngine.beginHandshake(); | |
// nats-server doesn't support renegotion, we just get this error: | |
// javax.net.ssl.SSLException: Received fatal alert: unexpected_message | |
assertThrows(SSLException.class, | |
() -> tls.write(new ByteBuffer[]{ByteBuffer.wrap("PING\r\n".getBytes(UTF_8))})); | |
} | |
} | |
@Test | |
public void testConcurrentHandshake() throws Exception { | |
try (NatsTestServer ts = new NatsTestServer("src/test/resources/tls.conf", false)) { | |
URI uri = new URI(ts.getURI()); | |
NatsChannel socket = SocketNatsChannel.factory().connect(uri, Duration.ofSeconds(2)); | |
ByteBuffer readBuffer = ByteBuffer.allocate(1024 * 1024); | |
socket.read(readBuffer); | |
int numThreads = 10; | |
SSLEngine sslEngine = createSSLEngine(uri); | |
TLSByteChannel tls = new TLSByteChannel(socket, sslEngine); | |
CountDownLatch threadsReady = new CountDownLatch(numThreads); | |
CountDownLatch startLatch = new CountDownLatch(1); | |
ExecutorService executor = Executors.newFixedThreadPool(numThreads); | |
Future<Void>[] futures = new Future[numThreads]; | |
for (int i = 0; i < 10; i++) { | |
boolean isRead = i % 2 == 0; | |
futures[i] = executor.submit(() -> { | |
threadsReady.countDown(); | |
startLatch.await(); | |
if (isRead) { | |
tls.read(EMPTY); | |
} else { | |
tls.write(EMPTY); | |
} | |
return null; | |
}); | |
} | |
threadsReady.await(); | |
startLatch.countDown(); | |
// Make sure no exception happend on any thread: | |
for (int i=0; i < 10; i++) { | |
futures[i].get(); | |
} | |
write(tls, "CONNECT {}\r\n"); | |
readBuffer.clear(); | |
tls.read(readBuffer); | |
readBuffer.flip(); | |
assertEquals(ByteBuffer.wrap("+OK\r\n".getBytes(UTF_8)), readBuffer); | |
tls.close(); | |
} | |
} | |
@Test | |
public void testConcurrentClose() throws Exception { | |
try (NatsTestServer ts = new NatsTestServer("src/test/resources/tls.conf", false)) { | |
URI uri = new URI(ts.getURI()); | |
NatsChannel socket = SocketNatsChannel.factory().connect(uri, Duration.ofSeconds(2)); | |
ByteBuffer readBuffer = ByteBuffer.allocate(1024 * 1024); | |
socket.read(readBuffer); | |
int numThreads = 10; | |
SSLEngine sslEngine = createSSLEngine(uri); | |
TLSByteChannel tls = new TLSByteChannel(socket, sslEngine); | |
tls.handshake(); | |
CountDownLatch threadsReady = new CountDownLatch(numThreads); | |
CountDownLatch startLatch = new CountDownLatch(1); | |
ExecutorService executor = Executors.newFixedThreadPool(numThreads); | |
Future<Void>[] futures = new Future[numThreads]; | |
for (int i = 0; i < 10; i++) { | |
futures[i] = executor.submit(() -> { | |
threadsReady.countDown(); | |
startLatch.await(); | |
tls.close(); | |
return null; | |
}); | |
} | |
threadsReady.await(); | |
startLatch.countDown(); | |
// Make sure no exception happend on any thread: | |
for (int i=0; i < 10; i++) { | |
futures[i].get(); | |
} | |
} | |
} | |
@Test | |
public void testShortNetRead() throws Exception { | |
// Scenario: Net read TLS frame which is larger than the read buffer. | |
try (NatsTestServer ts = new NatsTestServer("src/test/resources/tls.conf", false)) { | |
URI uri = new URI(ts.getURI()); | |
NatsChannel socket = SocketNatsChannel.factory().connect(uri, Duration.ofSeconds(2)); | |
AtomicBoolean readOneByteAtATime = new AtomicBoolean(true); | |
NatsChannel wrapper = new AbstractNatsChannel(socket) { | |
ByteBuffer readBuffer = ByteBuffer.allocate(1); | |
@Override | |
public int read(ByteBuffer dst) throws IOException { | |
if (!readOneByteAtATime.get()) { | |
return socket.read(dst); | |
} | |
readBuffer.clear(); | |
int result = socket.read(readBuffer); | |
if (result <= 0) { | |
return result; | |
} | |
readBuffer.flip(); | |
dst.put(readBuffer); | |
return result; | |
} | |
@Override | |
public long write(ByteBuffer[] srcs, int offset, int length) throws IOException { | |
return socket.write(srcs, offset, length); | |
} | |
@Override | |
public boolean isSecure() { | |
return false; | |
} | |
@Override | |
public String transformConnectUrl(String connectUrl) { | |
return connectUrl; | |
} | |
}; | |
ByteBuffer info = ByteBuffer.allocate(1024 * 1024); | |
socket.read(info); | |
TLSByteChannel tls = new TLSByteChannel(wrapper, createSSLEngine(uri)); | |
// Peform handshake: | |
tls.read(ByteBuffer.allocate(0)); | |
// Send connect & ping, but turn off one-byte at a time for readint PONG: | |
readOneByteAtATime.set(false); | |
write(tls, "CONNECT {}\r\nPING\r\n"); | |
info.clear(); | |
tls.read(info); | |
info.flip(); | |
assertEquals( | |
ByteBuffer.wrap( | |
"+OK\r\nPONG\r\n" | |
.getBytes(UTF_8)), | |
info); | |
tls.close(); | |
} | |
} | |
private static SSLEngine createSSLEngine(URI uri) throws Exception { | |
SSLContext ctx = TestSSLUtils.createTestSSLContext(); | |
SSLEngine engine = ctx.createSSLEngine(uri.getHost(), uri.getPort()); | |
engine.setUseClientMode(true); | |
return engine; | |
} | |
private static void write(ByteChannel channel, String str) throws IOException { | |
channel.write(ByteBuffer.wrap(str.getBytes(UTF_8))); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Note that the tests depend on
io.nats:jnats-server-runner:1.0.7
and a few other things not shown.