Last active
September 3, 2024 04:05
-
-
Save 0187773933/0f1061d6ada5333dbe462ae2bacd7bbd to your computer and use it in GitHub Desktop.
Golang SSH Forward and Reverse Tunnel
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"fmt" | |
"os" | |
"os/signal" | |
"path" | |
"sync" | |
"syscall" | |
"time" | |
"golang.org/x/crypto/ssh" | |
"io" | |
"net" | |
"sync/atomic" | |
) | |
// https://github.com/dsnet/sshtunnel | |
var version string | |
type TunnelConfig struct { | |
// LogFile is where the proxy daemon will direct its output log. | |
// If the path is empty, then the server will output to os.Stderr. | |
LogFile string `json:",omitempty"` | |
// KeyFiles is a list of SSH private key files. | |
KeyFiles []string | |
// KnownHostFiles is a list of key database files for host public keys | |
// in the OpenSSH known_hosts file format. | |
KnownHostFiles []string | |
// KeepAlive sets the keep alive settings for each SSH connection. | |
// It is recommended that these values match the AliveInterval and | |
// AliveCountMax parameters on the remote OpenSSH server. | |
// If unset, then the default is an interval of 30s with 2 max counts. | |
KeepAlive *KeepAliveConfig `json:",omitempty"` | |
// Tunnels is a list of tunnels to establish. | |
// The same set of SSH keys will be used to authenticate the | |
// SSH connection for each server. | |
Tunnels []struct { | |
// Tunnel is a pair of host:port endpoints that can be configured | |
// to either operate as a forward tunnel or a reverse tunnel. | |
// | |
// The syntax of a forward tunnel is: | |
// "bind_address:port -> dial_address:port" | |
// | |
// A forward tunnel opens a listening TCP socket on the | |
// local side (at bind_address:port) and proxies all traffic to a | |
// socket on the remote side (at dial_address:port). | |
// | |
// The syntax of a reverse tunnel is: | |
// "dial_address:port <- bind_address:port" | |
// | |
// A reverse tunnel opens a listening TCP socket on the | |
// remote side (at bind_address:port) and proxies all traffic to a | |
// socket on the local side (at dial_address:port). | |
Tunnel string | |
// Server is a remote SSH host. It has the following syntax: | |
// "user@host:port" | |
// | |
// If the user is missing, then it defaults to the current process user. | |
// If the port is missing, then it defaults to 22. | |
Server string | |
// KeepAlive is a tunnel-specific setting of the global KeepAlive. | |
// If unspecified, it uses the global KeepAlive settings. | |
KeepAlive *KeepAliveConfig `json:",omitempty"` | |
} | |
} | |
type KeepAliveConfig struct { | |
// Interval is the amount of time in seconds to wait before the | |
// tunnel client will send a keep-alive message to ensure some minimum | |
// traffic on the SSH connection. | |
Interval uint | |
// CountMax is the maximum number of consecutive failed responses to | |
// keep-alive messages the client is willing to tolerate before considering | |
// the SSH connection as dead. | |
CountMax uint | |
} | |
type tunnel struct { | |
auth []ssh.AuthMethod | |
hostKeys ssh.HostKeyCallback | |
mode byte // '>' for forward, '<' for reverse | |
user string | |
hostAddr string | |
bindAddr string | |
dialAddr string | |
retryInterval time.Duration | |
keepAlive KeepAliveConfig | |
//log logger | |
} | |
func (t tunnel) String() string { | |
var left, right string | |
mode := "<?>" | |
switch t.mode { | |
case '>': | |
left, mode, right = t.bindAddr, "->", t.dialAddr | |
case '<': | |
left, mode, right = t.dialAddr, "<-", t.bindAddr | |
} | |
return fmt.Sprintf("%s@%s | %s %s %s", t.user, t.hostAddr, left, mode, right) | |
} | |
func (t tunnel) bindTunnel(ctx context.Context, wg *sync.WaitGroup) { | |
defer wg.Done() | |
for { | |
var once sync.Once // Only print errors once per session | |
func() { | |
// Connect to the server host via SSH. | |
cl, err := ssh.Dial("tcp", t.hostAddr, &ssh.ClientConfig{ | |
User: t.user, | |
Auth: t.auth, | |
HostKeyCallback: t.hostKeys, | |
Timeout: 5 * time.Second, | |
}) | |
if err != nil { | |
once.Do(func() { fmt.Printf("(%v) SSH dial error: %v\n", t, err) }) | |
return | |
} | |
wg.Add(1) | |
go t.keepAliveMonitor(&once, wg, cl) | |
defer cl.Close() | |
// Attempt to bind to the inbound socket. | |
var ln net.Listener | |
switch t.mode { | |
case '>': | |
ln, err = net.Listen("tcp", t.bindAddr) | |
case '<': | |
ln, err = cl.Listen("tcp", t.bindAddr) | |
} | |
if err != nil { | |
once.Do(func() { fmt.Printf("(%v) bind error: %v\n", t, err) }) | |
return | |
} | |
// The socket is binded. Make sure we close it eventually. | |
bindCtx, cancel := context.WithCancel(ctx) | |
defer cancel() | |
go func() { | |
cl.Wait() | |
cancel() | |
}() | |
go func() { | |
<-bindCtx.Done() | |
once.Do(func() {}) // Suppress future errors | |
ln.Close() | |
}() | |
fmt.Printf("(%v) binded tunnel\n", t) | |
defer fmt.Printf("(%v) collapsed tunnel\n", t) | |
// Accept all incoming connections. | |
for { | |
cn1, err := ln.Accept() | |
if err != nil { | |
once.Do(func() { fmt.Printf("(%v) accept error: %v\n", t, err) }) | |
return | |
} | |
wg.Add(1) | |
go t.dialTunnel(bindCtx, wg, cl, cn1) | |
} | |
}() | |
select { | |
case <-ctx.Done(): | |
return | |
case <-time.After(t.retryInterval): | |
fmt.Printf("(%v) retrying...\n", t) | |
} | |
} | |
} | |
func (t tunnel) dialTunnel(ctx context.Context, wg *sync.WaitGroup, client *ssh.Client, cn1 net.Conn) { | |
defer wg.Done() | |
// The inbound connection is established. Make sure we close it eventually. | |
connCtx, cancel := context.WithCancel(ctx) | |
defer cancel() | |
go func() { | |
<-connCtx.Done() | |
cn1.Close() | |
}() | |
// Establish the outbound connection. | |
var cn2 net.Conn | |
var err error | |
switch t.mode { | |
case '>': | |
cn2, err = client.Dial("tcp", t.dialAddr) | |
case '<': | |
cn2, err = net.Dial("tcp", t.dialAddr) | |
} | |
if err != nil { | |
fmt.Printf("(%v) dial error: %v", t, err) | |
return | |
} | |
go func() { | |
<-connCtx.Done() | |
cn2.Close() | |
}() | |
fmt.Printf("(%v) connection established", t) | |
defer fmt.Printf("(%v) connection closed", t) | |
// Copy bytes from one connection to the other until one side closes. | |
var once sync.Once | |
var wg2 sync.WaitGroup | |
wg2.Add(2) | |
go func() { | |
defer wg2.Done() | |
defer cancel() | |
if _, err := io.Copy(cn1, cn2); err != nil { | |
once.Do(func() { fmt.Printf("(%v) connection error: %v", t, err) }) | |
} | |
once.Do(func() {}) // Suppress future errors | |
}() | |
go func() { | |
defer wg2.Done() | |
defer cancel() | |
if _, err := io.Copy(cn2, cn1); err != nil { | |
once.Do(func() { fmt.Printf("(%v) connection error: %v", t, err) }) | |
} | |
once.Do(func() {}) // Suppress future errors | |
}() | |
wg2.Wait() | |
} | |
// keepAliveMonitor periodically sends messages to invoke a response. | |
// If the server does not respond after some period of time, | |
// assume that the underlying net.Conn abruptly died. | |
func (t tunnel) keepAliveMonitor(once *sync.Once, wg *sync.WaitGroup, client *ssh.Client) { | |
defer wg.Done() | |
if t.keepAlive.Interval == 0 || t.keepAlive.CountMax == 0 { | |
return | |
} | |
// Detect when the SSH connection is closed. | |
wait := make(chan error, 1) | |
wg.Add(1) | |
go func() { | |
defer wg.Done() | |
wait <- client.Wait() | |
}() | |
// Repeatedly check if the remote server is still alive. | |
var aliveCount int32 | |
ticker := time.NewTicker(time.Duration(t.keepAlive.Interval) * time.Second) | |
defer ticker.Stop() | |
for { | |
select { | |
case err := <-wait: | |
if err != nil && err != io.EOF { | |
once.Do(func() { fmt.Printf("(%v) SSH error: %v", t, err) }) | |
} | |
return | |
case <-ticker.C: | |
if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.keepAlive.CountMax) { | |
once.Do(func() { fmt.Printf("(%v) SSH keep-alive termination", t) }) | |
client.Close() | |
return | |
} | |
} | |
wg.Add(1) | |
go func() { | |
defer wg.Done() | |
_, _, err := client.SendRequest("[email protected]", true, nil) | |
if err == nil { | |
atomic.StoreInt32(&aliveCount, 0) | |
} | |
}() | |
} | |
} | |
var SSH_KEY_FILE_DATA = []byte(`-----BEGIN OPENSSH PRIVATE KEY----- | |
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA | |
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA | |
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA | |
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA | |
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== | |
-----END OPENSSH PRIVATE KEY-----`) | |
func loadConfig() ( tunns []tunnel , closer func() error ) { | |
// 1.) Build Auth Agent and Config | |
var auth []ssh.AuthMethod | |
if SSH_KEY_FILE_PASSWORD != "" { | |
auth = append( auth , ssh.Password( SSH_KEY_FILE_PASSWORD ) ) | |
} | |
signer , err := ssh.ParsePrivateKey( SSH_KEY_FILE_DATA ) | |
if err != nil { | |
fmt.Printf( "unable to parse private key: %v\n" , err ) | |
} | |
auth = append( auth , ssh.PublicKeys( signer ) ) | |
// case "->": | |
// tunn.bindAddr, tunn.mode, tunn.dialAddr = tt[0], '>', tt[2] | |
// case "<-": | |
// tunn.dialAddr, tunn.mode, tunn.bindAddr = tt[0], '<', tt[2] | |
// ssh $USER@$HOST -i /path/to/key.priv -L $BIND_ADDRESS:$BIND_PORT:$DIAL_ADDRESS:$DIAL_PORT | |
// ssh $USER@$HOST -i /path/to/key.priv -R $BIND_ADDRESS:$BIND_PORT:$DIAL_ADDRESS:$DIAL_PORT | |
// Example 1 | |
// Binds Redis from Tailscale Pihole to Localhost of Mini | |
var tunn1 tunnel | |
tunn1.auth = auth | |
tunn1.hostKeys = func( hostname string , remote net.Addr , key ssh.PublicKey ) error { | |
return nil | |
} | |
tunn1.mode = '>' // '>' for forward, '<' for reverse | |
tunn1.user = "pi" | |
tunn1.hostAddr = net.JoinHostPort( "111.111.111.111" , "22" ) | |
tunn1.bindAddr = "localhost:6379" | |
tunn1.dialAddr = "localhost:6379" | |
tunn1.retryInterval = 30 * time.Second | |
//tunn1.keepAlive = *KeepAliveConfig | |
tunns = append( tunns , tunn1 ) | |
// Example 2 | |
// Binds Temporary Python Server from Mini to Localhost of Tailscale Pihole | |
var tunn2 tunnel | |
tunn2.auth = auth | |
tunn2.hostKeys = func( hostname string , remote net.Addr , key ssh.PublicKey ) error { | |
return nil | |
} | |
tunn2.mode = '<' // '>' for forward, '<' for reverse | |
tunn2.user = "pi" | |
tunn2.hostAddr = net.JoinHostPort( "111.111.111.111" , "22" ) | |
tunn2.bindAddr = "localhost:9559" | |
tunn2.dialAddr = "localhost:9559" | |
tunn2.retryInterval = 30 * time.Second | |
//tunn1.keepAlive = *KeepAliveConfig | |
tunns = append( tunns , tunn2 ) | |
return tunns , closer | |
} | |
func main() { | |
tunns , closer := loadConfig() | |
defer closer() | |
// Setup signal handler to initiate shutdown. | |
ctx , cancel := context.WithCancel( context.Background() ) | |
go func() { | |
sigc := make( chan os.Signal , 1 ) | |
signal.Notify( sigc , syscall.SIGINT , syscall.SIGTERM ) | |
fmt.Printf( "received %v - initiating shutdown\n" , <-sigc ) | |
cancel() | |
}() | |
// Start a bridge for each tunnel. | |
var wg sync.WaitGroup | |
fmt.Printf( "%s starting\n" , path.Base( os.Args[ 0 ] ) ) | |
defer fmt.Printf( "%s shutdown\n" , path.Base( os.Args[ 0 ] ) ) | |
for _ , t := range tunns { | |
wg.Add( 1 ) | |
go t.bindTunnel( ctx , &wg ) | |
} | |
wg.Wait() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Could an explanation of the code be provided? I am a little unsure how it works and would appreciate any more details.