Skip to content

Instantly share code, notes, and snippets.

@StevenACoffman
Created April 27, 2018 15:47
Show Gist options
  • Save StevenACoffman/8e2096e7583f3a67fe3d6280b2cb882c to your computer and use it in GitHub Desktop.
Save StevenACoffman/8e2096e7583f3a67fe3d6280b2cb882c to your computer and use it in GitHub Desktop.
Add or remove SSH Certificate based on expiration and source ip
package main
import (
"os"
"io/ioutil"
"time"
"golang.org/x/crypto/ssh"
"net"
"golang.org/x/crypto/ssh/agent"
"crypto/rsa"
"errors"
"log"
"encoding/pem"
"crypto/x509"
)
var rootKeyName string
func init() {
rootKeyName = getEnv("ROOT_KEY_NAME", "bless_rsa")
}
func main() {
myIP := getMyIP()
log.Println("Got IP")
pubKey, pubKeyErr := getPublicKey(rootKeyName)
log.Println("Got PublicKey")
if pubKeyErr != nil {
log.Fatalf("PublicKey error %v", pubKeyErr )
}
privKey, privKeyErr := getPrivateKey(rootKeyName)
log.Println("Got PrivateKey")
if privKeyErr != nil {
log.Fatalf("PrivateKey error %v", privKeyErr )
}
blessCert, blessCertErr := getValidBlessCert(rootKeyName, myIP)
log.Println("Got BlessCert")
if blessCertErr != nil {
removeKey(pubKey)
log.Println("Removed Invalid Blessed Public Key from SSH Agent")
os.Exit(1)
} else {
addKey(privKey, blessCert)
log.Println("Added Valid Blessed Public Key to SSH Agent")
os.Exit(0)
}
}
func getEnv(key, fallback string) string {
if value, ok := os.LookupEnv(key); ok {
return value
}
return fallback
}
func getPublicKey(rootKeyName string) ( ssh.PublicKey, error){
bytes, err := ioutil.ReadFile(os.Getenv("HOME") + "/.ssh/"+rootKeyName+".pub")
if err != nil {
log.Fatalf("Fatal error trying to read public key file: %s", err)
return nil, errors.New("fatal error trying to read public key file")
}
newAuthorizedKey, _, _, _, err := ssh.ParseAuthorizedKey(bytes)
return newAuthorizedKey, err
}
func getPrivateKey(rootKeyName string) ( *rsa.PrivateKey, error){
bytes, err := ioutil.ReadFile(os.Getenv("HOME") + "/.ssh/"+rootKeyName)
if err != nil {
log.Fatalf("Fatal error trying to read private key file: %s", err)
return nil, errors.New("unable to get Private Key")
}
// Extract the PEM-encoded data block
block, _ := pem.Decode(bytes)
if block == nil {
log.Fatalf("bad key data: %s", "not PEM-encoded")
return nil, errors.New("unable to get Private Key")
}
if got, want := block.Type, "RSA PRIVATE KEY"; got != want {
log.Fatalf("unknown key type %q, want %q", got, want)
return nil, errors.New("unable to get Private Key")
}
// Decode the RSA private key
return x509.ParsePKCS1PrivateKey(block.Bytes)
}
func getValidBlessCert(rootKeyName string, currentIP string) (*ssh.Certificate, error) {
// read in public key from file
bytes, err := ioutil.ReadFile(os.Getenv("HOME") + "/.ssh/"+rootKeyName+"-cert.pub")
if err != nil {
log.Fatalf("Fatal error trying to read Bless Certificate file: %s", err)
return nil, errors.New("unable to get Bless Certificate")
}
pubkey, _, _, _, err := ssh.ParseAuthorizedKey(bytes)
if err != nil {
log.Fatalf("Fatal error trying to parse Bless Cert: %s", err)
return nil, errors.New("unable to parse Bless Cert:")
}
cert, ok := pubkey.(*ssh.Certificate)
if !ok {
log.Fatalf("got %v (%T), want *Certificate", pubkey, pubkey)
return nil, errors.New("Unable to get Bless Cert")
}
now := time.Now()
unixNow := now.Unix()
const CertTimeInfinity = 1<<64 - 1
sourceAddress := cert.CriticalOptions["source-address"]
if currentIP != sourceAddress{
log.Print("MyIP %v does not match Source Address %v", currentIP, sourceAddress)
return nil, errors.New("Current IP "+currentIP+" does not match Bless Cert Source Address "+sourceAddress)
}
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
log.Print("ssh: cert is not yet valid")
return nil, errors.New("ssh: cert is not yet valid")
}
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) {
log.Print("ssh: cert has expired")
return nil, errors.New("ssh: cert has expired")
}
return cert, nil
}
func getMyIP() string {
addrs, err := net.InterfaceAddrs()
if err != nil {
log.Fatalf("Oops: " + err.Error() + "\n")
os.Exit(1)
}
for _, a := range addrs {
if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
if ipnet.IP.To4() != nil {
return ipnet.IP.String()
}
}
}
return ""
}
func addKey( key *rsa.PrivateKey, cert *ssh.Certificate) {
if key == nil {
return
}
conn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err != nil {
log.Fatalf("Unable to connect to SSH Agent %v", err)
}
defer conn.Close()
agentClient := agent.NewClient(conn)
err = agentClient.Add(agent.AddedKey{
PrivateKey: key,
Certificate: cert,
LifetimeSecs: 14440,
})
if err != nil {
log.Fatalf("failed to add key: %v", err)
}
}
func removeKey(key ssh.PublicKey) {
if key == nil {
log.Fatalf("Unable to Remove Empty Key")
return
}
conn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err != nil {
log.Fatalf("Unable to remove Key")
}
defer conn.Close()
agentClient := agent.NewClient(conn)
err = agentClient.Remove(key)
if err != nil {
log.Print("failed to remove key %q: %v", key, err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment