Created
November 5, 2020 11:17
-
-
Save meetme2meat/5652d7b803f5a57760fafd08e5d0d55d to your computer and use it in GitHub Desktop.
SFTP-proxy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package server | |
import ( | |
"errors" | |
"fmt" | |
"io" | |
"io/ioutil" | |
"time" | |
log "github.com/sirupsen/logrus" | |
"net" | |
"golang.org/x/crypto/ssh" | |
) | |
type Config struct { | |
Server struct { | |
Bind string | |
PrivateKey string | |
Auth sshAuth | |
AuthorizedKeys string | |
} | |
Remote struct { | |
Host string | |
Port int | |
Auth sshAuth | |
} | |
} | |
type sshAuth struct { | |
Type string | |
User string | |
Password string | |
PrivateKey string | |
} | |
type Server struct { | |
c *Config | |
} | |
func Start(config *Config) { | |
server := newServer(config) | |
go server.run() | |
} | |
func newServer(config *Config) *Server { | |
return &Server{c: config} | |
} | |
func (s *Server) run() { | |
log.Debugf("Listening at %s", s.c.Server.Bind) | |
s.listen(s.c.Server.Bind) | |
} | |
func (server *Server) listen(addr string) { | |
listener, err := net.Listen("tcp", addr) | |
if err != nil { | |
log.Fatalf("could not start the server %s", err) | |
} | |
go server.Accept(listener) | |
} | |
func (s *Server) Accept(l net.Listener) { | |
for { | |
conn, err := l.Accept() | |
if err != nil { | |
// The server should not die | |
fmt.Printf("failed to accept incoming connection %s", err.Error()) | |
continue | |
} | |
go s.handleConnection(conn) | |
} | |
} | |
func (s *Server) handleConnection(conn net.Conn) { | |
// check the TCP connection whether they obey handshake | |
defer func() { | |
if r := recover(); r != nil { | |
fmt.Println("Recovered the handleConnection", r) | |
} | |
}() | |
log.Debugf("Received connection from %s", conn.RemoteAddr()) | |
var rClient *ssh.Client | |
config := s.buildServerConfig(&rClient) | |
privateBytes, err := ioutil.ReadFile(s.c.Server.PrivateKey) | |
if err != nil { | |
panic("Failed to load private key") | |
} | |
private, err := ssh.ParsePrivateKey(privateBytes) | |
if err != nil { | |
panic("Failed to parse private key") | |
} | |
// add the private key | |
config.AddHostKey(private) | |
sconn, chans, reqs, err := ssh.NewServerConn(conn, config) | |
if err != nil { | |
log.Errorf("Could not establish connection with %s : %v", conn.RemoteAddr().String(), err) | |
} | |
defer sconn.Close() | |
defer rClient.Close() | |
// we will discarding all out-of-band request (essentially global request as they are not channel request) | |
// like SSHKEEPALIVE which does not require a reply | |
go ssh.DiscardRequests(reqs) | |
for inputChannel := range chans { | |
log.Debug("Received a new channel again") | |
s.handleChannel(inputChannel, rClient) | |
} | |
log.Debugf("Lost connection with %s", conn.RemoteAddr()) | |
} | |
func (s *Server) handleChannel(newChannel ssh.NewChannel, rClient *ssh.Client) { | |
// check if the channel is a terminal channel | |
if newChannel.ChannelType() != "session" { | |
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type: "+newChannel.ChannelType()) | |
return | |
} | |
// accept the request | |
// this is meat of the code | |
// The inputChannel where SSH_MSG_CHANNEL_DATA is received | |
// and inputReq where SSH_MSG_CHANNEL_REQUEST is received | |
inputChannel, inputReq, err := newChannel.Accept() | |
if err != nil { | |
panic("could not accept channel") | |
} | |
// this would open the remote client channel | |
// which would be use copy the stream back and forth i.e SSH_MSG_CHANNEL_DATA | |
// reference used here https://github.com/jpillora/go-and-ssh/blob/master/channels/client.go#L41 | |
outputChannel, outputReq, err := rClient.OpenChannel(newChannel.ChannelType(), nil) | |
if err != nil { | |
panic("could not open channel") | |
} | |
go s.bypass(inputChannel, outputChannel, inputReq, outputReq) | |
time.Sleep(5 * time.Second) | |
go s.copyStream(outputChannel, inputChannel) | |
go s.copyStream(inputChannel, outputChannel) | |
} | |
func (s *Server) bypass(chan1, chan2 ssh.Channel, req1, req2 <-chan *ssh.Request) { | |
defer func() { | |
if r := recover(); r != nil { | |
fmt.Println("Recovered in bypass", r) | |
} | |
}() | |
defer chan2.Close() | |
defer chan1.Close() | |
for { | |
select { | |
case req, ok := <-req1: | |
if !ok { | |
// if the channel is closed | |
return | |
} | |
if err := s.forwardRequest(req, chan2); err != nil { | |
fmt.Println("forward Error: " + err.Error()) | |
continue | |
} | |
case req, ok := <-req2: | |
if !ok { | |
return | |
} | |
if err := s.forwardRequest(req, chan1); err != nil { | |
fmt.Println("forward Error: " + err.Error()) | |
continue | |
} | |
} | |
} | |
} | |
func (s *Server) forwardRequest(req *ssh.Request, channel ssh.Channel) error { | |
reply, err := channel.SendRequest(req.Type, req.WantReply, req.Payload) | |
if err != nil { | |
return err | |
} | |
if req.WantReply { | |
req.Reply(reply, nil) | |
} | |
return nil | |
} | |
func (s *Server) copyStream(writer, reader ssh.Channel) { | |
_, err := io.Copy(writer, reader) | |
if err != nil { | |
log.Errorf("Copy stream error %s", err) | |
} | |
writer.CloseWrite() | |
log.Debug("Closing writer") | |
} | |
func (s *Server) buildServerConfig(rClient **ssh.Client) *ssh.ServerConfig { | |
config := &ssh.ServerConfig{ | |
ServerVersion: "SSH-2.0-ProxyServer", | |
PasswordCallback: s.passwordHook(rClient), | |
PublicKeyCallback: s.publicKeyHook(rClient), | |
} | |
return config | |
} | |
func (s *Server) publicKeyHook(rclient **ssh.Client) func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { | |
// https://github.com/golang/crypto/blob/master/ssh/example_test.go | |
// extracted from here. | |
authorizedKeysBytes, err := ioutil.ReadFile(s.c.Server.AuthorizedKeys) | |
if err != nil { | |
log.Fatalf("Failed to load authorized_keys, err: %v", err) | |
} | |
authorizedKeysMap := map[string]bool{} | |
for len(authorizedKeysBytes) > 0 { | |
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) | |
if err != nil { | |
log.Fatalf("ParseAuthorized Key : %v", err) | |
} | |
authorizedKeysMap[string(pubKey.Marshal())] = true | |
authorizedKeysBytes = rest | |
} | |
return func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { | |
if c.User() == s.c.Server.Auth.User { | |
if authorizedKeysMap[string(pubKey.Marshal())] { | |
*rclient, err = s.remoteSSHClient() | |
if err == nil { | |
return &ssh.Permissions{ | |
// Record the public key used for authentication. | |
Extensions: map[string]string{ | |
"pubkey-fp": ssh.FingerprintSHA256(pubKey), | |
}, | |
}, nil | |
} | |
} | |
} | |
return nil, fmt.Errorf("unknown public key for %q", c.User()) | |
} | |
} | |
func (s *Server) passwordHook(rclient **ssh.Client) func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { | |
return func(c ssh.ConnMetadata, passwd []byte) (*ssh.Permissions, error) { | |
// verify host | |
host := s.c.Remote.Host | |
// just a fail safe since we have already done the check at bootstrap | |
if host == "" { | |
return nil, fmt.Errorf("unknown user %s", c.User()) | |
} | |
if c.User() != s.c.Server.Auth.User { | |
return nil, fmt.Errorf("unknown user %s", c.User()) | |
} | |
if s.c.Server.Auth.Password != "" && s.c.Server.Auth.Password == string(passwd) { | |
var err error | |
*rclient, err = s.remoteSSHClient() | |
if err != nil { | |
log.Errorf("Could not authorize %s on %s: %s", c.User(), c.RemoteAddr().String(), err) | |
return nil, fmt.Errorf("Could not authorize %s on %s: %s", | |
c.User(), c.RemoteAddr().String(), err) | |
} | |
log.Debugf("User %s authenticated", s.c.Server.Auth.User) | |
return nil, nil | |
} | |
return nil, errors.New("passwords do not match") | |
} | |
} | |
func (s *Server) remoteSSHClient() (*ssh.Client, error) { | |
key, err := ioutil.ReadFile(s.c.Remote.Auth.PrivateKey) | |
if err != nil { | |
return nil, fmt.Errorf("unable to connect: not a valid key") | |
} | |
signer, err := ssh.ParsePrivateKey(key) | |
if err != nil { | |
return nil, fmt.Errorf("Parsing err: %v", err) | |
} | |
config := &ssh.ClientConfig{ | |
User: s.c.Remote.Auth.User, | |
HostKeyCallback: ssh.InsecureIgnoreHostKey(), | |
Auth: []ssh.AuthMethod{ | |
ssh.PublicKeys(signer), | |
}, | |
} | |
conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", s.c.Remote.Host, s.c.Remote.Port), config) | |
if err != nil { | |
return nil, fmt.Errorf("remote server connect error %v", err) | |
} | |
return conn, nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment