Skip to content

Instantly share code, notes, and snippets.

@stevvooe
Created December 15, 2015 22:29
Show Gist options
  • Save stevvooe/2639d32366925d7bb898 to your computer and use it in GitHub Desktop.
Save stevvooe/2639d32366925d7bb898 to your computer and use it in GitHub Desktop.
Simple SSH Client
// Package main implements a simple ssh client, roughly similar to the
// standard OpenSSH client. It uses golang.org/x/crypto/ssh and integrates
// with the ssh agent.
//
// Mostly, this demonstrates correct resource management and cleanup, as well
// as command dispatch. Obviously, this has never been vetted for security, so
// please don't use for anything serious.
package main
import (
"fmt"
"log"
"net"
"os"
userpkg "os/user"
"strings"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
func main() {
log.SetPrefix("ssh: ")
log.SetFlags(0)
var (
addr = os.Args[1]
command = strings.Join(os.Args[2:], " ")
user string
err error
)
user, addr, err = parseAddr(addr)
if err != nil {
log.Fatalln(err)
}
agentConn, err := dialAgent()
if err != nil {
log.Fatalln(err)
}
defer agentConn.Close()
config := ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{
ssh.PublicKeysCallback(agent.NewClient(agentConn).Signers),
},
}
log.Println("connecting", fmt.Sprintf("%v@%v", user, addr))
client, err := ssh.Dial("tcp", addr, &config)
if err != nil {
log.Fatalln(err)
}
defer client.Close()
cmd, err := client.NewSession()
if err != nil {
log.Fatalln(err)
}
defer cmd.Close()
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(command); err != nil {
switch v := err.(type) {
case *ssh.ExitError:
os.Exit(v.Waitmsg.ExitStatus())
default:
log.Fatalln(err)
}
}
}
// parseAddr decomposes an ssh-style address, possible extracting the user and
// port. If successful, a user and Go-style addr are returned. If no user is
// provided, u will be used as the user information. Unlike ssh addresses, we
// allow users to specify a port.
func parseAddr(s string) (user, addr string, err error) {
addr = s
userhost := strings.SplitN(addr, "@", 2)
if len(userhost) == 2 {
user = userhost[0]
addr = userhost[1]
} else {
addr = userhost[0]
usr, err := userpkg.Current()
if err != nil {
return "", "", err
}
user = usr.Username
}
host, port, err := net.SplitHostPort(addr)
if err != nil {
if !strings.Contains(err.Error(), "missing port in address") {
return
}
host = addr
port = "22"
}
return user, net.JoinHostPort(host, port), nil
}
func dialAgent() (net.Conn, error) {
socketPath := os.Getenv("SSH_AUTH_SOCK")
c, err := net.Dial("unix", socketPath)
if err != nil {
return nil, err
}
return c, nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment