Skip to content

Instantly share code, notes, and snippets.

@chop0
Last active November 17, 2023 08:20
Show Gist options
  • Save chop0/79654a42ebb0e07d3d092e6663ef76c9 to your computer and use it in GitHub Desktop.
Save chop0/79654a42ebb0e07d3d092e6663ef76c9 to your computer and use it in GitHub Desktop.
A Java socks5 server
package ax.xz.census.proxy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.BufferOverflowException;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.StructuredTaskScope;
public class Socks5Server implements Closeable {
private final Logger logger = LoggerFactory.getLogger(Socks5Server.class);
private static final int BUFFER_SIZE = 8192;
private final Dialer dialer;
private Selector selector;
private final Map<SocketChannel, SocketChannel> associatedChannels = new ConcurrentHashMap<>();
public Socks5Server(Dialer dialer) {
this.dialer = dialer;
}
public void start(int port) throws IOException {
selector = Selector.open();
var serverSocket = ServerSocketChannel.open();
serverSocket.bind(new InetSocketAddress(port));
serverSocket.configureBlocking(false);
serverSocket.register(selector, SelectionKey.OP_ACCEPT);
while (selector.isOpen()) {
selector.select();
try (var scope = new StructuredTaskScope<>()) {
for (SelectionKey key : selector.selectedKeys())
scope.fork(() -> {
handleSelectableKey(key);
return null;
});
scope.join();
} catch (InterruptedException e) {
return;
}
selector.selectedKeys().clear();
}
}
/**
* Processes a selected key
*
* @param key The key to process
*/
private void handleSelectableKey(SelectionKey key) {
try {
if (key.isAcceptable()) {
handleAccept(key);
} else if (key.isReadable()) {
handleRead(key);
}
} catch (IOException | BufferOverflowException e) {
closeKey(key);
}
}
/**
* Gets the packet handler appropriate for the next packet in the provided buffer
*
* @param dialer The dialer to use for connecting to remote hosts
* @param clientChannel The client channel
* @param buffer The buffer to read from
* @return The packet handler
* @throws IOException If an IO error occurs
*/
private PacketHandler getPacketHandler(Dialer dialer, SocketChannel clientChannel, ByteBuffer buffer) throws IOException {
// Check SOCKS5 version
byte version = buffer.get();
if (version != 5) {
throw new IOException("Unsupported SOCKS version");
}
byte command = buffer.get();
byte reserved = buffer.get();
if (reserved != 0) {
throw new IOException("Invalid SOCKS5 reserved byte");
}
return switch (command) {
case 1 -> new ConnectCommandHandler(dialer, clientChannel, buffer);
case 2 -> new BindCommandHandler(dialer, clientChannel, buffer);
default -> throw new IOException("Unsupported SOCKS command");
};
}
/**
* Handles an accept event and registers the client channel for read events
*
* @param key The key for the channel on which a new client is connecting (key should be for a server socket channel)
* @throws IOException If an IO error occurs
*/
private void handleAccept(SelectionKey key) throws IOException {
ServerSocketChannel serverSocket = (ServerSocketChannel) key.channel();
SocketChannel clientChannel = serverSocket.accept();
clientChannel.configureBlocking(false);
clientChannel.register(selector, SelectionKey.OP_READ, new byte[BUFFER_SIZE]);
}
/**
* Handles the socks5 handshake
*
* @param channel The channel to handle the handshake for
* @param buffer The buffer to read from
* @throws IOException If an IO error occurs
*/
private void handleGreeting(SocketChannel channel, ByteBuffer buffer) throws IOException {
if (buffer.get() != 0x05) {
throw new IOException("Unsupported SOCKS version");
}
int numMethods = Byte.toUnsignedInt(buffer.get());
boolean noAuthRequired = false;
for (int i = 0; i < numMethods; i++) {
byte method = buffer.get();
if (method == 0x00) {
noAuthRequired = true;
}
}
ByteBuffer response = ByteBuffer.allocate(2);
response.put((byte) 0x05); // SOCKS version
response.put(noAuthRequired ? (byte) 0x00 : (byte) 0xFF); // if we need auth, 'no acceptable methods'
response.flip();
channel.write(response);
if (!noAuthRequired) {
channel.close();
}
}
/**
* Handles a read event
*
* @param key The key for the readable channel
* @throws IOException If an IO error occurs
*/
private void handleRead(SelectionKey key) throws IOException {
var channel = (SocketChannel) key.channel();
var buffer = ByteBuffer.allocate(8192);
if (channel.read(buffer) < 0)
throw new IOException("End of stream");
buffer.flip();
if (!isGreetingDone(key)) {
handleGreeting(channel, buffer);
markGreetingDone(key);
}
if (!associatedChannels.containsKey(channel)) {
getPacketHandler(dialer, channel, buffer).handle();
}
var upstreamChannel = associatedChannels.get(channel);
if (upstreamChannel == null && buffer.hasRemaining()) {
logger.warn("Discarding {} bytes [session not established]", buffer.remaining());
return;
}
while (buffer.hasRemaining()) {
upstreamChannel.write(buffer);
}
}
/**
* Checks if the greeting has been done for the provided key
*
* @param key The key to check
* @return True if the greeting has been done, false otherwise
*/
private boolean isGreetingDone(SelectionKey key) {
return key.attachment() == null;
}
/**
* Marks the greeting as done for the provided key
*
* @param key The key to mark
*/
private void markGreetingDone(SelectionKey key) {
key.attach(null);
}
/**
* Closes the provided client's channel and any associated upstream channels
*
* @param key The key for the client
*/
private void closeKey(SelectionKey key) {
var channel = (SocketChannel) key.channel();
if (associatedChannels.containsKey(channel)) {
try {
associatedChannels.remove(channel).close();
} catch (IOException e) {
logger.warn("Error closing upstream channel", e);
}
}
try {
channel.close();
} catch (IOException e) {
logger.warn("Error closing channel", e);
}
}
abstract static class PacketHandler {
protected Dialer dialer;
protected SocketChannel clientChannel;
protected ByteBuffer buffer;
public PacketHandler(Dialer dialer, SocketChannel clientChannel, ByteBuffer buffer) {
this.dialer = dialer;
this.clientChannel = clientChannel;
this.buffer = buffer;
}
protected String readAddress() throws IOException {
// read address type
byte type = buffer.get();
byte[] bytes;
return switch (type) {
case 1 -> { // IPv4
bytes = new byte[4];
buffer.get(bytes);
yield InetAddress.getByAddress(bytes).getHostAddress();
}
case 3 -> { // Domain name
int length = buffer.get() & 0xff;
bytes = new byte[length];
buffer.get(bytes);
yield new String(bytes, StandardCharsets.UTF_8);
}
case 4 -> { // IPv6
bytes = new byte[16];
buffer.get(bytes);
yield InetAddress.getByAddress(bytes).getHostAddress();
}
default -> throw new IOException("Unsupported address type: " + type);
};
}
protected int readPort() {
return buffer.getShort() & 0xffff;
}
protected void sendResponse(byte rep, SocketAddress bndAddr) throws IOException {
buffer.clear();
buffer.put((byte) 5); // version
buffer.put(rep); // rep
buffer.put((byte) 0); // rsv
if (bndAddr instanceof InetSocketAddress) {
InetSocketAddress inetBndAddr = (InetSocketAddress) bndAddr;
byte[] addrBytes = inetBndAddr.getAddress().getAddress();
if (addrBytes.length == 4) { // IPv4
buffer.put((byte) 1); // atyp
buffer.put(addrBytes); // bnd.addr
} else { // IPv6
buffer.put((byte) 4); // atyp
buffer.put(addrBytes); // bnd.addr
}
buffer.putShort((short) inetBndAddr.getPort()); // bnd.port
} else {
buffer.put((byte) 1); // atyp (IPv4)
buffer.putInt(0); // bnd.addr (0.0.0.0)
buffer.putShort((short) 0); // bnd.port (0)
}
buffer.flip();
clientChannel.write(buffer);
}
public abstract void handle() throws IOException;
}
class ConnectCommandHandler extends PacketHandler {
private final Logger logger = LoggerFactory.getLogger(ConnectCommandHandler.class);
public ConnectCommandHandler(Dialer dialer, SocketChannel clientChannel, ByteBuffer buffer) {
super(dialer, clientChannel, buffer);
}
@Override
public void handle() throws IOException {
String host = readAddress();
int port = readPort();
try {
var upstreamChannel = dialer.connect(host, port);
upstreamChannel.configureBlocking(false);
upstreamChannel.register(selector, SelectionKey.OP_READ, clientChannel);
clientChannel.register(selector, SelectionKey.OP_READ, upstreamChannel);
associatedChannels.put(upstreamChannel, clientChannel);
associatedChannels.put(clientChannel, upstreamChannel);
sendResponse((byte) 0, upstreamChannel.getLocalAddress()); // succeeded
logger.debug("Connected to {}", upstreamChannel.getRemoteAddress());
} catch (IOException e) {
sendResponse((byte) 5, null); // connection refused
logger.debug("Connection refused", e);
}
}
}
class BindCommandHandler extends PacketHandler {
public BindCommandHandler(Dialer dialer, SocketChannel clientChannel, ByteBuffer buffer) {
super(dialer, clientChannel, buffer);
}
@Override
public void handle() throws IOException {
readAddress();
readPort();
try {
var serverChannel = dialer.bind(null, 0);
serverChannel.configureBlocking(false);
sendResponse((byte) 0, serverChannel.getLocalAddress()); // succeeded
var upstreamChannel = serverChannel.accept();
if (upstreamChannel == null)
throw new IOException("Failed to accept connection");
upstreamChannel.configureBlocking(false);
upstreamChannel.register(selector, SelectionKey.OP_READ, clientChannel);
associatedChannels.put(upstreamChannel, clientChannel);
clientChannel.register(selector, SelectionKey.OP_READ, upstreamChannel);
associatedChannels.put(clientChannel, upstreamChannel);
sendResponse((byte) 0, upstreamChannel.getLocalAddress()); // succeeded
} catch (IOException e) {
sendResponse((byte) 5, null); // connection refused
}
}
}
@Override
public void close() throws IOException {
selector.close();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment