Skip to content

Instantly share code, notes, and snippets.

@rlapz
Last active January 2, 2023 13:54
Show Gist options
  • Save rlapz/5863561987b62cdece4369d35818283a to your computer and use it in GitHub Desktop.
Save rlapz/5863561987b62cdece4369d35818283a to your computer and use it in GitHub Desktop.
const std = @import("std");
const fmt = std.fmt;
const mem = std.mem;
const io = std.io;
const net = std.net;
const os = std.os;
const log = std.log;
const time = std.time;
const linux = os.linux;
const BUFFER_SIZE = 4096;
const Address = struct {
host: []const u8,
port: u16,
};
const Config = struct {
listen: Address,
target: Address,
};
const stdout = io.getStdOut().writer();
const stderr = io.getStdErr().writer();
var fba_g: *std.heap.FixedBufferAllocator = undefined;
var is_alive_g: bool = false;
var listen_fd_g: os.fd_t = -1;
fn splice(in: os.fd_t, out: os.fd_t, size: usize, flags: u32) !usize {
const rc = linux.syscall6(
.splice,
@bitCast(usize, @as(isize, in)),
0,
@bitCast(usize, @as(isize, out)),
0,
size,
flags,
);
switch (os.errno(rc)) {
.SUCCESS => return rc,
.AGAIN => return error.WouldBlock,
.CONNRESET => return error.ConnectionResetByPeer,
.BADF => unreachable,
.INVAL => unreachable,
.NOMEM => unreachable,
.SPIPE => unreachable,
else => |err| return os.unexpectedErrno(err),
}
}
fn spipe(in: os.fd_t, out: os.fd_t, pipe: [*]const os.fd_t, size: usize) !void {
const flags = 0x1 | 0x2; // MOVE | NONBLOCK
var rd = try splice(in, pipe[1], size, flags);
if (rd == 0)
return error.EndOfFile;
while (rd > 0) {
const wr = try splice(pipe[0], out, rd, flags);
if (wr == 0)
return error.EndOfFile;
rd -= wr;
}
}
fn setSignal(handler: *const fn (c_int) callconv(.C) void) !void {
var act = mem.zeroInit(os.Sigaction, .{
.handler = .{ .handler = os.SIG.IGN },
});
try os.sigaction(os.SIG.PIPE, &act, null);
act.handler = .{ .handler = handler };
try os.sigaction(os.SIG.TERM, &act, null);
try os.sigaction(os.SIG.INT, &act, null);
try os.sigaction(os.SIG.HUP, &act, null);
}
fn connectToTarget(allocator: mem.Allocator, address: Address) !os.fd_t {
var list = try net.getAddressList(allocator, address.host, address.port);
defer list.deinit();
defer fba_g.reset();
if (list.addrs.len == 0)
return error.UnknownHostName;
const flags = os.SOCK.STREAM;
for (list.addrs) |addr| {
const fd = os.socket(addr.any.family, flags, os.IPPROTO.TCP) catch
continue;
os.connect(fd, &addr.any, addr.getOsSockLen()) catch
continue;
// Success
return fd;
}
return error.ConnectionRefused;
}
fn setupListener(address: Address) !os.fd_t {
const saddr = try net.Address.parseIp(address.host, address.port);
const fd = try os.socket(saddr.any.family, os.SOCK.STREAM, os.IPPROTO.TCP);
try os.setsockopt(
fd,
os.SOL.SOCKET,
os.SO.REUSEADDR,
&mem.toBytes(@as(c_int, 1)),
);
try os.bind(fd, &saddr.any, saddr.getOsSockLen());
try os.listen(fd, 10);
return fd;
}
fn startTunnel(allocator: mem.Allocator, config: Config) !void {
var pollfds: [2]os.pollfd = undefined;
pollfds[0].events = os.POLL.IN;
pollfds[1].events = os.POLL.IN;
listen_fd_g = try setupListener(config.listen);
defer os.close(listen_fd_g);
const pipe = try os.pipe();
defer for (pipe) |p|
os.close(p);
is_alive_g = true;
while (is_alive_g) {
const tfd = connectToTarget(allocator, config.target) catch |err| {
log.err("{s}", .{@errorName(err)});
time.sleep(time.ns_per_s);
continue;
};
defer os.close(tfd);
const nfd = os.accept(listen_fd_g, null, null, os.SOCK.NONBLOCK) catch
continue;
defer os.close(nfd);
pollfds[0].fd = nfd;
pollfds[1].fd = tfd;
while (is_alive_g) {
if ((try os.poll(&pollfds, 1000)) == 0)
continue;
if ((pollfds[0].revents & os.POLL.IN) != 0) {
spipe(nfd, tfd, &pipe, BUFFER_SIZE) catch |err| switch (err) {
error.WouldBlock => continue,
else => break,
};
}
if ((pollfds[1].revents & os.POLL.IN) != 0) {
spipe(tfd, nfd, &pipe, BUFFER_SIZE) catch |err| switch (err) {
error.WouldBlock => continue,
else => break,
};
}
}
}
}
fn signalHandler(sig: c_int) callconv(.C) void {
stderr.writeByte('\n') catch {};
log.err("Interrupted: {}", .{sig});
if (is_alive_g) {
is_alive_g = false;
if (listen_fd_g != -1)
os.shutdown(listen_fd_g, .both) catch {};
}
}
fn help(name: [*:0]const u8) void {
stdout.print(
\\{s} [listen host] [listen port] [target host] [target port]
\\
\\Example:
\\ {s} 127.0.0.1 8000 192.168.12.1 6969
\\
,
.{ name, name },
) catch {};
}
pub fn main() !void {
const argv = os.argv;
if (argv.len != 5) {
help(argv[0]);
return error.InvalidArgument;
}
const cfg = Config{
.listen = .{
.host = mem.span(argv[1]),
.port = try fmt.parseUnsigned(u16, mem.span(argv[2]), 10),
},
.target = .{
.host = mem.span(argv[3]),
.port = try fmt.parseUnsigned(u16, mem.span(argv[4]), 10),
},
};
var real_buffer: [1024 * 1024]u8 = undefined;
var fba = std.heap.FixedBufferAllocator.init(&real_buffer);
fba_g = &fba;
try setSignal(signalHandler);
try startTunnel(fba.allocator(), cfg);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment