Skip to content

Instantly share code, notes, and snippets.

@sykesm
Created February 27, 2015 11:38
Show Gist options
  • Save sykesm/2f9710a2c892c4a854c2 to your computer and use it in GitHub Desktop.
Save sykesm/2f9710a2c892c4a854c2 to your computer and use it in GitHub Desktop.
Outline for a simple SSH proxy
package proxy
import (
"io"
"net"
"os"
"sync"
"github.com/pivotal-golang/lager"
"github.com/tedsuo/ifrit"
"golang.org/x/crypto/ssh"
)
type SSHProxy interface {
ifrit.Runner
}
//go:generate counterfeiter -o fakes/fake_authenticator.go . Authenticator
type Authenticator interface {
Authenticate(metdata ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error)
}
//go:generate counterfeiter -o fakes/fake_config_factory.go . ConfigFactory
type ConfigFactory interface {
Create(permissions *ssh.Permissions) (config *ssh.ClientConfig, address string, err error)
}
//go:generate counterfeiter -o fakes/fake_auditor.go . AuthenticationAuditor
type AuthenticationAuditor interface {
Audit(conn ssh.ConnMetadata, method string, err error)
}
type Waiter interface {
Wait() error
}
type proxy struct {
logger lager.Logger
proxyConfig *ssh.ServerConfig
authenticator Authenticator
configFactory ConfigFactory
listenAddr string
listener net.Listener
mutex *sync.Mutex
stopping bool
stopChan chan struct{}
}
func NewProxy(
logger lager.Logger,
hostPrivateKey string,
listenAddress string,
authenticator Authenticator,
configFactory ConfigFactory,
) SSHProxy {
sshConfig := &ssh.ServerConfig{}
proxy := &proxy{
logger: logger,
proxyConfig: sshConfig,
authenticator: authenticator,
configFactory: configFactory,
listenAddr: listenAddress,
mutex: &sync.Mutex{},
stopChan: make(chan struct{}),
}
privateKey, err := ssh.ParsePrivateKey([]byte(hostPrivateKey))
if err != nil {
logger.Fatal("failed-to-parse-host-key", err)
}
sshConfig.AddHostKey(privateKey)
sshConfig.PasswordCallback = proxy.authenticator.Authenticate
sshConfig.AuthLogCallback = proxy.auditAuthentication
return proxy
}
func (p *proxy) auditAuthentication(conn ssh.ConnMetadata, method string, err error) {
logger := p.logger.Session("audit-authentication")
if err == nil {
logger.Info("success", lager.Data{
"method": method,
"user": conn.User(),
"client-version": conn.ClientVersion(),
"remote-address": conn.RemoteAddr(),
"session-id": conn.SessionID(),
})
} else {
logger.Info("failed", lager.Data{
"user": conn.User(),
"remote-address": conn.RemoteAddr(),
"error": err,
})
}
}
func (p *proxy) Run(signals <-chan os.Signal, ready chan<- struct{}) error {
logger := p.logger.Session("run")
logger.Info("started")
defer logger.Info("ended")
stopChan := make(chan struct{})
listener, err := net.Listen("tcp", p.listenAddr)
if err != nil {
logger.Error("listen-failed", err)
return err
}
p.listener = listener
go p.acceptLoop(stopChan)
close(ready)
for {
select {
case <-signals:
return p.stop()
case <-stopChan:
return p.stop()
}
}
}
func (p *proxy) acceptLoop(stopChan chan<- struct{}) {
logger := p.logger.Session("accept-loop")
for {
netConn, err := p.listener.Accept()
if err != nil {
if p.isStopping() {
break
} else {
logger.Error("accept-failed", err)
continue
}
}
go p.handleConnection(netConn)
}
close(stopChan)
}
func (p *proxy) handleConnection(netConn net.Conn) {
logger := p.logger.Session("handle-connection")
logger.Info("started")
defer logger.Info("completed")
defer netConn.Close()
conn, channels, requests, err := ssh.NewServerConn(netConn, p.proxyConfig)
if err != nil {
logger.Error("handshake-failed", err)
return
}
defer conn.Close()
clientConfig, address, err := p.configFactory.Create(conn.Permissions)
if err != nil {
logger.Error("failed-to-create-client-config", err)
return
}
clientConn, clientChannels, clientRequests, err := p.NewClientConn(logger, address, clientConfig)
if err != nil {
logger.Error("failed-to-handshake-target", err)
return
}
defer clientConn.Close()
go proxyGlobalRequests(logger, clientConn, requests)
go proxyGlobalRequests(logger, conn, clientRequests)
go proxyChannels(logger, clientConn, channels)
go proxyChannels(logger, conn, clientChannels)
wait(logger, conn, clientConn)
}
func proxyChannels(logger lager.Logger, conn ssh.Conn, newChannelChan <-chan ssh.NewChannel) {
logger = logger.Session("proxy-channels")
logger.Info("started")
defer logger.Info("completed")
defer conn.Close()
for newChannel := range newChannelChan {
logger.Info("new-channel", lager.Data{
"channelType": newChannel.ChannelType(),
"extraData": newChannel.ExtraData(),
})
targetChan, targetReqs, err := conn.OpenChannel(newChannel.ChannelType(), newChannel.ExtraData())
if err != nil {
logger.Error("failed-to-open-channel", err)
if openErr, ok := err.(*ssh.OpenChannelError); ok {
newChannel.Reject(openErr.Reason, openErr.Message)
} else {
newChannel.Reject(ssh.ConnectionFailed, openErr.Message)
}
continue
}
sourceChan, sourceReqs, err := newChannel.Accept()
if err != nil {
targetChan.Close()
continue
}
go copyChannel(logger, targetChan, sourceChan)
go copyChannel(logger, sourceChan, targetChan)
go proxyRequests(logger, newChannel.ChannelType(), sourceReqs, targetChan)
go proxyRequests(logger, newChannel.ChannelType(), targetReqs, sourceChan)
}
}
func proxyRequests(logger lager.Logger, channelType string, reqs <-chan *ssh.Request, channel ssh.Channel) {
logger = logger.Session("proxy-requests", lager.Data{
"channel-type": channelType,
})
logger.Info("started")
defer logger.Info("completed")
defer channel.Close()
for req := range reqs {
logger.Info("request", lager.Data{
"type": req.Type,
"wantReply": req.WantReply,
"payload": req.Payload,
})
success, err := channel.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
logger.Error("send-request-failed", err)
continue
}
if req.WantReply {
req.Reply(success, []byte{})
}
}
}
func proxyGlobalRequests(logger lager.Logger, conn ssh.Conn, reqs <-chan *ssh.Request) {
logger = logger.Session("proxy-global-requests")
logger.Info("started")
defer logger.Info("completed")
for req := range reqs {
logger.Info("request", lager.Data{
"type": req.Type,
"wantReply": req.WantReply,
"payload": req.Payload,
})
success, reply, err := conn.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
logger.Error("send-request-failed", err)
continue
}
if req.WantReply {
req.Reply(success, reply)
}
}
}
func copyChannel(logger lager.Logger, dest ssh.Channel, src ssh.Channel) {
logger = logger.Session("copy-channel")
logger.Info("started")
defer logger.Info("completed")
io.Copy(dest, src)
dest.Close()
}
func wait(logger lager.Logger, waiters ...Waiter) {
wg := &sync.WaitGroup{}
for _, waiter := range waiters {
wg.Add(1)
go func() {
waiter.Wait()
wg.Done()
}()
}
wg.Wait()
}
func (p *proxy) NewClientConn(logger lager.Logger, address string, config *ssh.ClientConfig) (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request, error) {
logger = logger.Session("new-client-conn", lager.Data{
"address": address,
})
nConn, err := net.Dial("tcp", address)
if err != nil {
logger.Error("dial-failed", err)
return nil, nil, nil, err
}
conn, ch, req, err := ssh.NewClientConn(nConn, address, config)
if err != nil {
logger.Error("handshake-failed", err)
return nil, nil, nil, err
}
return conn, ch, req, nil
}
func (p *proxy) stop() error {
p.mutex.Lock()
defer p.mutex.Unlock()
if !p.stopping {
p.logger.Info("stopping-proxy")
p.stopping = true
p.listener.Close()
}
return nil
}
func (p *proxy) isStopping() bool {
p.mutex.Lock()
defer p.mutex.Unlock()
return p.stopping
}
func (p *proxy) discardRequests(in <-chan *ssh.Request) {
logger := p.logger.Session("proxy-discard-requests")
for req := range in {
logger.Info("discarding", lager.Data{
"type": req.Type,
"wantReply": req.WantReply,
})
if req.WantReply {
req.Reply(false, nil)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment