Created
February 27, 2015 11:38
-
-
Save sykesm/2f9710a2c892c4a854c2 to your computer and use it in GitHub Desktop.
Outline for a simple SSH proxy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package proxy | |
import ( | |
"io" | |
"net" | |
"os" | |
"sync" | |
"github.com/pivotal-golang/lager" | |
"github.com/tedsuo/ifrit" | |
"golang.org/x/crypto/ssh" | |
) | |
type SSHProxy interface { | |
ifrit.Runner | |
} | |
//go:generate counterfeiter -o fakes/fake_authenticator.go . Authenticator | |
type Authenticator interface { | |
Authenticate(metdata ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) | |
} | |
//go:generate counterfeiter -o fakes/fake_config_factory.go . ConfigFactory | |
type ConfigFactory interface { | |
Create(permissions *ssh.Permissions) (config *ssh.ClientConfig, address string, err error) | |
} | |
//go:generate counterfeiter -o fakes/fake_auditor.go . AuthenticationAuditor | |
type AuthenticationAuditor interface { | |
Audit(conn ssh.ConnMetadata, method string, err error) | |
} | |
type Waiter interface { | |
Wait() error | |
} | |
type proxy struct { | |
logger lager.Logger | |
proxyConfig *ssh.ServerConfig | |
authenticator Authenticator | |
configFactory ConfigFactory | |
listenAddr string | |
listener net.Listener | |
mutex *sync.Mutex | |
stopping bool | |
stopChan chan struct{} | |
} | |
func NewProxy( | |
logger lager.Logger, | |
hostPrivateKey string, | |
listenAddress string, | |
authenticator Authenticator, | |
configFactory ConfigFactory, | |
) SSHProxy { | |
sshConfig := &ssh.ServerConfig{} | |
proxy := &proxy{ | |
logger: logger, | |
proxyConfig: sshConfig, | |
authenticator: authenticator, | |
configFactory: configFactory, | |
listenAddr: listenAddress, | |
mutex: &sync.Mutex{}, | |
stopChan: make(chan struct{}), | |
} | |
privateKey, err := ssh.ParsePrivateKey([]byte(hostPrivateKey)) | |
if err != nil { | |
logger.Fatal("failed-to-parse-host-key", err) | |
} | |
sshConfig.AddHostKey(privateKey) | |
sshConfig.PasswordCallback = proxy.authenticator.Authenticate | |
sshConfig.AuthLogCallback = proxy.auditAuthentication | |
return proxy | |
} | |
func (p *proxy) auditAuthentication(conn ssh.ConnMetadata, method string, err error) { | |
logger := p.logger.Session("audit-authentication") | |
if err == nil { | |
logger.Info("success", lager.Data{ | |
"method": method, | |
"user": conn.User(), | |
"client-version": conn.ClientVersion(), | |
"remote-address": conn.RemoteAddr(), | |
"session-id": conn.SessionID(), | |
}) | |
} else { | |
logger.Info("failed", lager.Data{ | |
"user": conn.User(), | |
"remote-address": conn.RemoteAddr(), | |
"error": err, | |
}) | |
} | |
} | |
func (p *proxy) Run(signals <-chan os.Signal, ready chan<- struct{}) error { | |
logger := p.logger.Session("run") | |
logger.Info("started") | |
defer logger.Info("ended") | |
stopChan := make(chan struct{}) | |
listener, err := net.Listen("tcp", p.listenAddr) | |
if err != nil { | |
logger.Error("listen-failed", err) | |
return err | |
} | |
p.listener = listener | |
go p.acceptLoop(stopChan) | |
close(ready) | |
for { | |
select { | |
case <-signals: | |
return p.stop() | |
case <-stopChan: | |
return p.stop() | |
} | |
} | |
} | |
func (p *proxy) acceptLoop(stopChan chan<- struct{}) { | |
logger := p.logger.Session("accept-loop") | |
for { | |
netConn, err := p.listener.Accept() | |
if err != nil { | |
if p.isStopping() { | |
break | |
} else { | |
logger.Error("accept-failed", err) | |
continue | |
} | |
} | |
go p.handleConnection(netConn) | |
} | |
close(stopChan) | |
} | |
func (p *proxy) handleConnection(netConn net.Conn) { | |
logger := p.logger.Session("handle-connection") | |
logger.Info("started") | |
defer logger.Info("completed") | |
defer netConn.Close() | |
conn, channels, requests, err := ssh.NewServerConn(netConn, p.proxyConfig) | |
if err != nil { | |
logger.Error("handshake-failed", err) | |
return | |
} | |
defer conn.Close() | |
clientConfig, address, err := p.configFactory.Create(conn.Permissions) | |
if err != nil { | |
logger.Error("failed-to-create-client-config", err) | |
return | |
} | |
clientConn, clientChannels, clientRequests, err := p.NewClientConn(logger, address, clientConfig) | |
if err != nil { | |
logger.Error("failed-to-handshake-target", err) | |
return | |
} | |
defer clientConn.Close() | |
go proxyGlobalRequests(logger, clientConn, requests) | |
go proxyGlobalRequests(logger, conn, clientRequests) | |
go proxyChannels(logger, clientConn, channels) | |
go proxyChannels(logger, conn, clientChannels) | |
wait(logger, conn, clientConn) | |
} | |
func proxyChannels(logger lager.Logger, conn ssh.Conn, newChannelChan <-chan ssh.NewChannel) { | |
logger = logger.Session("proxy-channels") | |
logger.Info("started") | |
defer logger.Info("completed") | |
defer conn.Close() | |
for newChannel := range newChannelChan { | |
logger.Info("new-channel", lager.Data{ | |
"channelType": newChannel.ChannelType(), | |
"extraData": newChannel.ExtraData(), | |
}) | |
targetChan, targetReqs, err := conn.OpenChannel(newChannel.ChannelType(), newChannel.ExtraData()) | |
if err != nil { | |
logger.Error("failed-to-open-channel", err) | |
if openErr, ok := err.(*ssh.OpenChannelError); ok { | |
newChannel.Reject(openErr.Reason, openErr.Message) | |
} else { | |
newChannel.Reject(ssh.ConnectionFailed, openErr.Message) | |
} | |
continue | |
} | |
sourceChan, sourceReqs, err := newChannel.Accept() | |
if err != nil { | |
targetChan.Close() | |
continue | |
} | |
go copyChannel(logger, targetChan, sourceChan) | |
go copyChannel(logger, sourceChan, targetChan) | |
go proxyRequests(logger, newChannel.ChannelType(), sourceReqs, targetChan) | |
go proxyRequests(logger, newChannel.ChannelType(), targetReqs, sourceChan) | |
} | |
} | |
func proxyRequests(logger lager.Logger, channelType string, reqs <-chan *ssh.Request, channel ssh.Channel) { | |
logger = logger.Session("proxy-requests", lager.Data{ | |
"channel-type": channelType, | |
}) | |
logger.Info("started") | |
defer logger.Info("completed") | |
defer channel.Close() | |
for req := range reqs { | |
logger.Info("request", lager.Data{ | |
"type": req.Type, | |
"wantReply": req.WantReply, | |
"payload": req.Payload, | |
}) | |
success, err := channel.SendRequest(req.Type, req.WantReply, req.Payload) | |
if err != nil { | |
logger.Error("send-request-failed", err) | |
continue | |
} | |
if req.WantReply { | |
req.Reply(success, []byte{}) | |
} | |
} | |
} | |
func proxyGlobalRequests(logger lager.Logger, conn ssh.Conn, reqs <-chan *ssh.Request) { | |
logger = logger.Session("proxy-global-requests") | |
logger.Info("started") | |
defer logger.Info("completed") | |
for req := range reqs { | |
logger.Info("request", lager.Data{ | |
"type": req.Type, | |
"wantReply": req.WantReply, | |
"payload": req.Payload, | |
}) | |
success, reply, err := conn.SendRequest(req.Type, req.WantReply, req.Payload) | |
if err != nil { | |
logger.Error("send-request-failed", err) | |
continue | |
} | |
if req.WantReply { | |
req.Reply(success, reply) | |
} | |
} | |
} | |
func copyChannel(logger lager.Logger, dest ssh.Channel, src ssh.Channel) { | |
logger = logger.Session("copy-channel") | |
logger.Info("started") | |
defer logger.Info("completed") | |
io.Copy(dest, src) | |
dest.Close() | |
} | |
func wait(logger lager.Logger, waiters ...Waiter) { | |
wg := &sync.WaitGroup{} | |
for _, waiter := range waiters { | |
wg.Add(1) | |
go func() { | |
waiter.Wait() | |
wg.Done() | |
}() | |
} | |
wg.Wait() | |
} | |
func (p *proxy) NewClientConn(logger lager.Logger, address string, config *ssh.ClientConfig) (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request, error) { | |
logger = logger.Session("new-client-conn", lager.Data{ | |
"address": address, | |
}) | |
nConn, err := net.Dial("tcp", address) | |
if err != nil { | |
logger.Error("dial-failed", err) | |
return nil, nil, nil, err | |
} | |
conn, ch, req, err := ssh.NewClientConn(nConn, address, config) | |
if err != nil { | |
logger.Error("handshake-failed", err) | |
return nil, nil, nil, err | |
} | |
return conn, ch, req, nil | |
} | |
func (p *proxy) stop() error { | |
p.mutex.Lock() | |
defer p.mutex.Unlock() | |
if !p.stopping { | |
p.logger.Info("stopping-proxy") | |
p.stopping = true | |
p.listener.Close() | |
} | |
return nil | |
} | |
func (p *proxy) isStopping() bool { | |
p.mutex.Lock() | |
defer p.mutex.Unlock() | |
return p.stopping | |
} | |
func (p *proxy) discardRequests(in <-chan *ssh.Request) { | |
logger := p.logger.Session("proxy-discard-requests") | |
for req := range in { | |
logger.Info("discarding", lager.Data{ | |
"type": req.Type, | |
"wantReply": req.WantReply, | |
}) | |
if req.WantReply { | |
req.Reply(false, nil) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment