Created
June 9, 2017 21:20
-
-
Save meatballhat/854b13af8305f3ea4468d0bf9a1f5f7b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/acme/acme.go b/acme/acme.go | |
index 8619508..a7b6ce4 100644 | |
--- a/acme/acme.go | |
+++ b/acme/acme.go | |
@@ -15,6 +15,7 @@ package acme | |
import ( | |
"bytes" | |
+ "context" | |
"crypto" | |
"crypto/ecdsa" | |
"crypto/elliptic" | |
@@ -36,9 +37,6 @@ import ( | |
"strings" | |
"sync" | |
"time" | |
- | |
- "golang.org/x/net/context" | |
- "golang.org/x/net/context/ctxhttp" | |
) | |
// LetsEncryptURL is the Directory endpoint of Let's Encrypt CA. | |
@@ -133,7 +131,7 @@ func (c *Client) Discover(ctx context.Context) (Directory, error) { | |
if dirURL == "" { | |
dirURL = LetsEncryptURL | |
} | |
- res, err := ctxhttp.Get(ctx, c.HTTPClient, dirURL) | |
+ res, err := c.get(ctx, dirURL) | |
if err != nil { | |
return Directory{}, err | |
} | |
@@ -154,7 +152,7 @@ func (c *Client) Discover(ctx context.Context) (Directory, error) { | |
CAA []string `json:"caa-identities"` | |
} | |
} | |
- if json.NewDecoder(res.Body).Decode(&v); err != nil { | |
+ if err := json.NewDecoder(res.Body).Decode(&v); err != nil { | |
return Directory{}, err | |
} | |
c.dir = &Directory{ | |
@@ -200,7 +198,7 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration, | |
req.NotAfter = now.Add(exp).Format(time.RFC3339) | |
} | |
- res, err := c.postJWS(ctx, c.Key, c.dir.CertURL, req) | |
+ res, err := c.retryPostJWS(ctx, c.Key, c.dir.CertURL, req) | |
if err != nil { | |
return nil, "", err | |
} | |
@@ -216,7 +214,7 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration, | |
return cert, curl, err | |
} | |
// slurp issued cert and CA chain, if requested | |
- cert, err := responseCert(ctx, c.HTTPClient, res, bundle) | |
+ cert, err := c.responseCert(ctx, res, bundle) | |
return cert, curl, err | |
} | |
@@ -231,13 +229,13 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration, | |
// and has expected features. | |
func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) { | |
for { | |
- res, err := ctxhttp.Get(ctx, c.HTTPClient, url) | |
+ res, err := c.get(ctx, url) | |
if err != nil { | |
return nil, err | |
} | |
defer res.Body.Close() | |
if res.StatusCode == http.StatusOK { | |
- return responseCert(ctx, c.HTTPClient, res, bundle) | |
+ return c.responseCert(ctx, res, bundle) | |
} | |
if res.StatusCode > 299 { | |
return nil, responseError(res) | |
@@ -275,7 +273,7 @@ func (c *Client) RevokeCert(ctx context.Context, key crypto.Signer, cert []byte, | |
if key == nil { | |
key = c.Key | |
} | |
- res, err := c.postJWS(ctx, key, c.dir.RevokeURL, body) | |
+ res, err := c.retryPostJWS(ctx, key, c.dir.RevokeURL, body) | |
if err != nil { | |
return err | |
} | |
@@ -363,7 +361,7 @@ func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization, | |
Resource: "new-authz", | |
Identifier: authzID{Type: "dns", Value: domain}, | |
} | |
- res, err := c.postJWS(ctx, c.Key, c.dir.AuthzURL, req) | |
+ res, err := c.retryPostJWS(ctx, c.Key, c.dir.AuthzURL, req) | |
if err != nil { | |
return nil, err | |
} | |
@@ -387,7 +385,7 @@ func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization, | |
// If a caller needs to poll an authorization until its status is final, | |
// see the WaitAuthorization method. | |
func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) { | |
- res, err := ctxhttp.Get(ctx, c.HTTPClient, url) | |
+ res, err := c.get(ctx, url) | |
if err != nil { | |
return nil, err | |
} | |
@@ -421,7 +419,7 @@ func (c *Client) RevokeAuthorization(ctx context.Context, url string) error { | |
Status: "deactivated", | |
Delete: true, | |
} | |
- res, err := c.postJWS(ctx, c.Key, url, req) | |
+ res, err := c.retryPostJWS(ctx, c.Key, url, req) | |
if err != nil { | |
return err | |
} | |
@@ -438,25 +436,11 @@ func (c *Client) RevokeAuthorization(ctx context.Context, url string) error { | |
// | |
// It returns a non-nil Authorization only if its Status is StatusValid. | |
// In all other cases WaitAuthorization returns an error. | |
-// If the Status is StatusInvalid, the returned error is ErrAuthorizationFailed. | |
+// If the Status is StatusInvalid, the returned error is of type *AuthorizationError. | |
func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorization, error) { | |
- var count int | |
- sleep := func(v string, inc int) error { | |
- count += inc | |
- d := backoff(count, 10*time.Second) | |
- d = retryAfter(v, d) | |
- wakeup := time.NewTimer(d) | |
- defer wakeup.Stop() | |
- select { | |
- case <-ctx.Done(): | |
- return ctx.Err() | |
- case <-wakeup.C: | |
- return nil | |
- } | |
- } | |
- | |
+ sleep := sleeper(ctx) | |
for { | |
- res, err := ctxhttp.Get(ctx, c.HTTPClient, url) | |
+ res, err := c.get(ctx, url) | |
if err != nil { | |
return nil, err | |
} | |
@@ -481,7 +465,7 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat | |
return raw.authorization(url), nil | |
} | |
if raw.Status == StatusInvalid { | |
- return nil, ErrAuthorizationFailed | |
+ return nil, raw.error(url) | |
} | |
if err := sleep(retry, 0); err != nil { | |
return nil, err | |
@@ -493,7 +477,7 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat | |
// | |
// A client typically polls a challenge status using this method. | |
func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) { | |
- res, err := ctxhttp.Get(ctx, c.HTTPClient, url) | |
+ res, err := c.get(ctx, url) | |
if err != nil { | |
return nil, err | |
} | |
@@ -527,7 +511,7 @@ func (c *Client) Accept(ctx context.Context, chal *Challenge) (*Challenge, error | |
Type: chal.Type, | |
Auth: auth, | |
} | |
- res, err := c.postJWS(ctx, c.Key, chal.URI, req) | |
+ res, err := c.retryPostJWS(ctx, c.Key, chal.URI, req) | |
if err != nil { | |
return nil, err | |
} | |
@@ -660,7 +644,7 @@ func (c *Client) doReg(ctx context.Context, url string, typ string, acct *Accoun | |
req.Contact = acct.Contact | |
req.Agreement = acct.AgreedTerms | |
} | |
- res, err := c.postJWS(ctx, c.Key, url, req) | |
+ res, err := c.retryPostJWS(ctx, c.Key, url, req) | |
if err != nil { | |
return nil, err | |
} | |
@@ -697,6 +681,40 @@ func (c *Client) doReg(ctx context.Context, url string, typ string, acct *Accoun | |
}, nil | |
} | |
+// retryPostJWS will retry calls to postJWS if there is a badNonce error, | |
+// clearing the stored nonces after each error. | |
+// If the response was 4XX-5XX, then responseError is called on the body, | |
+// the body is closed, and the error returned. | |
+func (c *Client) retryPostJWS(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, error) { | |
+ sleep := sleeper(ctx) | |
+ for { | |
+ res, err := c.postJWS(ctx, key, url, body) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ // handle errors 4XX-5XX with responseError | |
+ if res.StatusCode >= 400 && res.StatusCode <= 599 { | |
+ err := responseError(res) | |
+ res.Body.Close() | |
+ // according to spec badNonce is urn:ietf:params:acme:error:badNonce | |
+ // however, acme servers in the wild return their version of the error | |
+ // https://tools.ietf.org/html/draft-ietf-acme-acme-02#section-5.4 | |
+ if ae, ok := err.(*Error); ok && strings.HasSuffix(strings.ToLower(ae.ProblemType), ":badnonce") { | |
+ // clear any nonces that we might've stored that might now be | |
+ // considered bad | |
+ c.clearNonces() | |
+ retry := res.Header.Get("retry-after") | |
+ if err := sleep(retry, 1); err != nil { | |
+ return nil, err | |
+ } | |
+ continue | |
+ } | |
+ return nil, err | |
+ } | |
+ return res, nil | |
+ } | |
+} | |
+ | |
// postJWS signs the body with the given key and POSTs it to the provided url. | |
// The body argument must be JSON-serializable. | |
func (c *Client) postJWS(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, error) { | |
@@ -708,7 +726,7 @@ func (c *Client) postJWS(ctx context.Context, key crypto.Signer, url string, bod | |
if err != nil { | |
return nil, err | |
} | |
- res, err := ctxhttp.Post(ctx, c.HTTPClient, url, "application/jose+json", bytes.NewReader(b)) | |
+ res, err := c.post(ctx, url, "application/jose+json", bytes.NewReader(b)) | |
if err != nil { | |
return nil, err | |
} | |
@@ -722,7 +740,7 @@ func (c *Client) popNonce(ctx context.Context, url string) (string, error) { | |
c.noncesMu.Lock() | |
defer c.noncesMu.Unlock() | |
if len(c.nonces) == 0 { | |
- return fetchNonce(ctx, c.HTTPClient, url) | |
+ return c.fetchNonce(ctx, url) | |
} | |
var nonce string | |
for nonce = range c.nonces { | |
@@ -732,6 +750,13 @@ func (c *Client) popNonce(ctx context.Context, url string) (string, error) { | |
return nonce, nil | |
} | |
+// clearNonces clears any stored nonces | |
+func (c *Client) clearNonces() { | |
+ c.noncesMu.Lock() | |
+ defer c.noncesMu.Unlock() | |
+ c.nonces = make(map[string]struct{}) | |
+} | |
+ | |
// addNonce stores a nonce value found in h (if any) for future use. | |
func (c *Client) addNonce(h http.Header) { | |
v := nonceFromHeader(h) | |
@@ -749,8 +774,58 @@ func (c *Client) addNonce(h http.Header) { | |
c.nonces[v] = struct{}{} | |
} | |
-func fetchNonce(ctx context.Context, client *http.Client, url string) (string, error) { | |
- resp, err := ctxhttp.Head(ctx, client, url) | |
+func (c *Client) httpClient() *http.Client { | |
+ if c.HTTPClient != nil { | |
+ return c.HTTPClient | |
+ } | |
+ return http.DefaultClient | |
+} | |
+ | |
+func (c *Client) get(ctx context.Context, urlStr string) (*http.Response, error) { | |
+ req, err := http.NewRequest("GET", urlStr, nil) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ return c.do(ctx, req) | |
+} | |
+ | |
+func (c *Client) head(ctx context.Context, urlStr string) (*http.Response, error) { | |
+ req, err := http.NewRequest("HEAD", urlStr, nil) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ return c.do(ctx, req) | |
+} | |
+ | |
+func (c *Client) post(ctx context.Context, urlStr, contentType string, body io.Reader) (*http.Response, error) { | |
+ req, err := http.NewRequest("POST", urlStr, body) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ req.Header.Set("Content-Type", contentType) | |
+ return c.do(ctx, req) | |
+} | |
+ | |
+func (c *Client) do(ctx context.Context, req *http.Request) (*http.Response, error) { | |
+ res, err := c.httpClient().Do(req.WithContext(ctx)) | |
+ if err != nil { | |
+ select { | |
+ case <-ctx.Done(): | |
+ // Prefer the unadorned context error. | |
+ // (The acme package had tests assuming this, previously from ctxhttp's | |
+ // behavior, predating net/http supporting contexts natively) | |
+ // TODO(bradfitz): reconsider this in the future. But for now this | |
+ // requires no test updates. | |
+ return nil, ctx.Err() | |
+ default: | |
+ return nil, err | |
+ } | |
+ } | |
+ return res, nil | |
+} | |
+ | |
+func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) { | |
+ resp, err := c.head(ctx, url) | |
if err != nil { | |
return "", err | |
} | |
@@ -769,7 +844,7 @@ func nonceFromHeader(h http.Header) string { | |
return h.Get("Replay-Nonce") | |
} | |
-func responseCert(ctx context.Context, client *http.Client, res *http.Response, bundle bool) ([][]byte, error) { | |
+func (c *Client) responseCert(ctx context.Context, res *http.Response, bundle bool) ([][]byte, error) { | |
b, err := ioutil.ReadAll(io.LimitReader(res.Body, maxCertSize+1)) | |
if err != nil { | |
return nil, fmt.Errorf("acme: response stream: %v", err) | |
@@ -793,7 +868,7 @@ func responseCert(ctx context.Context, client *http.Client, res *http.Response, | |
return nil, errors.New("acme: rel=up link is too large") | |
} | |
for _, url := range up { | |
- cc, err := chainCert(ctx, client, url, 0) | |
+ cc, err := c.chainCert(ctx, url, 0) | |
if err != nil { | |
return nil, err | |
} | |
@@ -807,14 +882,8 @@ func responseError(resp *http.Response) error { | |
// don't care if ReadAll returns an error: | |
// json.Unmarshal will fail in that case anyway | |
b, _ := ioutil.ReadAll(resp.Body) | |
- e := struct { | |
- Status int | |
- Type string | |
- Detail string | |
- }{ | |
- Status: resp.StatusCode, | |
- } | |
- if err := json.Unmarshal(b, &e); err != nil { | |
+ e := &wireError{Status: resp.StatusCode} | |
+ if err := json.Unmarshal(b, e); err != nil { | |
// this is not a regular error response: | |
// populate detail with anything we received, | |
// e.Status will already contain HTTP response code value | |
@@ -823,12 +892,7 @@ func responseError(resp *http.Response) error { | |
e.Detail = resp.Status | |
} | |
} | |
- return &Error{ | |
- StatusCode: e.Status, | |
- ProblemType: e.Type, | |
- Detail: e.Detail, | |
- Header: resp.Header, | |
- } | |
+ return e.error(resp.Header) | |
} | |
// chainCert fetches CA certificate chain recursively by following "up" links. | |
@@ -836,12 +900,12 @@ func responseError(resp *http.Response) error { | |
// if the recursion level reaches maxChainLen. | |
// | |
// First chainCert call starts with depth of 0. | |
-func chainCert(ctx context.Context, client *http.Client, url string, depth int) ([][]byte, error) { | |
+func (c *Client) chainCert(ctx context.Context, url string, depth int) ([][]byte, error) { | |
if depth >= maxChainLen { | |
return nil, errors.New("acme: certificate chain is too deep") | |
} | |
- res, err := ctxhttp.Get(ctx, client, url) | |
+ res, err := c.get(ctx, url) | |
if err != nil { | |
return nil, err | |
} | |
@@ -863,7 +927,7 @@ func chainCert(ctx context.Context, client *http.Client, url string, depth int) | |
return nil, errors.New("acme: certificate chain is too large") | |
} | |
for _, up := range uplink { | |
- cc, err := chainCert(ctx, client, up, depth+1) | |
+ cc, err := c.chainCert(ctx, up, depth+1) | |
if err != nil { | |
return nil, err | |
} | |
@@ -893,6 +957,28 @@ func linkHeader(h http.Header, rel string) []string { | |
return links | |
} | |
+// sleeper returns a function that accepts the Retry-After HTTP header value | |
+// and an increment that's used with backoff to increasingly sleep on | |
+// consecutive calls until the context is done. If the Retry-After header | |
+// cannot be parsed, then backoff is used with a maximum sleep time of 10 | |
+// seconds. | |
+func sleeper(ctx context.Context) func(ra string, inc int) error { | |
+ var count int | |
+ return func(ra string, inc int) error { | |
+ count += inc | |
+ d := backoff(count, 10*time.Second) | |
+ d = retryAfter(ra, d) | |
+ wakeup := time.NewTimer(d) | |
+ defer wakeup.Stop() | |
+ select { | |
+ case <-ctx.Done(): | |
+ return ctx.Err() | |
+ case <-wakeup.C: | |
+ return nil | |
+ } | |
+ } | |
+} | |
+ | |
// retryAfter parses a Retry-After HTTP header value, | |
// trying to convert v into an int (seconds) or use http.ParseTime otherwise. | |
// It returns d if v cannot be parsed. | |
@@ -974,7 +1060,8 @@ func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) { | |
NotBefore: time.Now(), | |
NotAfter: time.Now().Add(24 * time.Hour), | |
BasicConstraintsValid: true, | |
- KeyUsage: x509.KeyUsageKeyEncipherment, | |
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, | |
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, | |
} | |
} | |
tmpl.DNSNames = san | |
diff --git a/acme/acme_test.go b/acme/acme_test.go | |
index 1205dbb..a4d276d 100644 | |
--- a/acme/acme_test.go | |
+++ b/acme/acme_test.go | |
@@ -6,6 +6,7 @@ package acme | |
import ( | |
"bytes" | |
+ "context" | |
"crypto/rand" | |
"crypto/rsa" | |
"crypto/tls" | |
@@ -23,8 +24,6 @@ import ( | |
"strings" | |
"testing" | |
"time" | |
- | |
- "golang.org/x/net/context" | |
) | |
// Decodes a JWS-encoded request and unmarshals the decoded JSON into a provided | |
@@ -544,6 +543,9 @@ func TestWaitAuthorizationInvalid(t *testing.T) { | |
if err == nil { | |
t.Error("err is nil") | |
} | |
+ if _, ok := err.(*AuthorizationError); !ok { | |
+ t.Errorf("err is %T; want *AuthorizationError", err) | |
+ } | |
} | |
} | |
@@ -981,7 +983,8 @@ func TestNonce_fetch(t *testing.T) { | |
defer ts.Close() | |
for ; i < len(tests); i++ { | |
test := tests[i] | |
- n, err := fetchNonce(context.Background(), http.DefaultClient, ts.URL) | |
+ c := &Client{} | |
+ n, err := c.fetchNonce(context.Background(), ts.URL) | |
if n != test.nonce { | |
t.Errorf("%d: n=%q; want %q", i, n, test.nonce) | |
} | |
@@ -999,7 +1002,8 @@ func TestNonce_fetchError(t *testing.T) { | |
w.WriteHeader(http.StatusTooManyRequests) | |
})) | |
defer ts.Close() | |
- _, err := fetchNonce(context.Background(), http.DefaultClient, ts.URL) | |
+ c := &Client{} | |
+ _, err := c.fetchNonce(context.Background(), ts.URL) | |
e, ok := err.(*Error) | |
if !ok { | |
t.Fatalf("err is %T; want *Error", err) | |
@@ -1064,6 +1068,44 @@ func TestNonce_postJWS(t *testing.T) { | |
} | |
} | |
+func TestRetryPostJWS(t *testing.T) { | |
+ var count int | |
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
+ count++ | |
+ w.Header().Set("replay-nonce", fmt.Sprintf("nonce%d", count)) | |
+ if r.Method == "HEAD" { | |
+ // We expect the client to do 2 head requests to fetch | |
+ // nonces, one to start and another after getting badNonce | |
+ return | |
+ } | |
+ | |
+ head, err := decodeJWSHead(r) | |
+ if err != nil { | |
+ t.Errorf("decodeJWSHead: %v", err) | |
+ } else if head.Nonce == "" { | |
+ t.Error("head.Nonce is empty") | |
+ } else if head.Nonce == "nonce1" { | |
+ // return a badNonce error to force the call to retry | |
+ w.WriteHeader(http.StatusBadRequest) | |
+ w.Write([]byte(`{"type":"urn:ietf:params:acme:error:badNonce"}`)) | |
+ return | |
+ } | |
+ // Make client.Authorize happy; we're not testing its result. | |
+ w.WriteHeader(http.StatusCreated) | |
+ w.Write([]byte(`{"status":"valid"}`)) | |
+ })) | |
+ defer ts.Close() | |
+ | |
+ client := Client{Key: testKey, dir: &Directory{AuthzURL: ts.URL}} | |
+ // This call will fail with badNonce, causing a retry | |
+ if _, err := client.Authorize(context.Background(), "example.com"); err != nil { | |
+ t.Errorf("client.Authorize 1: %v", err) | |
+ } | |
+ if count != 4 { | |
+ t.Errorf("total requests count: %d; want 4", count) | |
+ } | |
+} | |
+ | |
func TestLinkHeader(t *testing.T) { | |
h := http.Header{"Link": { | |
`<https://example.com/acme/new-authz>;rel="next"`, | |
diff --git a/acme/autocert/autocert.go b/acme/autocert/autocert.go | |
index 4b15816..a478eff 100644 | |
--- a/acme/autocert/autocert.go | |
+++ b/acme/autocert/autocert.go | |
@@ -10,6 +10,7 @@ package autocert | |
import ( | |
"bytes" | |
+ "context" | |
"crypto" | |
"crypto/ecdsa" | |
"crypto/elliptic" | |
@@ -30,9 +31,14 @@ import ( | |
"time" | |
"golang.org/x/crypto/acme" | |
- "golang.org/x/net/context" | |
) | |
+// createCertRetryAfter is how much time to wait before removing a failed state | |
+// entry due to an unsuccessful createCert call. | |
+// This is a variable instead of a const for testing. | |
+// TODO: Consider making it configurable or an exp backoff? | |
+var createCertRetryAfter = time.Minute | |
+ | |
// pseudoRand is safe for concurrent use. | |
var pseudoRand *lockedMathRand | |
@@ -41,8 +47,9 @@ func init() { | |
pseudoRand = &lockedMathRand{rnd: mathrand.New(src)} | |
} | |
-// AcceptTOS always returns true to indicate the acceptance of a CA Terms of Service | |
-// during account registration. | |
+// AcceptTOS is a Manager.Prompt function that always returns true to | |
+// indicate acceptance of the CA's Terms of Service during account | |
+// registration. | |
func AcceptTOS(tosURL string) bool { return true } | |
// HostPolicy specifies which host names the Manager is allowed to respond to. | |
@@ -76,18 +83,6 @@ func defaultHostPolicy(context.Context, string) error { | |
// It obtains and refreshes certificates automatically, | |
// as well as providing them to a TLS server via tls.Config. | |
// | |
-// A simple usage example: | |
-// | |
-// m := autocert.Manager{ | |
-// Prompt: autocert.AcceptTOS, | |
-// HostPolicy: autocert.HostWhitelist("example.org"), | |
-// } | |
-// s := &http.Server{ | |
-// Addr: ":https", | |
-// TLSConfig: &tls.Config{GetCertificate: m.GetCertificate}, | |
-// } | |
-// s.ListenAndServeTLS("", "") | |
-// | |
// To preserve issued certificates and improve overall performance, | |
// use a cache implementation of Cache. For instance, DirCache. | |
type Manager struct { | |
@@ -123,7 +118,7 @@ type Manager struct { | |
// RenewBefore optionally specifies how early certificates should | |
// be renewed before they expire. | |
// | |
- // If zero, they're renewed 1 week before expiration. | |
+ // If zero, they're renewed 30 days before expiration. | |
RenewBefore time.Duration | |
// Client is used to perform low-level operations, such as account registration | |
@@ -173,10 +168,23 @@ type Manager struct { | |
// The error is propagated back to the caller of GetCertificate and is user-visible. | |
// This does not affect cached certs. See HostPolicy field description for more details. | |
func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { | |
+ if m.Prompt == nil { | |
+ return nil, errors.New("acme/autocert: Manager.Prompt not set") | |
+ } | |
+ | |
name := hello.ServerName | |
if name == "" { | |
return nil, errors.New("acme/autocert: missing server name") | |
} | |
+ if !strings.Contains(strings.Trim(name, "."), ".") { | |
+ return nil, errors.New("acme/autocert: server name component count invalid") | |
+ } | |
+ if strings.ContainsAny(name, `/\`) { | |
+ return nil, errors.New("acme/autocert: server name contains invalid character") | |
+ } | |
+ | |
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) | |
+ defer cancel() | |
// check whether this is a token cert requested for TLS-SNI challenge | |
if strings.HasSuffix(name, ".acme.invalid") { | |
@@ -185,7 +193,7 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, | |
if cert := m.tokenCert[name]; cert != nil { | |
return cert, nil | |
} | |
- if cert, err := m.cacheGet(name); err == nil { | |
+ if cert, err := m.cacheGet(ctx, name); err == nil { | |
return cert, nil | |
} | |
// TODO: cache error results? | |
@@ -194,7 +202,7 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, | |
// regular domain | |
name = strings.TrimSuffix(name, ".") // golang.org/issue/18114 | |
- cert, err := m.cert(name) | |
+ cert, err := m.cert(ctx, name) | |
if err == nil { | |
return cert, nil | |
} | |
@@ -203,7 +211,6 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, | |
} | |
// first-time | |
- ctx := context.Background() // TODO: use a deadline? | |
if err := m.hostPolicy()(ctx, name); err != nil { | |
return nil, err | |
} | |
@@ -211,14 +218,14 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, | |
if err != nil { | |
return nil, err | |
} | |
- m.cachePut(name, cert) | |
+ m.cachePut(ctx, name, cert) | |
return cert, nil | |
} | |
// cert returns an existing certificate either from m.state or cache. | |
// If a certificate is found in cache but not in m.state, the latter will be filled | |
// with the cached value. | |
-func (m *Manager) cert(name string) (*tls.Certificate, error) { | |
+func (m *Manager) cert(ctx context.Context, name string) (*tls.Certificate, error) { | |
m.stateMu.Lock() | |
if s, ok := m.state[name]; ok { | |
m.stateMu.Unlock() | |
@@ -227,7 +234,7 @@ func (m *Manager) cert(name string) (*tls.Certificate, error) { | |
return s.tlscert() | |
} | |
defer m.stateMu.Unlock() | |
- cert, err := m.cacheGet(name) | |
+ cert, err := m.cacheGet(ctx, name) | |
if err != nil { | |
return nil, err | |
} | |
@@ -249,12 +256,11 @@ func (m *Manager) cert(name string) (*tls.Certificate, error) { | |
} | |
// cacheGet always returns a valid certificate, or an error otherwise. | |
-func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) { | |
+// If a cached certficate exists but is not valid, ErrCacheMiss is returned. | |
+func (m *Manager) cacheGet(ctx context.Context, domain string) (*tls.Certificate, error) { | |
if m.Cache == nil { | |
return nil, ErrCacheMiss | |
} | |
- // TODO: might want to define a cache timeout on m | |
- ctx := context.Background() | |
data, err := m.Cache.Get(ctx, domain) | |
if err != nil { | |
return nil, err | |
@@ -263,7 +269,7 @@ func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) { | |
// private | |
priv, pub := pem.Decode(data) | |
if priv == nil || !strings.Contains(priv.Type, "PRIVATE") { | |
- return nil, errors.New("acme/autocert: no private key found in cache") | |
+ return nil, ErrCacheMiss | |
} | |
privKey, err := parsePrivateKey(priv.Bytes) | |
if err != nil { | |
@@ -281,13 +287,14 @@ func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) { | |
pubDER = append(pubDER, b.Bytes) | |
} | |
if len(pub) > 0 { | |
- return nil, errors.New("acme/autocert: invalid public key") | |
+ // Leftover content not consumed by pem.Decode. Corrupt. Ignore. | |
+ return nil, ErrCacheMiss | |
} | |
// verify and create TLS cert | |
leaf, err := validCert(domain, pubDER, privKey) | |
if err != nil { | |
- return nil, err | |
+ return nil, ErrCacheMiss | |
} | |
tlscert := &tls.Certificate{ | |
Certificate: pubDER, | |
@@ -297,7 +304,7 @@ func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) { | |
return tlscert, nil | |
} | |
-func (m *Manager) cachePut(domain string, tlscert *tls.Certificate) error { | |
+func (m *Manager) cachePut(ctx context.Context, domain string, tlscert *tls.Certificate) error { | |
if m.Cache == nil { | |
return nil | |
} | |
@@ -329,8 +336,6 @@ func (m *Manager) cachePut(domain string, tlscert *tls.Certificate) error { | |
} | |
} | |
- // TODO: might want to define a cache timeout on m | |
- ctx := context.Background() | |
return m.Cache.Put(ctx, domain, buf.Bytes()) | |
} | |
@@ -370,6 +375,23 @@ func (m *Manager) createCert(ctx context.Context, domain string) (*tls.Certifica | |
der, leaf, err := m.authorizedCert(ctx, state.key, domain) | |
if err != nil { | |
+ // Remove the failed state after some time, | |
+ // making the manager call createCert again on the following TLS hello. | |
+ time.AfterFunc(createCertRetryAfter, func() { | |
+ defer testDidRemoveState(domain) | |
+ m.stateMu.Lock() | |
+ defer m.stateMu.Unlock() | |
+ // Verify the state hasn't changed and it's still invalid | |
+ // before deleting. | |
+ s, ok := m.state[domain] | |
+ if !ok { | |
+ return | |
+ } | |
+ if _, err := validCert(domain, s.cert, s.key); err == nil { | |
+ return | |
+ } | |
+ delete(m.state, domain) | |
+ }) | |
return nil, err | |
} | |
state.cert = der | |
@@ -418,7 +440,6 @@ func (m *Manager) certState(domain string) (*certState, error) { | |
// authorizedCert starts domain ownership verification process and requests a new cert upon success. | |
// The key argument is the certificate private key. | |
func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, domain string) (der [][]byte, leaf *x509.Certificate, err error) { | |
- // TODO: make m.verify retry or retry m.verify calls here | |
if err := m.verify(ctx, domain); err != nil { | |
return nil, nil, err | |
} | |
@@ -494,7 +515,7 @@ func (m *Manager) verify(ctx context.Context, domain string) error { | |
if err != nil { | |
return err | |
} | |
- m.putTokenCert(name, &cert) | |
+ m.putTokenCert(ctx, name, &cert) | |
defer func() { | |
// verification has ended at this point | |
// don't need token cert anymore | |
@@ -512,14 +533,14 @@ func (m *Manager) verify(ctx context.Context, domain string) error { | |
// putTokenCert stores the cert under the named key in both m.tokenCert map | |
// and m.Cache. | |
-func (m *Manager) putTokenCert(name string, cert *tls.Certificate) { | |
+func (m *Manager) putTokenCert(ctx context.Context, name string, cert *tls.Certificate) { | |
m.tokenCertMu.Lock() | |
defer m.tokenCertMu.Unlock() | |
if m.tokenCert == nil { | |
m.tokenCert = make(map[string]*tls.Certificate) | |
} | |
m.tokenCert[name] = cert | |
- m.cachePut(name, cert) | |
+ m.cachePut(ctx, name, cert) | |
} | |
// deleteTokenCert removes the token certificate for the specified domain name | |
@@ -644,10 +665,10 @@ func (m *Manager) hostPolicy() HostPolicy { | |
} | |
func (m *Manager) renewBefore() time.Duration { | |
- if m.RenewBefore > maxRandRenew { | |
+ if m.RenewBefore > renewJitter { | |
return m.RenewBefore | |
} | |
- return 7 * 24 * time.Hour // 1 week | |
+ return 720 * time.Hour // 30 days | |
} | |
// certState is ready when its mutex is unlocked for reading. | |
@@ -789,5 +810,10 @@ func (r *lockedMathRand) int63n(max int64) int64 { | |
return n | |
} | |
-// for easier testing | |
-var timeNow = time.Now | |
+// For easier testing. | |
+var ( | |
+ timeNow = time.Now | |
+ | |
+ // Called when a state is removed. | |
+ testDidRemoveState = func(domain string) {} | |
+) | |
diff --git a/acme/autocert/autocert_test.go b/acme/autocert/autocert_test.go | |
index 7afb213..0352e34 100644 | |
--- a/acme/autocert/autocert_test.go | |
+++ b/acme/autocert/autocert_test.go | |
@@ -5,6 +5,7 @@ | |
package autocert | |
import ( | |
+ "context" | |
"crypto" | |
"crypto/ecdsa" | |
"crypto/elliptic" | |
@@ -27,7 +28,6 @@ import ( | |
"time" | |
"golang.org/x/crypto/acme" | |
- "golang.org/x/net/context" | |
) | |
var discoTmpl = template.Must(template.New("disco").Parse(`{ | |
@@ -150,7 +150,7 @@ func TestGetCertificate_ForceRSA(t *testing.T) { | |
hello := &tls.ClientHelloInfo{ServerName: "example.org"} | |
testGetCertificate(t, man, "example.org", hello) | |
- cert, err := man.cacheGet("example.org") | |
+ cert, err := man.cacheGet(context.Background(), "example.org") | |
if err != nil { | |
t.Fatalf("man.cacheGet: %v", err) | |
} | |
@@ -159,9 +159,110 @@ func TestGetCertificate_ForceRSA(t *testing.T) { | |
} | |
} | |
-// tests man.GetCertificate flow using the provided hello argument. | |
+func TestGetCertificate_nilPrompt(t *testing.T) { | |
+ man := &Manager{} | |
+ defer man.stopRenew() | |
+ url, finish := startACMEServerStub(t, man, "example.org") | |
+ defer finish() | |
+ key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) | |
+ if err != nil { | |
+ t.Fatal(err) | |
+ } | |
+ man.Client = &acme.Client{ | |
+ Key: key, | |
+ DirectoryURL: url, | |
+ } | |
+ hello := &tls.ClientHelloInfo{ServerName: "example.org"} | |
+ if _, err := man.GetCertificate(hello); err == nil { | |
+ t.Error("got certificate for example.org; wanted error") | |
+ } | |
+} | |
+ | |
+func TestGetCertificate_expiredCache(t *testing.T) { | |
+ // Make an expired cert and cache it. | |
+ pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) | |
+ if err != nil { | |
+ t.Fatal(err) | |
+ } | |
+ tmpl := &x509.Certificate{ | |
+ SerialNumber: big.NewInt(1), | |
+ Subject: pkix.Name{CommonName: "example.org"}, | |
+ NotAfter: time.Now(), | |
+ } | |
+ pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &pk.PublicKey, pk) | |
+ if err != nil { | |
+ t.Fatal(err) | |
+ } | |
+ tlscert := &tls.Certificate{ | |
+ Certificate: [][]byte{pub}, | |
+ PrivateKey: pk, | |
+ } | |
+ | |
+ man := &Manager{Prompt: AcceptTOS, Cache: newMemCache()} | |
+ defer man.stopRenew() | |
+ if err := man.cachePut(context.Background(), "example.org", tlscert); err != nil { | |
+ t.Fatalf("man.cachePut: %v", err) | |
+ } | |
+ | |
+ // The expired cached cert should trigger a new cert issuance | |
+ // and return without an error. | |
+ hello := &tls.ClientHelloInfo{ServerName: "example.org"} | |
+ testGetCertificate(t, man, "example.org", hello) | |
+} | |
+ | |
+func TestGetCertificate_failedAttempt(t *testing.T) { | |
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
+ w.WriteHeader(http.StatusBadRequest) | |
+ })) | |
+ defer ts.Close() | |
+ | |
+ const example = "example.org" | |
+ d := createCertRetryAfter | |
+ f := testDidRemoveState | |
+ defer func() { | |
+ createCertRetryAfter = d | |
+ testDidRemoveState = f | |
+ }() | |
+ createCertRetryAfter = 0 | |
+ done := make(chan struct{}) | |
+ testDidRemoveState = func(domain string) { | |
+ if domain != example { | |
+ t.Errorf("testDidRemoveState: domain = %q; want %q", domain, example) | |
+ } | |
+ close(done) | |
+ } | |
+ | |
+ key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) | |
+ if err != nil { | |
+ t.Fatal(err) | |
+ } | |
+ man := &Manager{ | |
+ Prompt: AcceptTOS, | |
+ Client: &acme.Client{ | |
+ Key: key, | |
+ DirectoryURL: ts.URL, | |
+ }, | |
+ } | |
+ defer man.stopRenew() | |
+ hello := &tls.ClientHelloInfo{ServerName: example} | |
+ if _, err := man.GetCertificate(hello); err == nil { | |
+ t.Error("GetCertificate: err is nil") | |
+ } | |
+ select { | |
+ case <-time.After(5 * time.Second): | |
+ t.Errorf("took too long to remove the %q state", example) | |
+ case <-done: | |
+ man.stateMu.Lock() | |
+ defer man.stateMu.Unlock() | |
+ if v, exist := man.state[example]; exist { | |
+ t.Errorf("state exists for %q: %+v", example, v) | |
+ } | |
+ } | |
+} | |
+ | |
+// startACMEServerStub runs an ACME server | |
// The domain argument is the expected domain name of a certificate request. | |
-func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.ClientHelloInfo) { | |
+func startACMEServerStub(t *testing.T, man *Manager, domain string) (url string, finish func()) { | |
// echo token-02 | shasum -a 256 | |
// then divide result in 2 parts separated by dot | |
tokenCertName := "4e8eb87631187e9ff2153b56b13a4dec.13a35d002e485d60ff37354b32f665d9.token.acme.invalid" | |
@@ -187,7 +288,7 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl | |
// discovery | |
case "/": | |
if err := discoTmpl.Execute(w, ca.URL); err != nil { | |
- t.Fatalf("discoTmpl: %v", err) | |
+ t.Errorf("discoTmpl: %v", err) | |
} | |
// client key registration | |
case "/new-reg": | |
@@ -197,7 +298,7 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl | |
w.Header().Set("location", ca.URL+"/authz/1") | |
w.WriteHeader(http.StatusCreated) | |
if err := authzTmpl.Execute(w, ca.URL); err != nil { | |
- t.Fatalf("authzTmpl: %v", err) | |
+ t.Errorf("authzTmpl: %v", err) | |
} | |
// accept tls-sni-02 challenge | |
case "/challenge/2": | |
@@ -215,14 +316,14 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl | |
b, _ := base64.RawURLEncoding.DecodeString(req.CSR) | |
csr, err := x509.ParseCertificateRequest(b) | |
if err != nil { | |
- t.Fatalf("new-cert: CSR: %v", err) | |
+ t.Errorf("new-cert: CSR: %v", err) | |
} | |
if csr.Subject.CommonName != domain { | |
t.Errorf("CommonName in CSR = %q; want %q", csr.Subject.CommonName, domain) | |
} | |
der, err := dummyCert(csr.PublicKey, domain) | |
if err != nil { | |
- t.Fatalf("new-cert: dummyCert: %v", err) | |
+ t.Errorf("new-cert: dummyCert: %v", err) | |
} | |
chainUp := fmt.Sprintf("<%s/ca-cert>; rel=up", ca.URL) | |
w.Header().Set("link", chainUp) | |
@@ -232,14 +333,51 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl | |
case "/ca-cert": | |
der, err := dummyCert(nil, "ca") | |
if err != nil { | |
- t.Fatalf("ca-cert: dummyCert: %v", err) | |
+ t.Errorf("ca-cert: dummyCert: %v", err) | |
} | |
w.Write(der) | |
default: | |
t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path) | |
} | |
})) | |
- defer ca.Close() | |
+ finish = func() { | |
+ ca.Close() | |
+ | |
+ // make sure token cert was removed | |
+ cancel := make(chan struct{}) | |
+ done := make(chan struct{}) | |
+ go func() { | |
+ defer close(done) | |
+ tick := time.NewTicker(100 * time.Millisecond) | |
+ defer tick.Stop() | |
+ for { | |
+ hello := &tls.ClientHelloInfo{ServerName: tokenCertName} | |
+ if _, err := man.GetCertificate(hello); err != nil { | |
+ return | |
+ } | |
+ select { | |
+ case <-tick.C: | |
+ case <-cancel: | |
+ return | |
+ } | |
+ } | |
+ }() | |
+ select { | |
+ case <-done: | |
+ case <-time.After(5 * time.Second): | |
+ close(cancel) | |
+ t.Error("token cert was not removed") | |
+ <-done | |
+ } | |
+ } | |
+ return ca.URL, finish | |
+} | |
+ | |
+// tests man.GetCertificate flow using the provided hello argument. | |
+// The domain argument is the expected domain name of a certificate request. | |
+func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.ClientHelloInfo) { | |
+ url, finish := startACMEServerStub(t, man, domain) | |
+ defer finish() | |
// use EC key to run faster on 386 | |
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) | |
@@ -248,7 +386,7 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl | |
} | |
man.Client = &acme.Client{ | |
Key: key, | |
- DirectoryURL: ca.URL, | |
+ DirectoryURL: url, | |
} | |
// simulate tls.Config.GetCertificate | |
@@ -279,23 +417,6 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl | |
t.Errorf("cert.DNSNames = %v; want %q", cert.DNSNames, domain) | |
} | |
- // make sure token cert was removed | |
- done = make(chan struct{}) | |
- go func() { | |
- for { | |
- hello := &tls.ClientHelloInfo{ServerName: tokenCertName} | |
- if _, err := man.GetCertificate(hello); err != nil { | |
- break | |
- } | |
- time.Sleep(100 * time.Millisecond) | |
- } | |
- close(done) | |
- }() | |
- select { | |
- case <-time.After(5 * time.Second): | |
- t.Error("token cert was not removed") | |
- case <-done: | |
- } | |
} | |
func TestAccountKeyCache(t *testing.T) { | |
@@ -335,10 +456,11 @@ func TestCache(t *testing.T) { | |
man := &Manager{Cache: newMemCache()} | |
defer man.stopRenew() | |
- if err := man.cachePut("example.org", tlscert); err != nil { | |
+ ctx := context.Background() | |
+ if err := man.cachePut(ctx, "example.org", tlscert); err != nil { | |
t.Fatalf("man.cachePut: %v", err) | |
} | |
- res, err := man.cacheGet("example.org") | |
+ res, err := man.cacheGet(ctx, "example.org") | |
if err != nil { | |
t.Fatalf("man.cacheGet: %v", err) | |
} | |
@@ -438,3 +560,47 @@ func TestValidCert(t *testing.T) { | |
} | |
} | |
} | |
+ | |
+type cacheGetFunc func(ctx context.Context, key string) ([]byte, error) | |
+ | |
+func (f cacheGetFunc) Get(ctx context.Context, key string) ([]byte, error) { | |
+ return f(ctx, key) | |
+} | |
+ | |
+func (f cacheGetFunc) Put(ctx context.Context, key string, data []byte) error { | |
+ return fmt.Errorf("unsupported Put of %q = %q", key, data) | |
+} | |
+ | |
+func (f cacheGetFunc) Delete(ctx context.Context, key string) error { | |
+ return fmt.Errorf("unsupported Delete of %q", key) | |
+} | |
+ | |
+func TestManagerGetCertificateBogusSNI(t *testing.T) { | |
+ m := Manager{ | |
+ Prompt: AcceptTOS, | |
+ Cache: cacheGetFunc(func(ctx context.Context, key string) ([]byte, error) { | |
+ return nil, fmt.Errorf("cache.Get of %s", key) | |
+ }), | |
+ } | |
+ tests := []struct { | |
+ name string | |
+ wantErr string | |
+ }{ | |
+ {"foo.com", "cache.Get of foo.com"}, | |
+ {"foo.com.", "cache.Get of foo.com"}, | |
+ {`a\b.com`, "acme/autocert: server name contains invalid character"}, | |
+ {`a/b.com`, "acme/autocert: server name contains invalid character"}, | |
+ {"", "acme/autocert: missing server name"}, | |
+ {"foo", "acme/autocert: server name component count invalid"}, | |
+ {".foo", "acme/autocert: server name component count invalid"}, | |
+ {"foo.", "acme/autocert: server name component count invalid"}, | |
+ {"fo.o", "cache.Get of fo.o"}, | |
+ } | |
+ for _, tt := range tests { | |
+ _, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: tt.name}) | |
+ got := fmt.Sprint(err) | |
+ if got != tt.wantErr { | |
+ t.Errorf("GetCertificate(SNI = %q) = %q; want %q", tt.name, got, tt.wantErr) | |
+ } | |
+ } | |
+} | |
diff --git a/acme/autocert/cache.go b/acme/autocert/cache.go | |
index 9b184aa..61a5fd2 100644 | |
--- a/acme/autocert/cache.go | |
+++ b/acme/autocert/cache.go | |
@@ -5,12 +5,11 @@ | |
package autocert | |
import ( | |
+ "context" | |
"errors" | |
"io/ioutil" | |
"os" | |
"path/filepath" | |
- | |
- "golang.org/x/net/context" | |
) | |
// ErrCacheMiss is returned when a certificate is not found in cache. | |
@@ -78,12 +77,13 @@ func (d DirCache) Put(ctx context.Context, name string, data []byte) error { | |
if tmp, err = d.writeTempFile(name, data); err != nil { | |
return | |
} | |
- // prevent overwriting the file if the context was cancelled | |
- if ctx.Err() != nil { | |
- return // no need to set err | |
+ select { | |
+ case <-ctx.Done(): | |
+ // Don't overwrite the file if the context was canceled. | |
+ default: | |
+ newName := filepath.Join(string(d), name) | |
+ err = os.Rename(tmp, newName) | |
} | |
- name = filepath.Join(string(d), name) | |
- err = os.Rename(tmp, name) | |
}() | |
select { | |
case <-ctx.Done(): | |
diff --git a/acme/autocert/cache_test.go b/acme/autocert/cache_test.go | |
index ad6d4a4..6e1b88d 100644 | |
--- a/acme/autocert/cache_test.go | |
+++ b/acme/autocert/cache_test.go | |
@@ -5,13 +5,12 @@ | |
package autocert | |
import ( | |
+ "context" | |
"io/ioutil" | |
"os" | |
"path/filepath" | |
"reflect" | |
"testing" | |
- | |
- "golang.org/x/net/context" | |
) | |
// make sure DirCache satisfies Cache interface | |
diff --git a/acme/autocert/example_test.go b/acme/autocert/example_test.go | |
new file mode 100644 | |
index 0000000..c6267b8 | |
--- /dev/null | |
+++ b/acme/autocert/example_test.go | |
@@ -0,0 +1,34 @@ | |
+// Copyright 2017 The Go Authors. All rights reserved. | |
+// Use of this source code is governed by a BSD-style | |
+// license that can be found in the LICENSE file. | |
+ | |
+package autocert_test | |
+ | |
+import ( | |
+ "crypto/tls" | |
+ "fmt" | |
+ "log" | |
+ "net/http" | |
+ | |
+ "golang.org/x/crypto/acme/autocert" | |
+) | |
+ | |
+func ExampleNewListener() { | |
+ mux := http.NewServeMux() | |
+ mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { | |
+ fmt.Fprintf(w, "Hello, TLS user! Your config: %+v", r.TLS) | |
+ }) | |
+ log.Fatal(http.Serve(autocert.NewListener("example.com"), mux)) | |
+} | |
+ | |
+func ExampleManager() { | |
+ m := autocert.Manager{ | |
+ Prompt: autocert.AcceptTOS, | |
+ HostPolicy: autocert.HostWhitelist("example.org"), | |
+ } | |
+ s := &http.Server{ | |
+ Addr: ":https", | |
+ TLSConfig: &tls.Config{GetCertificate: m.GetCertificate}, | |
+ } | |
+ s.ListenAndServeTLS("", "") | |
+} | |
diff --git a/acme/autocert/listener.go b/acme/autocert/listener.go | |
new file mode 100644 | |
index 0000000..d4c93d2 | |
--- /dev/null | |
+++ b/acme/autocert/listener.go | |
@@ -0,0 +1,153 @@ | |
+// Copyright 2017 The Go Authors. All rights reserved. | |
+// Use of this source code is governed by a BSD-style | |
+// license that can be found in the LICENSE file. | |
+ | |
+package autocert | |
+ | |
+import ( | |
+ "crypto/tls" | |
+ "log" | |
+ "net" | |
+ "os" | |
+ "path/filepath" | |
+ "runtime" | |
+ "time" | |
+) | |
+ | |
+// NewListener returns a net.Listener that listens on the standard TLS | |
+// port (443) on all interfaces and returns *tls.Conn connections with | |
+// LetsEncrypt certificates for the provided domain or domains. | |
+// | |
+// It enables one-line HTTPS servers: | |
+// | |
+// log.Fatal(http.Serve(autocert.NewListener("example.com"), handler)) | |
+// | |
+// NewListener is a convenience function for a common configuration. | |
+// More complex or custom configurations can use the autocert.Manager | |
+// type instead. | |
+// | |
+// Use of this function implies acceptance of the LetsEncrypt Terms of | |
+// Service. If domains is not empty, the provided domains are passed | |
+// to HostWhitelist. If domains is empty, the listener will do | |
+// LetsEncrypt challenges for any requested domain, which is not | |
+// recommended. | |
+// | |
+// Certificates are cached in a "golang-autocert" directory under an | |
+// operating system-specific cache or temp directory. This may not | |
+// be suitable for servers spanning multiple machines. | |
+// | |
+// The returned Listener also enables TCP keep-alives on the accepted | |
+// connections. The returned *tls.Conn are returned before their TLS | |
+// handshake has completed. | |
+func NewListener(domains ...string) net.Listener { | |
+ m := &Manager{ | |
+ Prompt: AcceptTOS, | |
+ } | |
+ if len(domains) > 0 { | |
+ m.HostPolicy = HostWhitelist(domains...) | |
+ } | |
+ dir := cacheDir() | |
+ if err := os.MkdirAll(dir, 0700); err != nil { | |
+ log.Printf("warning: autocert.NewListener not using a cache: %v", err) | |
+ } else { | |
+ m.Cache = DirCache(dir) | |
+ } | |
+ return m.Listener() | |
+} | |
+ | |
+// Listener listens on the standard TLS port (443) on all interfaces | |
+// and returns a net.Listener returning *tls.Conn connections. | |
+// | |
+// The returned Listener also enables TCP keep-alives on the accepted | |
+// connections. The returned *tls.Conn are returned before their TLS | |
+// handshake has completed. | |
+// | |
+// Unlike NewListener, it is the caller's responsibility to initialize | |
+// the Manager m's Prompt, Cache, HostPolicy, and other desired options. | |
+func (m *Manager) Listener() net.Listener { | |
+ ln := &listener{ | |
+ m: m, | |
+ conf: &tls.Config{ | |
+ GetCertificate: m.GetCertificate, // bonus: panic on nil m | |
+ }, | |
+ } | |
+ ln.tcpListener, ln.tcpListenErr = net.Listen("tcp", ":443") | |
+ return ln | |
+} | |
+ | |
+type listener struct { | |
+ m *Manager | |
+ conf *tls.Config | |
+ | |
+ tcpListener net.Listener | |
+ tcpListenErr error | |
+} | |
+ | |
+func (ln *listener) Accept() (net.Conn, error) { | |
+ if ln.tcpListenErr != nil { | |
+ return nil, ln.tcpListenErr | |
+ } | |
+ conn, err := ln.tcpListener.Accept() | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ tcpConn := conn.(*net.TCPConn) | |
+ | |
+ // Because Listener is a convenience function, help out with | |
+ // this too. This is not possible for the caller to set once | |
+ // we return a *tcp.Conn wrapping an inaccessible net.Conn. | |
+ // If callers don't want this, they can do things the manual | |
+ // way and tweak as needed. But this is what net/http does | |
+ // itself, so copy that. If net/http changes, we can change | |
+ // here too. | |
+ tcpConn.SetKeepAlive(true) | |
+ tcpConn.SetKeepAlivePeriod(3 * time.Minute) | |
+ | |
+ return tls.Server(tcpConn, ln.conf), nil | |
+} | |
+ | |
+func (ln *listener) Addr() net.Addr { | |
+ if ln.tcpListener != nil { | |
+ return ln.tcpListener.Addr() | |
+ } | |
+ // net.Listen failed. Return something non-nil in case callers | |
+ // call Addr before Accept: | |
+ return &net.TCPAddr{IP: net.IP{0, 0, 0, 0}, Port: 443} | |
+} | |
+ | |
+func (ln *listener) Close() error { | |
+ if ln.tcpListenErr != nil { | |
+ return ln.tcpListenErr | |
+ } | |
+ return ln.tcpListener.Close() | |
+} | |
+ | |
+func homeDir() string { | |
+ if runtime.GOOS == "windows" { | |
+ return os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH") | |
+ } | |
+ if h := os.Getenv("HOME"); h != "" { | |
+ return h | |
+ } | |
+ return "/" | |
+} | |
+ | |
+func cacheDir() string { | |
+ const base = "golang-autocert" | |
+ switch runtime.GOOS { | |
+ case "darwin": | |
+ return filepath.Join(homeDir(), "Library", "Caches", base) | |
+ case "windows": | |
+ for _, ev := range []string{"APPDATA", "CSIDL_APPDATA", "TEMP", "TMP"} { | |
+ if v := os.Getenv(ev); v != "" { | |
+ return filepath.Join(v, base) | |
+ } | |
+ } | |
+ // Worst case: | |
+ return filepath.Join(homeDir(), base) | |
+ } | |
+ if xdg := os.Getenv("XDG_CACHE_HOME"); xdg != "" { | |
+ return filepath.Join(xdg, base) | |
+ } | |
+ return filepath.Join(homeDir(), ".cache", base) | |
+} | |
diff --git a/acme/autocert/renewal.go b/acme/autocert/renewal.go | |
index 1a5018c..6c5da2b 100644 | |
--- a/acme/autocert/renewal.go | |
+++ b/acme/autocert/renewal.go | |
@@ -5,15 +5,14 @@ | |
package autocert | |
import ( | |
+ "context" | |
"crypto" | |
"sync" | |
"time" | |
- | |
- "golang.org/x/net/context" | |
) | |
-// maxRandRenew is a maximum deviation from Manager.RenewBefore. | |
-const maxRandRenew = time.Hour | |
+// renewJitter is the maximum deviation from Manager.RenewBefore. | |
+const renewJitter = time.Hour | |
// domainRenewal tracks the state used by the periodic timers | |
// renewing a single domain's cert. | |
@@ -65,7 +64,7 @@ func (dr *domainRenewal) renew() { | |
// TODO: rotate dr.key at some point? | |
next, err := dr.do(ctx) | |
if err != nil { | |
- next = maxRandRenew / 2 | |
+ next = renewJitter / 2 | |
next += time.Duration(pseudoRand.int63n(int64(next))) | |
} | |
dr.timer = time.AfterFunc(next, dr.renew) | |
@@ -83,9 +82,9 @@ func (dr *domainRenewal) renew() { | |
func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) { | |
// a race is likely unavoidable in a distributed environment | |
// but we try nonetheless | |
- if tlscert, err := dr.m.cacheGet(dr.domain); err == nil { | |
+ if tlscert, err := dr.m.cacheGet(ctx, dr.domain); err == nil { | |
next := dr.next(tlscert.Leaf.NotAfter) | |
- if next > dr.m.renewBefore()+maxRandRenew { | |
+ if next > dr.m.renewBefore()+renewJitter { | |
return next, nil | |
} | |
} | |
@@ -103,7 +102,7 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) { | |
if err != nil { | |
return 0, err | |
} | |
- dr.m.cachePut(dr.domain, tlscert) | |
+ dr.m.cachePut(ctx, dr.domain, tlscert) | |
dr.m.stateMu.Lock() | |
defer dr.m.stateMu.Unlock() | |
// m.state is guaranteed to be non-nil at this point | |
@@ -114,7 +113,7 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) { | |
func (dr *domainRenewal) next(expiry time.Time) time.Duration { | |
d := expiry.Sub(timeNow()) - dr.m.renewBefore() | |
// add a bit of randomness to renew deadline | |
- n := pseudoRand.int63n(int64(maxRandRenew)) | |
+ n := pseudoRand.int63n(int64(renewJitter)) | |
d -= time.Duration(n) | |
if d < 0 { | |
return 0 | |
diff --git a/acme/autocert/renewal_test.go b/acme/autocert/renewal_test.go | |
index 10c811a..f232619 100644 | |
--- a/acme/autocert/renewal_test.go | |
+++ b/acme/autocert/renewal_test.go | |
@@ -5,6 +5,7 @@ | |
package autocert | |
import ( | |
+ "context" | |
"crypto/ecdsa" | |
"crypto/elliptic" | |
"crypto/rand" | |
@@ -31,7 +32,7 @@ func TestRenewalNext(t *testing.T) { | |
expiry time.Time | |
min, max time.Duration | |
}{ | |
- {now.Add(90 * 24 * time.Hour), 83*24*time.Hour - maxRandRenew, 83 * 24 * time.Hour}, | |
+ {now.Add(90 * 24 * time.Hour), 83*24*time.Hour - renewJitter, 83 * 24 * time.Hour}, | |
{now.Add(time.Hour), 0, 1}, | |
{now, 0, 1}, | |
{now.Add(-time.Hour), 0, 1}, | |
@@ -127,7 +128,7 @@ func TestRenewFromCache(t *testing.T) { | |
t.Fatal(err) | |
} | |
tlscert := &tls.Certificate{PrivateKey: key, Certificate: [][]byte{cert}} | |
- if err := man.cachePut(domain, tlscert); err != nil { | |
+ if err := man.cachePut(context.Background(), domain, tlscert); err != nil { | |
t.Fatal(err) | |
} | |
@@ -151,7 +152,7 @@ func TestRenewFromCache(t *testing.T) { | |
// ensure the new cert is cached | |
after := time.Now().Add(future) | |
- tlscert, err := man.cacheGet(domain) | |
+ tlscert, err := man.cacheGet(context.Background(), domain) | |
if err != nil { | |
t.Fatalf("man.cacheGet: %v", err) | |
} | |
diff --git a/acme/jws.go b/acme/jws.go | |
index 49ba313..6cbca25 100644 | |
--- a/acme/jws.go | |
+++ b/acme/jws.go | |
@@ -134,7 +134,7 @@ func jwsHasher(key crypto.Signer) (string, crypto.Hash) { | |
return "ES256", crypto.SHA256 | |
case "P-384": | |
return "ES384", crypto.SHA384 | |
- case "P-512": | |
+ case "P-521": | |
return "ES512", crypto.SHA512 | |
} | |
} | |
diff --git a/acme/jws_test.go b/acme/jws_test.go | |
index 1def873..0ff0fb5 100644 | |
--- a/acme/jws_test.go | |
+++ b/acme/jws_test.go | |
@@ -12,11 +12,13 @@ import ( | |
"encoding/base64" | |
"encoding/json" | |
"encoding/pem" | |
+ "fmt" | |
"math/big" | |
"testing" | |
) | |
-const testKeyPEM = ` | |
+const ( | |
+ testKeyPEM = ` | |
-----BEGIN RSA PRIVATE KEY----- | |
MIIEowIBAAKCAQEA4xgZ3eRPkwoRvy7qeRUbmMDe0V+xH9eWLdu0iheeLlrmD2mq | |
WXfP9IeSKApbn34g8TuAS9g5zhq8ELQ3kmjr+KV86GAMgI6VAcGlq3QrzpTCf/30 | |
@@ -46,10 +48,9 @@ EQeIP6dZtv8IMgtGIb91QX9pXvP0aznzQKwYIA8nZgoENCPfiMTPiEDT9e/0lObO | |
-----END RSA PRIVATE KEY----- | |
` | |
-// This thumbprint is for the testKey defined above. | |
-const testKeyThumbprint = "6nicxzh6WETQlrvdchkz-U3e3DOQZ4heJKU63rfqMqQ" | |
+ // This thumbprint is for the testKey defined above. | |
+ testKeyThumbprint = "6nicxzh6WETQlrvdchkz-U3e3DOQZ4heJKU63rfqMqQ" | |
-const ( | |
// openssl ecparam -name secp256k1 -genkey -noout | |
testKeyECPEM = ` | |
-----BEGIN EC PRIVATE KEY----- | |
@@ -58,39 +59,78 @@ AwEHoUQDQgAE5lhEug5xK4xBDZ2nAbaxLtaLiv85bxJ7ePd1dkO23HThqIrvawF5 | |
QAaS/RNouybCiRhRjI3EaxLkQwgrCw0gqQ== | |
-----END EC PRIVATE KEY----- | |
` | |
- // 1. opnessl ec -in key.pem -noout -text | |
+ // openssl ecparam -name secp384r1 -genkey -noout | |
+ testKeyEC384PEM = ` | |
+-----BEGIN EC PRIVATE KEY----- | |
+MIGkAgEBBDAQ4lNtXRORWr1bgKR1CGysr9AJ9SyEk4jiVnlUWWUChmSNL+i9SLSD | |
+Oe/naPqXJ6CgBwYFK4EEACKhZANiAAQzKtj+Ms0vHoTX5dzv3/L5YMXOWuI5UKRj | |
+JigpahYCqXD2BA1j0E/2xt5vlPf+gm0PL+UHSQsCokGnIGuaHCsJAp3ry0gHQEke | |
+WYXapUUFdvaK1R2/2hn5O+eiQM8YzCg= | |
+-----END EC PRIVATE KEY----- | |
+` | |
+ // openssl ecparam -name secp521r1 -genkey -noout | |
+ testKeyEC512PEM = ` | |
+-----BEGIN EC PRIVATE KEY----- | |
+MIHcAgEBBEIBSNZKFcWzXzB/aJClAb305ibalKgtDA7+70eEkdPt28/3LZMM935Z | |
+KqYHh/COcxuu3Kt8azRAUz3gyr4zZKhlKUSgBwYFK4EEACOhgYkDgYYABAHUNKbx | |
+7JwC7H6pa2sV0tERWhHhB3JmW+OP6SUgMWryvIKajlx73eS24dy4QPGrWO9/ABsD | |
+FqcRSkNVTXnIv6+0mAF25knqIBIg5Q8M9BnOu9GGAchcwt3O7RDHmqewnJJDrbjd | |
+GGnm6rb+NnWR9DIopM0nKNkToWoF/hzopxu4Ae/GsQ== | |
+-----END EC PRIVATE KEY----- | |
+` | |
+ // 1. openssl ec -in key.pem -noout -text | |
// 2. remove first byte, 04 (the header); the rest is X and Y | |
- // 3. covert each with: echo <val> | xxd -r -p | base64 | tr -d '=' | tr '/+' '_-' | |
- testKeyECPubX = "5lhEug5xK4xBDZ2nAbaxLtaLiv85bxJ7ePd1dkO23HQ" | |
- testKeyECPubY = "4aiK72sBeUAGkv0TaLsmwokYUYyNxGsS5EMIKwsNIKk" | |
+ // 3. convert each with: echo <val> | xxd -r -p | base64 -w 100 | tr -d '=' | tr '/+' '_-' | |
+ testKeyECPubX = "5lhEug5xK4xBDZ2nAbaxLtaLiv85bxJ7ePd1dkO23HQ" | |
+ testKeyECPubY = "4aiK72sBeUAGkv0TaLsmwokYUYyNxGsS5EMIKwsNIKk" | |
+ testKeyEC384PubX = "MyrY_jLNLx6E1-Xc79_y-WDFzlriOVCkYyYoKWoWAqlw9gQNY9BP9sbeb5T3_oJt" | |
+ testKeyEC384PubY = "Dy_lB0kLAqJBpyBrmhwrCQKd68tIB0BJHlmF2qVFBXb2itUdv9oZ-TvnokDPGMwo" | |
+ testKeyEC512PubX = "AdQ0pvHsnALsfqlraxXS0RFaEeEHcmZb44_pJSAxavK8gpqOXHvd5Lbh3LhA8atY738AGwMWpxFKQ1VNeci_r7SY" | |
+ testKeyEC512PubY = "AXbmSeogEiDlDwz0Gc670YYByFzC3c7tEMeap7CckkOtuN0Yaebqtv42dZH0MiikzSco2ROhagX-HOinG7gB78ax" | |
+ | |
// echo -n '{"crv":"P-256","kty":"EC","x":"<testKeyECPubX>","y":"<testKeyECPubY>"}' | \ | |
// openssl dgst -binary -sha256 | base64 | tr -d '=' | tr '/+' '_-' | |
testKeyECThumbprint = "zedj-Bd1Zshp8KLePv2MB-lJ_Hagp7wAwdkA0NUTniU" | |
) | |
var ( | |
- testKey *rsa.PrivateKey | |
- testKeyEC *ecdsa.PrivateKey | |
+ testKey *rsa.PrivateKey | |
+ testKeyEC *ecdsa.PrivateKey | |
+ testKeyEC384 *ecdsa.PrivateKey | |
+ testKeyEC512 *ecdsa.PrivateKey | |
) | |
func init() { | |
- d, _ := pem.Decode([]byte(testKeyPEM)) | |
+ testKey = parseRSA(testKeyPEM, "testKeyPEM") | |
+ testKeyEC = parseEC(testKeyECPEM, "testKeyECPEM") | |
+ testKeyEC384 = parseEC(testKeyEC384PEM, "testKeyEC384PEM") | |
+ testKeyEC512 = parseEC(testKeyEC512PEM, "testKeyEC512PEM") | |
+} | |
+ | |
+func decodePEM(s, name string) []byte { | |
+ d, _ := pem.Decode([]byte(s)) | |
if d == nil { | |
- panic("no block found in testKeyPEM") | |
+ panic("no block found in " + name) | |
} | |
- var err error | |
- testKey, err = x509.ParsePKCS1PrivateKey(d.Bytes) | |
+ return d.Bytes | |
+} | |
+ | |
+func parseRSA(s, name string) *rsa.PrivateKey { | |
+ b := decodePEM(s, name) | |
+ k, err := x509.ParsePKCS1PrivateKey(b) | |
if err != nil { | |
- panic(err.Error()) | |
+ panic(fmt.Sprintf("%s: %v", name, err)) | |
} | |
+ return k | |
+} | |
- if d, _ = pem.Decode([]byte(testKeyECPEM)); d == nil { | |
- panic("no block found in testKeyECPEM") | |
- } | |
- testKeyEC, err = x509.ParseECPrivateKey(d.Bytes) | |
+func parseEC(s, name string) *ecdsa.PrivateKey { | |
+ b := decodePEM(s, name) | |
+ k, err := x509.ParseECPrivateKey(b) | |
if err != nil { | |
- panic(err.Error()) | |
+ panic(fmt.Sprintf("%s: %v", name, err)) | |
} | |
+ return k | |
} | |
func TestJWSEncodeJSON(t *testing.T) { | |
@@ -141,50 +181,63 @@ func TestJWSEncodeJSON(t *testing.T) { | |
} | |
func TestJWSEncodeJSONEC(t *testing.T) { | |
- claims := struct{ Msg string }{"Hello JWS"} | |
- | |
- b, err := jwsEncodeJSON(claims, testKeyEC, "nonce") | |
- if err != nil { | |
- t.Fatal(err) | |
- } | |
- var jws struct{ Protected, Payload, Signature string } | |
- if err := json.Unmarshal(b, &jws); err != nil { | |
- t.Fatal(err) | |
+ tt := []struct { | |
+ key *ecdsa.PrivateKey | |
+ x, y string | |
+ alg, crv string | |
+ }{ | |
+ {testKeyEC, testKeyECPubX, testKeyECPubY, "ES256", "P-256"}, | |
+ {testKeyEC384, testKeyEC384PubX, testKeyEC384PubY, "ES384", "P-384"}, | |
+ {testKeyEC512, testKeyEC512PubX, testKeyEC512PubY, "ES512", "P-521"}, | |
} | |
+ for i, test := range tt { | |
+ claims := struct{ Msg string }{"Hello JWS"} | |
+ b, err := jwsEncodeJSON(claims, test.key, "nonce") | |
+ if err != nil { | |
+ t.Errorf("%d: %v", i, err) | |
+ continue | |
+ } | |
+ var jws struct{ Protected, Payload, Signature string } | |
+ if err := json.Unmarshal(b, &jws); err != nil { | |
+ t.Errorf("%d: %v", i, err) | |
+ continue | |
+ } | |
- if b, err = base64.RawURLEncoding.DecodeString(jws.Protected); err != nil { | |
- t.Fatalf("jws.Protected: %v", err) | |
- } | |
- var head struct { | |
- Alg string | |
- Nonce string | |
- JWK struct { | |
- Crv string | |
- Kty string | |
- X string | |
- Y string | |
- } `json:"jwk"` | |
- } | |
- if err := json.Unmarshal(b, &head); err != nil { | |
- t.Fatalf("jws.Protected: %v", err) | |
- } | |
- if head.Alg != "ES256" { | |
- t.Errorf("head.Alg = %q; want ES256", head.Alg) | |
- } | |
- if head.Nonce != "nonce" { | |
- t.Errorf("head.Nonce = %q; want nonce", head.Nonce) | |
- } | |
- if head.JWK.Crv != "P-256" { | |
- t.Errorf("head.JWK.Crv = %q; want P-256", head.JWK.Crv) | |
- } | |
- if head.JWK.Kty != "EC" { | |
- t.Errorf("head.JWK.Kty = %q; want EC", head.JWK.Kty) | |
- } | |
- if head.JWK.X != testKeyECPubX { | |
- t.Errorf("head.JWK.X = %q; want %q", head.JWK.X, testKeyECPubX) | |
- } | |
- if head.JWK.Y != testKeyECPubY { | |
- t.Errorf("head.JWK.Y = %q; want %q", head.JWK.Y, testKeyECPubY) | |
+ b, err = base64.RawURLEncoding.DecodeString(jws.Protected) | |
+ if err != nil { | |
+ t.Errorf("%d: jws.Protected: %v", i, err) | |
+ } | |
+ var head struct { | |
+ Alg string | |
+ Nonce string | |
+ JWK struct { | |
+ Crv string | |
+ Kty string | |
+ X string | |
+ Y string | |
+ } `json:"jwk"` | |
+ } | |
+ if err := json.Unmarshal(b, &head); err != nil { | |
+ t.Errorf("%d: jws.Protected: %v", i, err) | |
+ } | |
+ if head.Alg != test.alg { | |
+ t.Errorf("%d: head.Alg = %q; want %q", i, head.Alg, test.alg) | |
+ } | |
+ if head.Nonce != "nonce" { | |
+ t.Errorf("%d: head.Nonce = %q; want nonce", i, head.Nonce) | |
+ } | |
+ if head.JWK.Crv != test.crv { | |
+ t.Errorf("%d: head.JWK.Crv = %q; want %q", i, head.JWK.Crv, test.crv) | |
+ } | |
+ if head.JWK.Kty != "EC" { | |
+ t.Errorf("%d: head.JWK.Kty = %q; want EC", i, head.JWK.Kty) | |
+ } | |
+ if head.JWK.X != test.x { | |
+ t.Errorf("%d: head.JWK.X = %q; want %q", i, head.JWK.X, test.x) | |
+ } | |
+ if head.JWK.Y != test.y { | |
+ t.Errorf("%d: head.JWK.Y = %q; want %q", i, head.JWK.Y, test.y) | |
+ } | |
} | |
} | |
diff --git a/acme/types.go b/acme/types.go | |
index 0513b2e..ab4de0b 100644 | |
--- a/acme/types.go | |
+++ b/acme/types.go | |
@@ -1,9 +1,15 @@ | |
+// Copyright 2016 The Go Authors. All rights reserved. | |
+// Use of this source code is governed by a BSD-style | |
+// license that can be found in the LICENSE file. | |
+ | |
package acme | |
import ( | |
"errors" | |
"fmt" | |
"net/http" | |
+ "strings" | |
+ "time" | |
) | |
// ACME server response statuses used to describe Authorization and Challenge states. | |
@@ -33,14 +39,8 @@ const ( | |
CRLReasonAACompromise CRLReasonCode = 10 | |
) | |
-var ( | |
- // ErrAuthorizationFailed indicates that an authorization for an identifier | |
- // did not succeed. | |
- ErrAuthorizationFailed = errors.New("acme: identifier authorization failed") | |
- | |
- // ErrUnsupportedKey is returned when an unsupported key type is encountered. | |
- ErrUnsupportedKey = errors.New("acme: unknown key type; only RSA and ECDSA are supported") | |
-) | |
+// ErrUnsupportedKey is returned when an unsupported key type is encountered. | |
+var ErrUnsupportedKey = errors.New("acme: unknown key type; only RSA and ECDSA are supported") | |
// Error is an ACME error, defined in Problem Details for HTTP APIs doc | |
// http://tools.ietf.org/html/draft-ietf-appsawg-http-problem. | |
@@ -53,6 +53,7 @@ type Error struct { | |
// Detail is a human-readable explanation specific to this occurrence of the problem. | |
Detail string | |
// Header is the original server error response headers. | |
+ // It may be nil. | |
Header http.Header | |
} | |
@@ -60,6 +61,50 @@ func (e *Error) Error() string { | |
return fmt.Sprintf("%d %s: %s", e.StatusCode, e.ProblemType, e.Detail) | |
} | |
+// AuthorizationError indicates that an authorization for an identifier | |
+// did not succeed. | |
+// It contains all errors from Challenge items of the failed Authorization. | |
+type AuthorizationError struct { | |
+ // URI uniquely identifies the failed Authorization. | |
+ URI string | |
+ | |
+ // Identifier is an AuthzID.Value of the failed Authorization. | |
+ Identifier string | |
+ | |
+ // Errors is a collection of non-nil error values of Challenge items | |
+ // of the failed Authorization. | |
+ Errors []error | |
+} | |
+ | |
+func (a *AuthorizationError) Error() string { | |
+ e := make([]string, len(a.Errors)) | |
+ for i, err := range a.Errors { | |
+ e[i] = err.Error() | |
+ } | |
+ return fmt.Sprintf("acme: authorization error for %s: %s", a.Identifier, strings.Join(e, "; ")) | |
+} | |
+ | |
+// RateLimit reports whether err represents a rate limit error and | |
+// any Retry-After duration returned by the server. | |
+// | |
+// See the following for more details on rate limiting: | |
+// https://tools.ietf.org/html/draft-ietf-acme-acme-05#section-5.6 | |
+func RateLimit(err error) (time.Duration, bool) { | |
+ e, ok := err.(*Error) | |
+ if !ok { | |
+ return 0, false | |
+ } | |
+ // Some CA implementations may return incorrect values. | |
+ // Use case-insensitive comparison. | |
+ if !strings.HasSuffix(strings.ToLower(e.ProblemType), ":ratelimited") { | |
+ return 0, false | |
+ } | |
+ if e.Header == nil { | |
+ return 0, true | |
+ } | |
+ return retryAfter(e.Header.Get("Retry-After"), 0), true | |
+} | |
+ | |
// Account is a user account. It is associated with a private key. | |
type Account struct { | |
// URI is the account unique ID, which is also a URL used to retrieve | |
@@ -118,6 +163,8 @@ type Directory struct { | |
} | |
// Challenge encodes a returned CA challenge. | |
+// Its Error field may be non-nil if the challenge is part of an Authorization | |
+// with StatusInvalid. | |
type Challenge struct { | |
// Type is the challenge type, e.g. "http-01", "tls-sni-02", "dns-01". | |
Type string | |
@@ -130,6 +177,11 @@ type Challenge struct { | |
// Status identifies the status of this challenge. | |
Status string | |
+ | |
+ // Error indicates the reason for an authorization failure | |
+ // when this challenge was used. | |
+ // The type of a non-nil value is *Error. | |
+ Error error | |
} | |
// Authorization encodes an authorization response. | |
@@ -187,12 +239,26 @@ func (z *wireAuthz) authorization(uri string) *Authorization { | |
return a | |
} | |
+func (z *wireAuthz) error(uri string) *AuthorizationError { | |
+ err := &AuthorizationError{ | |
+ URI: uri, | |
+ Identifier: z.Identifier.Value, | |
+ } | |
+ for _, raw := range z.Challenges { | |
+ if raw.Error != nil { | |
+ err.Errors = append(err.Errors, raw.Error.error(nil)) | |
+ } | |
+ } | |
+ return err | |
+} | |
+ | |
// wireChallenge is ACME JSON challenge representation. | |
type wireChallenge struct { | |
URI string `json:"uri"` | |
Type string | |
Token string | |
Status string | |
+ Error *wireError | |
} | |
func (c *wireChallenge) challenge() *Challenge { | |
@@ -205,5 +271,25 @@ func (c *wireChallenge) challenge() *Challenge { | |
if v.Status == "" { | |
v.Status = StatusPending | |
} | |
+ if c.Error != nil { | |
+ v.Error = c.Error.error(nil) | |
+ } | |
return v | |
} | |
+ | |
+// wireError is a subset of fields of the Problem Details object | |
+// as described in https://tools.ietf.org/html/rfc7807#section-3.1. | |
+type wireError struct { | |
+ Status int | |
+ Type string | |
+ Detail string | |
+} | |
+ | |
+func (e *wireError) error(h http.Header) *Error { | |
+ return &Error{ | |
+ StatusCode: e.Status, | |
+ ProblemType: e.Type, | |
+ Detail: e.Detail, | |
+ Header: h, | |
+ } | |
+} | |
diff --git a/acme/types_test.go b/acme/types_test.go | |
new file mode 100644 | |
index 0000000..a7553e6 | |
--- /dev/null | |
+++ b/acme/types_test.go | |
@@ -0,0 +1,63 @@ | |
+// Copyright 2017 The Go Authors. All rights reserved. | |
+// Use of this source code is governed by a BSD-style | |
+// license that can be found in the LICENSE file. | |
+ | |
+package acme | |
+ | |
+import ( | |
+ "errors" | |
+ "net/http" | |
+ "testing" | |
+ "time" | |
+) | |
+ | |
+func TestRateLimit(t *testing.T) { | |
+ now := time.Date(2017, 04, 27, 10, 0, 0, 0, time.UTC) | |
+ f := timeNow | |
+ defer func() { timeNow = f }() | |
+ timeNow = func() time.Time { return now } | |
+ | |
+ h120, hTime := http.Header{}, http.Header{} | |
+ h120.Set("Retry-After", "120") | |
+ hTime.Set("Retry-After", "Tue Apr 27 11:00:00 2017") | |
+ | |
+ err1 := &Error{ | |
+ ProblemType: "urn:ietf:params:acme:error:nolimit", | |
+ Header: h120, | |
+ } | |
+ err2 := &Error{ | |
+ ProblemType: "urn:ietf:params:acme:error:rateLimited", | |
+ Header: h120, | |
+ } | |
+ err3 := &Error{ | |
+ ProblemType: "urn:ietf:params:acme:error:rateLimited", | |
+ Header: nil, | |
+ } | |
+ err4 := &Error{ | |
+ ProblemType: "urn:ietf:params:acme:error:rateLimited", | |
+ Header: hTime, | |
+ } | |
+ | |
+ tt := []struct { | |
+ err error | |
+ res time.Duration | |
+ ok bool | |
+ }{ | |
+ {nil, 0, false}, | |
+ {errors.New("dummy"), 0, false}, | |
+ {err1, 0, false}, | |
+ {err2, 2 * time.Minute, true}, | |
+ {err3, 0, true}, | |
+ {err4, time.Hour, true}, | |
+ } | |
+ for i, test := range tt { | |
+ res, ok := RateLimit(test.err) | |
+ if ok != test.ok { | |
+ t.Errorf("%d: RateLimit(%+v): ok = %v; want %v", i, test.err, ok, test.ok) | |
+ continue | |
+ } | |
+ if res != test.res { | |
+ t.Errorf("%d: RateLimit(%+v) = %v; want %v", i, test.err, res, test.res) | |
+ } | |
+ } | |
+} | |
diff --git a/bcrypt/bcrypt.go b/bcrypt/bcrypt.go | |
index f8b807f..202fa8a 100644 | |
--- a/bcrypt/bcrypt.go | |
+++ b/bcrypt/bcrypt.go | |
@@ -12,9 +12,10 @@ import ( | |
"crypto/subtle" | |
"errors" | |
"fmt" | |
- "golang.org/x/crypto/blowfish" | |
"io" | |
"strconv" | |
+ | |
+ "golang.org/x/crypto/blowfish" | |
) | |
const ( | |
@@ -205,7 +206,6 @@ func bcrypt(password []byte, cost int, salt []byte) ([]byte, error) { | |
} | |
func expensiveBlowfishSetup(key []byte, cost uint32, salt []byte) (*blowfish.Cipher, error) { | |
- | |
csalt, err := base64Decode(salt) | |
if err != nil { | |
return nil, err | |
@@ -213,7 +213,8 @@ func expensiveBlowfishSetup(key []byte, cost uint32, salt []byte) (*blowfish.Cip | |
// Bug compatibility with C bcrypt implementations. They use the trailing | |
// NULL in the key string during expansion. | |
- ckey := append(key, 0) | |
+ // We copy the key to prevent changing the underlying array. | |
+ ckey := append(key[:len(key):len(key)], 0) | |
c, err := blowfish.NewSaltedCipher(ckey, csalt) | |
if err != nil { | |
diff --git a/bcrypt/bcrypt_test.go b/bcrypt/bcrypt_test.go | |
index f08a6f5..aecf759 100644 | |
--- a/bcrypt/bcrypt_test.go | |
+++ b/bcrypt/bcrypt_test.go | |
@@ -224,3 +224,20 @@ func BenchmarkGeneration(b *testing.B) { | |
GenerateFromPassword(passwd, 10) | |
} | |
} | |
+ | |
+// See Issue https://github.com/golang/go/issues/20425. | |
+func TestNoSideEffectsFromCompare(t *testing.T) { | |
+ source := []byte("passw0rd123456") | |
+ password := source[:len(source)-6] | |
+ token := source[len(source)-6:] | |
+ want := make([]byte, len(source)) | |
+ copy(want, source) | |
+ | |
+ wantHash := []byte("$2a$10$LK9XRuhNxHHCvjX3tdkRKei1QiCDUKrJRhZv7WWZPuQGRUM92rOUa") | |
+ _ = CompareHashAndPassword(wantHash, password) | |
+ | |
+ got := bytes.Join([][]byte{password, token}, []byte("")) | |
+ if !bytes.Equal(got, want) { | |
+ t.Errorf("got=%q want=%q", got, want) | |
+ } | |
+} | |
diff --git a/blake2b/blake2b.go b/blake2b/blake2b.go | |
index fa9e48e..ce62241 100644 | |
--- a/blake2b/blake2b.go | |
+++ b/blake2b/blake2b.go | |
@@ -4,7 +4,7 @@ | |
// Package blake2b implements the BLAKE2b hash algorithm as | |
// defined in RFC 7693. | |
-package blake2b | |
+package blake2b // import "golang.org/x/crypto/blake2b" | |
import ( | |
"encoding/binary" | |
diff --git a/blake2b/blake2b_test.go b/blake2b/blake2b_test.go | |
index a38fceb..7954346 100644 | |
--- a/blake2b/blake2b_test.go | |
+++ b/blake2b/blake2b_test.go | |
@@ -22,7 +22,7 @@ func fromHex(s string) []byte { | |
func TestHashes(t *testing.T) { | |
defer func(sse4, avx, avx2 bool) { | |
- useSSE4, useAVX, useAVX2 = sse4, useAVX, avx2 | |
+ useSSE4, useAVX, useAVX2 = sse4, avx, avx2 | |
}(useSSE4, useAVX, useAVX2) | |
if useAVX2 { | |
diff --git a/blake2s/blake2s.go b/blake2s/blake2s.go | |
index 394c121..f2d8221 100644 | |
--- a/blake2s/blake2s.go | |
+++ b/blake2s/blake2s.go | |
@@ -4,7 +4,7 @@ | |
// Package blake2s implements the BLAKE2s hash algorithm as | |
// defined in RFC 7693. | |
-package blake2s | |
+package blake2s // import "golang.org/x/crypto/blake2s" | |
import ( | |
"encoding/binary" | |
@@ -15,8 +15,12 @@ import ( | |
const ( | |
// The blocksize of BLAKE2s in bytes. | |
BlockSize = 64 | |
+ | |
// The hash size of BLAKE2s-256 in bytes. | |
Size = 32 | |
+ | |
+ // The hash size of BLAKE2s-128 in bytes. | |
+ Size128 = 16 | |
) | |
var errKeySize = errors.New("blake2s: invalid key size") | |
@@ -37,6 +41,17 @@ func Sum256(data []byte) [Size]byte { | |
// key turns the hash into a MAC. The key must between zero and 32 bytes long. | |
func New256(key []byte) (hash.Hash, error) { return newDigest(Size, key) } | |
+// New128 returns a new hash.Hash computing the BLAKE2s-128 checksum given a | |
+// non-empty key. Note that a 128-bit digest is too small to be secure as a | |
+// cryptographic hash and should only be used as a MAC, thus the key argument | |
+// is not optional. | |
+func New128(key []byte) (hash.Hash, error) { | |
+ if len(key) == 0 { | |
+ return nil, errors.New("blake2s: a key is required for a 128-bit hash") | |
+ } | |
+ return newDigest(Size128, key) | |
+} | |
+ | |
func newDigest(hashSize int, key []byte) (*digest, error) { | |
if len(key) > Size { | |
return nil, errKeySize | |
diff --git a/blake2s/blake2s_test.go b/blake2s/blake2s_test.go | |
index e6f2eeb..ff41670 100644 | |
--- a/blake2s/blake2s_test.go | |
+++ b/blake2s/blake2s_test.go | |
@@ -18,21 +18,25 @@ func TestHashes(t *testing.T) { | |
if useSSE4 { | |
t.Log("SSE4 version") | |
testHashes(t) | |
+ testHashes128(t) | |
useSSE4 = false | |
} | |
if useSSSE3 { | |
t.Log("SSSE3 version") | |
testHashes(t) | |
+ testHashes128(t) | |
useSSSE3 = false | |
} | |
if useSSE2 { | |
t.Log("SSE2 version") | |
testHashes(t) | |
+ testHashes128(t) | |
useSSE2 = false | |
} | |
if useGeneric { | |
t.Log("generic version") | |
testHashes(t) | |
+ testHashes128(t) | |
} | |
} | |
@@ -69,6 +73,39 @@ func testHashes(t *testing.T) { | |
} | |
} | |
+func testHashes128(t *testing.T) { | |
+ key, _ := hex.DecodeString("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f") | |
+ | |
+ input := make([]byte, 255) | |
+ for i := range input { | |
+ input[i] = byte(i) | |
+ } | |
+ | |
+ for i, expectedHex := range hashes128 { | |
+ h, err := New128(key) | |
+ if err != nil { | |
+ t.Fatalf("#%d: error from New128: %v", i, err) | |
+ } | |
+ | |
+ h.Write(input[:i]) | |
+ sum := h.Sum(nil) | |
+ | |
+ if gotHex := fmt.Sprintf("%x", sum); gotHex != expectedHex { | |
+ t.Fatalf("#%d (single write): got %s, wanted %s", i, gotHex, expectedHex) | |
+ } | |
+ | |
+ h.Reset() | |
+ for j := 0; j < i; j++ { | |
+ h.Write(input[j : j+1]) | |
+ } | |
+ | |
+ sum = h.Sum(sum[:0]) | |
+ if gotHex := fmt.Sprintf("%x", sum); gotHex != expectedHex { | |
+ t.Fatalf("#%d (byte-by-byte): got %s, wanted %s", i, gotHex, expectedHex) | |
+ } | |
+ } | |
+} | |
+ | |
// Benchmarks | |
func benchmarkSum(b *testing.B, size int) { | |
@@ -355,3 +392,262 @@ var hashes = []string{ | |
"db444c15597b5f1a03d1f9edd16e4a9f43a667cc275175dfa2b704e3bb1a9b83", | |
"3fb735061abc519dfe979e54c1ee5bfad0a9d858b3315bad34bde999efd724dd", | |
} | |
+ | |
+var hashes128 = []string{ | |
+ "9536f9b267655743dee97b8a670f9f53", | |
+ "13bacfb85b48a1223c595f8c1e7e82cb", | |
+ "d47a9b1645e2feae501cd5fe44ce6333", | |
+ "1e2a79436a7796a3e9826bfedf07659f", | |
+ "7640360ed3c4f3054dba79a21dda66b7", | |
+ "d1207ac2bf5ac84fc9ef016da5a46a86", | |
+ "3123987871e59305ece3125abfc0099a", | |
+ "cf9e072ad522f2cda2d825218086731c", | |
+ "95d22870392efe2846b12b6e8e84efbb", | |
+ "7d63c30e2d51333f245601b038c0b93b", | |
+ "ed608b98e13976bdf4bedc63fa35e443", | |
+ "ed704b5cd1abf8e0dd67a6ac667a3fa5", | |
+ "77dc70109827dc74c70fd26cba379ae5", | |
+ "d2bf34508b07825ee934f33958f4560e", | |
+ "a340baa7b8a93a6e658adef42e78eeb7", | |
+ "b85c5ceaecbe9a251eac76f6932ba395", | |
+ "246519722001f6e8e97a2183f5985e53", | |
+ "5bce5aa0b7c6cac2ecf6406183cd779a", | |
+ "13408f1647c02f6efd0047ad8344f695", | |
+ "a63970f196760aa36cb965ab62f0e0fa", | |
+ "bc26f48421dd99fd45e15e736d3e7dac", | |
+ "4c6f70f9e3237cde918afb52d26f1823", | |
+ "45ed610cfbc37db80c4bf0eef14ae8d6", | |
+ "87c4c150705ea5078209ec008200539c", | |
+ "54de21f5e0e6f2afe04daeb822b6931e", | |
+ "9732a04e505064e19de3d542e7e71631", | |
+ "d2bd27e95531d6957eef511c4ba64ad4", | |
+ "7a36c9f70dcc7c3063b547101a5f6c35", | |
+ "322007d1a44c4257bc7903b183305529", | |
+ "dbcc9a09f412290ca2e0d53dfd142ddb", | |
+ "df12ed43b8e53a56db20e0f83764002c", | |
+ "d114cc11e7d5b33a360c45f18d4c7c6e", | |
+ "c43b5e836af88620a8a71b1652cb8640", | |
+ "9491c653e8867ed73c1b4ac6b5a9bb4d", | |
+ "06d0e988df94ada6c6f9f36f588ab7c5", | |
+ "561efad2480e93262c8eeaa3677615c4", | |
+ "ba8ffc702e5adc93503045eca8702312", | |
+ "5782be6ccdc78c8425285e85de8ccdc6", | |
+ "aa1c4393e4c07b53ea6e2b5b1e970771", | |
+ "42a229dc50e52271c51e8666023ebc1e", | |
+ "53706110e919f84de7f8d6c7f0e7b831", | |
+ "fc5ac8ee39cc1dd1424391323e2901bd", | |
+ "bed27b62ff66cac2fbb68193c727106a", | |
+ "cd5e689b96d0b9ea7e08dac36f7b211e", | |
+ "0b4c7f604eba058d18e322c6e1baf173", | |
+ "eb838227fdfad09a27f0f8413120675d", | |
+ "3149cf9d19a7fd529e6154a8b4c3b3ad", | |
+ "ca1e20126df930fd5fb7afe4422191e5", | |
+ "b23398f910599f3c09b6549fa81bcb46", | |
+ "27fb17c11b34fa5d8b5afe5ee3321ead", | |
+ "0f665f5f04cf2d46b7fead1a1f328158", | |
+ "8f068be73b3681f99f3b282e3c02bba5", | |
+ "ba189bbd13808dcf4e002a4dd21660d5", | |
+ "2732dcd1b16668ae6ab6a61595d0d62a", | |
+ "d410ccdd059f0e02b472ec9ec54bdd3c", | |
+ "b2eaa07b055b3a03a399971327f7e8c2", | |
+ "2e8a225655e9f99b69c60dc8b4d8e566", | |
+ "4eb55416c853f2152e67f8a224133cec", | |
+ "49552403790d8de0505a8e317a443687", | |
+ "7f2747cd41f56942752e868212c7d5ac", | |
+ "02a28f10e193b430df7112d2d98cf759", | |
+ "d4213404a9f1cf759017747cf5958270", | |
+ "faa34884344f9c65e944882db8476d34", | |
+ "ece382a8bd5018f1de5da44b72cea75b", | |
+ "f1efa90d2547036841ecd3627fafbc36", | |
+ "811ff8686d23a435ecbd0bdafcd27b1b", | |
+ "b21beea9c7385f657a76558530438721", | |
+ "9cb969da4f1b4fc5b13bf78fe366f0c4", | |
+ "8850d16d7b614d3268ccfa009d33c7fc", | |
+ "aa98a2b6176ea86415b9aff3268c6f6d", | |
+ "ec3e1efa5ed195eff667e16b1af1e39e", | |
+ "e40787dca57411d2630db2de699beb08", | |
+ "554835890735babd06318de23d31e78a", | |
+ "493957feecddc302ee2bb2086b6ebfd3", | |
+ "f6069709ad5b0139163717e9ce1114ab", | |
+ "ba5ed386098da284484b211555505a01", | |
+ "9244c8dfad8cbb68c118fa51465b3ae4", | |
+ "51e309a5008eb1f5185e5cc007cfb36f", | |
+ "6ce9ff712121b4f6087955f4911eafd4", | |
+ "59b51d8dcda031218ccdd7c760828155", | |
+ "0012878767a3d4f1c8194458cf1f8832", | |
+ "82900708afd5b6582dc16f008c655edd", | |
+ "21302c7e39b5a4cdf1d6f86b4f00c9b4", | |
+ "e894c7431591eab8d1ce0fe2aa1f01df", | |
+ "b67e1c40ee9d988226d605621854d955", | |
+ "6237bdafa34137cbbec6be43ea9bd22c", | |
+ "4172a8e19b0dcb09b978bb9eff7af52b", | |
+ "5714abb55bd4448a5a6ad09fbd872fdf", | |
+ "7ce1700bef423e1f958a94a77a94d44a", | |
+ "3742ec50cded528527775833453e0b26", | |
+ "5d41b135724c7c9c689495324b162f18", | |
+ "85c523333c6442c202e9e6e0f1185f93", | |
+ "5c71f5222d40ff5d90e7570e71ab2d30", | |
+ "6e18912e83d012efb4c66250ced6f0d9", | |
+ "4add4448c2e35e0b138a0bac7b4b1775", | |
+ "c0376c6bc5e7b8b9d2108ec25d2aab53", | |
+ "f72261d5ed156765c977751c8a13fcc1", | |
+ "cff4156c48614b6ceed3dd6b9058f17e", | |
+ "36bfb513f76c15f514bcb593419835aa", | |
+ "166bf48c6bffaf8291e6fdf63854bef4", | |
+ "0b67d33f8b859c3157fbabd9e6e47ed0", | |
+ "e4da659ca76c88e73a9f9f10f3d51789", | |
+ "33c1ae2a86b3f51c0642e6ed5b5aa1f1", | |
+ "27469b56aca2334449c1cf4970dcd969", | |
+ "b7117b2e363378aa0901b0d6a9f6ddc0", | |
+ "a9578233b09e5cd5231943fdb12cd90d", | |
+ "486d7d75253598b716a068243c1c3e89", | |
+ "66f6b02d682b78ffdc85e9ec86852489", | |
+ "38a07b9a4b228fbcc305476e4d2e05d2", | |
+ "aedb61c7970e7d05bf9002dae3c6858c", | |
+ "c03ef441f7dd30fdb61ad2d4d8e4c7da", | |
+ "7f45cc1eea9a00cb6aeb2dd748361190", | |
+ "a59538b358459132e55160899e47bd65", | |
+ "137010fef72364411820c3fbed15c8df", | |
+ "d8362b93fc504500dbd33ac74e1b4d70", | |
+ "a7e49f12c8f47e3b29cf8c0889b0a9c8", | |
+ "072e94ffbfc684bd8ab2a1b9dade2fd5", | |
+ "5ab438584bd2229e452052e002631a5f", | |
+ "f233d14221097baef57d3ec205c9e086", | |
+ "3a95db000c4a8ff98dc5c89631a7f162", | |
+ "0544f18c2994ab4ddf1728f66041ff16", | |
+ "0bc02116c60a3cc331928d6c9d3ba37e", | |
+ "b189dca6cb5b813c74200834fba97f29", | |
+ "ac8aaab075b4a5bc24419da239212650", | |
+ "1e9f19323dc71c29ae99c479dc7e8df9", | |
+ "12d944c3fa7caa1b3d62adfc492274dd", | |
+ "b4c68f1fffe8f0030e9b18aad8c9dc96", | |
+ "25887fab1422700d7fa3edc0b20206e2", | |
+ "8c09f698d03eaf88abf69f8147865ef6", | |
+ "5c363ae42a5bec26fbc5e996428d9bd7", | |
+ "7fdfc2e854fbb3928150d5e3abcf56d6", | |
+ "f0c944023f714df115f9e4f25bcdb89b", | |
+ "6d19534b4c332741c8ddd79a9644de2d", | |
+ "32595eb23764fbfc2ee7822649f74a12", | |
+ "5a51391aab33c8d575019b6e76ae052a", | |
+ "98b861ce2c620f10f913af5d704a5afd", | |
+ "b7fe2fc8b77fb1ce434f8465c7ddf793", | |
+ "0e8406e0cf8e9cc840668ece2a0fc64e", | |
+ "b89922db99c58f6a128ccffe19b6ce60", | |
+ "e1be9af665f0932b77d7f5631a511db7", | |
+ "74b96f20f58de8dc9ff5e31f91828523", | |
+ "36a4cfef5a2a7d8548db6710e50b3009", | |
+ "007e95e8d3b91948a1dedb91f75de76b", | |
+ "a87a702ce08f5745edf765bfcd5fbe0d", | |
+ "847e69a388a749a9c507354d0dddfe09", | |
+ "07176eefbc107a78f058f3d424ca6a54", | |
+ "ad7e80682333b68296f6cb2b4a8e446d", | |
+ "53c4aba43896ae422e5de5b9edbd46bf", | |
+ "33bd6c20ca2a7ab916d6e98003c6c5f8", | |
+ "060d088ea94aa093f9981a79df1dfcc8", | |
+ "5617b214b9df08d4f11e58f5e76d9a56", | |
+ "ca3a60ee85bd971e1daf9f7db059d909", | |
+ "cd2b7754505d8c884eddf736f1ec613e", | |
+ "f496163b252f1439e7e113ba2ecabd8e", | |
+ "5719c7dcf9d9f756d6213354acb7d5cf", | |
+ "6f7dd40b245c54411e7a9be83ae5701c", | |
+ "c8994dd9fdeb077a45ea04a30358b637", | |
+ "4b1184f1e35458c1c747817d527a252f", | |
+ "fc7df674afeac7a3fd994183f4c67a74", | |
+ "4f68e05ce4dcc533acf9c7c01d95711e", | |
+ "d4ebc59e918400720035dfc88e0c486a", | |
+ "d3105dd6fa123e543b0b3a6e0eeaea9e", | |
+ "874196128ed443f5bdb2800ca048fcad", | |
+ "01645f134978dc8f9cf0abc93b53780e", | |
+ "5b8b64caa257873a0ffd47c981ef6c3f", | |
+ "4ee208fc50ba0a6e65c5b58cec44c923", | |
+ "53f409a52427b3b7ffabb057ca088428", | |
+ "c1d6cd616f5341a93d921e356e5887a9", | |
+ "e85c20fea67fa7320dc23379181183c8", | |
+ "7912b6409489df001b7372bc94aebde7", | |
+ "e559f761ec866a87f1f331767fafc60f", | |
+ "20a6f5a36bc37043d977ed7708465ef8", | |
+ "6a72f526965ab120826640dd784c6cc4", | |
+ "bf486d92ad68e87c613689dd370d001b", | |
+ "d339fd0eb35edf3abd6419c8d857acaf", | |
+ "9521cd7f32306d969ddabc4e6a617f52", | |
+ "a1cd9f3e81520842f3cf6cc301cb0021", | |
+ "18e879b6f154492d593edd3f4554e237", | |
+ "66e2329c1f5137589e051592587e521e", | |
+ "e899566dd6c3e82cbc83958e69feb590", | |
+ "8a4b41d7c47e4e80659d77b4e4bfc9ae", | |
+ "f1944f6fcfc17803405a1101998c57dd", | |
+ "f6bcec07567b4f72851b307139656b18", | |
+ "22e7bb256918fe9924dce9093e2d8a27", | |
+ "dd25b925815fe7b50b7079f5f65a3970", | |
+ "0457f10f299acf0c230dd4007612e58f", | |
+ "ecb420c19efd93814fae2964d69b54af", | |
+ "14eb47b06dff685d88751c6e32789db4", | |
+ "e8f072dbb50d1ab6654aa162604a892d", | |
+ "69cff9c62092332f03a166c7b0034469", | |
+ "d3619f98970b798ca32c6c14cd25af91", | |
+ "2246d423774ee9d51a551e89c0539d9e", | |
+ "75e5d1a1e374a04a699247dad827b6cf", | |
+ "6d087dd1d4cd15bf47db07c7a96b1db8", | |
+ "967e4c055ac51b4b2a3e506cebd5826f", | |
+ "7417aa79247e473401bfa92a25b62e2a", | |
+ "24f3f4956da34b5c533d9a551ccd7b16", | |
+ "0c40382de693a5304e2331eb951cc962", | |
+ "9436f949d51b347db5c8e6258dafaaac", | |
+ "d2084297fe84c4ba6e04e4fb73d734fe", | |
+ "42a6f8ff590af21b512e9e088257aa34", | |
+ "c484ad06b1cdb3a54f3f6464a7a2a6fd", | |
+ "1b8ac860f5ceb4365400a201ed2917aa", | |
+ "c43eadabbe7b7473f3f837fc52650f54", | |
+ "0e5d3205406126b1f838875deb150d6a", | |
+ "6bf4946f8ec8a9c417f50cd1e67565be", | |
+ "42f09a2522314799c95b3fc121a0e3e8", | |
+ "06b8f1487f691a3f7c3f74e133d55870", | |
+ "1a70a65fb4f314dcf6a31451a9d2704f", | |
+ "7d4acdd0823279fd28a1e48b49a04669", | |
+ "09545cc8822a5dfc93bbab708fd69174", | |
+ "efc063db625013a83c9a426d39a9bddb", | |
+ "213bbf89b3f5be0ffdb14854bbcb2588", | |
+ "b69624d89fe2774df9a6f43695d755d4", | |
+ "c0f9ff9ded82bd73c512e365a894774d", | |
+ "d1b68507ed89c17ead6f69012982db71", | |
+ "14cf16db04648978e35c44850855d1b0", | |
+ "9f254d4eccab74cd91d694df863650a8", | |
+ "8f8946e2967baa4a814d36ff01d20813", | |
+ "6b9dc4d24ecba166cb2915d7a6cba43b", | |
+ "eb35a80418a0042b850e294db7898d4d", | |
+ "f55f925d280c637d54055c9df088ef5f", | |
+ "f48427a04f67e33f3ba0a17f7c9704a7", | |
+ "4a9f5bfcc0321aea2eced896cee65894", | |
+ "8723a67d1a1df90f1cef96e6fe81e702", | |
+ "c166c343ee25998f80bad4067960d3fd", | |
+ "dab67288d16702e676a040fd42344d73", | |
+ "c8e9e0d80841eb2c116dd14c180e006c", | |
+ "92294f546bacf0dea9042c93ecba8b34", | |
+ "013705b1502b37369ad22fe8237d444e", | |
+ "9b97f8837d5f2ebab0768fc9a6446b93", | |
+ "7e7e5236b05ec35f89edf8bf655498e7", | |
+ "7be8f2362c174c776fb9432fe93bf259", | |
+ "2422e80420276d2df5702c6470879b01", | |
+ "df645795db778bcce23bbe819a76ba48", | |
+ "3f97a4ac87dfc58761cda1782d749074", | |
+ "50e3f45df21ebfa1b706b9c0a1c245a8", | |
+ "7879541c7ff612c7ddf17cb8f7260183", | |
+ "67f6542b903b7ba1945eba1a85ee6b1c", | |
+ "b34b73d36ab6234b8d3f5494d251138e", | |
+ "0aea139641fdba59ab1103479a96e05f", | |
+ "02776815a87b8ba878453666d42afe3c", | |
+ "5929ab0a90459ebac5a16e2fb37c847e", | |
+ "c244def5b20ce0468f2b5012d04ac7fd", | |
+ "12116add6fefce36ed8a0aeccce9b6d3", | |
+ "3cd743841e9d8b878f34d91b793b4fad", | |
+ "45e87510cf5705262185f46905fae35f", | |
+ "276047016b0bfb501b2d4fc748165793", | |
+ "ddd245df5a799417d350bd7f4e0b0b7e", | |
+ "d34d917a54a2983f3fdbc4b14caae382", | |
+ "7730fbc09d0c1fb1939a8fc436f6b995", | |
+ "eb4899ef257a1711cc9270a19702e5b5", | |
+ "8a30932014bce35bba620895d374df7a", | |
+ "1924aabf9c50aa00bee5e1f95b5d9e12", | |
+ "1758d6f8b982aec9fbe50f20e3082b46", | |
+ "cd075928ab7e6883e697fe7fd3ac43ee", | |
+} | |
diff --git a/chacha20poly1305/chacha20poly1305.go b/chacha20poly1305/chacha20poly1305.go | |
index eb6739a..3f0dcb9 100644 | |
--- a/chacha20poly1305/chacha20poly1305.go | |
+++ b/chacha20poly1305/chacha20poly1305.go | |
@@ -3,7 +3,7 @@ | |
// license that can be found in the LICENSE file. | |
// Package chacha20poly1305 implements the ChaCha20-Poly1305 AEAD as specified in RFC 7539. | |
-package chacha20poly1305 | |
+package chacha20poly1305 // import "golang.org/x/crypto/chacha20poly1305" | |
import ( | |
"crypto/cipher" | |
diff --git a/chacha20poly1305/chacha20poly1305_amd64.go b/chacha20poly1305/chacha20poly1305_amd64.go | |
index 4755033..7cd7ad8 100644 | |
--- a/chacha20poly1305/chacha20poly1305_amd64.go | |
+++ b/chacha20poly1305/chacha20poly1305_amd64.go | |
@@ -14,13 +14,60 @@ func chacha20Poly1305Open(dst []byte, key []uint32, src, ad []byte) bool | |
//go:noescape | |
func chacha20Poly1305Seal(dst []byte, key []uint32, src, ad []byte) | |
-//go:noescape | |
-func haveSSSE3() bool | |
+// cpuid is implemented in chacha20poly1305_amd64.s. | |
+func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32) | |
+ | |
+// xgetbv with ecx = 0 is implemented in chacha20poly1305_amd64.s. | |
+func xgetbv() (eax, edx uint32) | |
-var canUseASM bool | |
+var ( | |
+ useASM bool | |
+ useAVX2 bool | |
+) | |
func init() { | |
- canUseASM = haveSSSE3() | |
+ detectCPUFeatures() | |
+} | |
+ | |
+// detectCPUFeatures is used to detect if cpu instructions | |
+// used by the functions implemented in assembler in | |
+// chacha20poly1305_amd64.s are supported. | |
+func detectCPUFeatures() { | |
+ maxID, _, _, _ := cpuid(0, 0) | |
+ if maxID < 1 { | |
+ return | |
+ } | |
+ | |
+ _, _, ecx1, _ := cpuid(1, 0) | |
+ | |
+ haveSSSE3 := isSet(9, ecx1) | |
+ useASM = haveSSSE3 | |
+ | |
+ haveOSXSAVE := isSet(27, ecx1) | |
+ | |
+ osSupportsAVX := false | |
+ // For XGETBV, OSXSAVE bit is required and sufficient. | |
+ if haveOSXSAVE { | |
+ eax, _ := xgetbv() | |
+ // Check if XMM and YMM registers have OS support. | |
+ osSupportsAVX = isSet(1, eax) && isSet(2, eax) | |
+ } | |
+ haveAVX := isSet(28, ecx1) && osSupportsAVX | |
+ | |
+ if maxID < 7 { | |
+ return | |
+ } | |
+ | |
+ _, ebx7, _, _ := cpuid(7, 0) | |
+ haveAVX2 := isSet(5, ebx7) && haveAVX | |
+ haveBMI2 := isSet(8, ebx7) | |
+ | |
+ useAVX2 = haveAVX2 && haveBMI2 | |
+} | |
+ | |
+// isSet checks if bit at bitpos is set in value. | |
+func isSet(bitpos uint, value uint32) bool { | |
+ return value&(1<<bitpos) != 0 | |
} | |
// setupState writes a ChaCha20 input matrix to state. See | |
@@ -47,7 +94,7 @@ func setupState(state *[16]uint32, key *[32]byte, nonce []byte) { | |
} | |
func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []byte { | |
- if !canUseASM { | |
+ if !useASM { | |
return c.sealGeneric(dst, nonce, plaintext, additionalData) | |
} | |
@@ -60,7 +107,7 @@ func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) [] | |
} | |
func (c *chacha20poly1305) open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { | |
- if !canUseASM { | |
+ if !useASM { | |
return c.openGeneric(dst, nonce, ciphertext, additionalData) | |
} | |
diff --git a/chacha20poly1305/chacha20poly1305_amd64.s b/chacha20poly1305/chacha20poly1305_amd64.s | |
index 39c58b4..1c57e38 100644 | |
--- a/chacha20poly1305/chacha20poly1305_amd64.s | |
+++ b/chacha20poly1305/chacha20poly1305_amd64.s | |
@@ -278,15 +278,8 @@ TEXT ·chacha20Poly1305Open(SB), 0, $288-97 | |
MOVQ ad+72(FP), adp | |
// Check for AVX2 support | |
- CMPB runtime·support_avx2(SB), $0 | |
- JE noavx2bmi2Open | |
- | |
- // Check BMI2 bit for MULXQ. | |
- // runtime·cpuid_ebx7 is always available here | |
- // because it passed avx2 check | |
- TESTL $(1<<8), runtime·cpuid_ebx7(SB) | |
- JNE chacha20Poly1305Open_AVX2 | |
-noavx2bmi2Open: | |
+ CMPB ·useAVX2(SB), $1 | |
+ JE chacha20Poly1305Open_AVX2 | |
// Special optimization, for very short buffers | |
CMPQ inl, $128 | |
@@ -1491,16 +1484,8 @@ TEXT ·chacha20Poly1305Seal(SB), 0, $288-96 | |
MOVQ src_len+56(FP), inl | |
MOVQ ad+72(FP), adp | |
- // Check for AVX2 support | |
- CMPB runtime·support_avx2(SB), $0 | |
- JE noavx2bmi2Seal | |
- | |
- // Check BMI2 bit for MULXQ. | |
- // runtime·cpuid_ebx7 is always available here | |
- // because it passed avx2 check | |
- TESTL $(1<<8), runtime·cpuid_ebx7(SB) | |
- JNE chacha20Poly1305Seal_AVX2 | |
-noavx2bmi2Seal: | |
+ CMPB ·useAVX2(SB), $1 | |
+ JE chacha20Poly1305Seal_AVX2 | |
// Special optimization, for very short buffers | |
CMPQ inl, $128 | |
@@ -2709,13 +2694,21 @@ sealAVX2Tail512LoopB: | |
JMP sealAVX2SealHash | |
-// func haveSSSE3() bool | |
-TEXT ·haveSSSE3(SB), NOSPLIT, $0 | |
- XORQ AX, AX | |
- INCL AX | |
+// func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32) | |
+TEXT ·cpuid(SB), NOSPLIT, $0-24 | |
+ MOVL eaxArg+0(FP), AX | |
+ MOVL ecxArg+4(FP), CX | |
CPUID | |
- SHRQ $9, CX | |
- ANDQ $1, CX | |
- MOVB CX, ret+0(FP) | |
+ MOVL AX, eax+8(FP) | |
+ MOVL BX, ebx+12(FP) | |
+ MOVL CX, ecx+16(FP) | |
+ MOVL DX, edx+20(FP) | |
RET | |
+// func xgetbv() (eax, edx uint32) | |
+TEXT ·xgetbv(SB),NOSPLIT,$0-8 | |
+ MOVL $0, CX | |
+ XGETBV | |
+ MOVL AX, eax+0(FP) | |
+ MOVL DX, edx+4(FP) | |
+ RET | |
diff --git a/cryptobyte/string.go b/cryptobyte/string.go | |
index b1215b3..6780336 100644 | |
--- a/cryptobyte/string.go | |
+++ b/cryptobyte/string.go | |
@@ -5,7 +5,7 @@ | |
// Package cryptobyte implements building and parsing of byte strings for | |
// DER-encoded ASN.1 and TLS messages. See the examples for the Builder and | |
// String types to get started. | |
-package cryptobyte | |
+package cryptobyte // import "golang.org/x/crypto/cryptobyte" | |
// String represents a string of bytes. It provides methods for parsing | |
// fixed-length and length-prefixed values from it. | |
diff --git a/curve25519/cswap_amd64.s b/curve25519/cswap_amd64.s | |
index 45484d1..cd793a5 100644 | |
--- a/curve25519/cswap_amd64.s | |
+++ b/curve25519/cswap_amd64.s | |
@@ -2,87 +2,64 @@ | |
// Use of this source code is governed by a BSD-style | |
// license that can be found in the LICENSE file. | |
-// This code was translated into a form compatible with 6a from the public | |
-// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html | |
- | |
// +build amd64,!gccgo,!appengine | |
-// func cswap(inout *[5]uint64, v uint64) | |
+// func cswap(inout *[4][5]uint64, v uint64) | |
TEXT ·cswap(SB),7,$0 | |
MOVQ inout+0(FP),DI | |
MOVQ v+8(FP),SI | |
- CMPQ SI,$1 | |
- MOVQ 0(DI),SI | |
- MOVQ 80(DI),DX | |
- MOVQ 8(DI),CX | |
- MOVQ 88(DI),R8 | |
- MOVQ SI,R9 | |
- CMOVQEQ DX,SI | |
- CMOVQEQ R9,DX | |
- MOVQ CX,R9 | |
- CMOVQEQ R8,CX | |
- CMOVQEQ R9,R8 | |
- MOVQ SI,0(DI) | |
- MOVQ DX,80(DI) | |
- MOVQ CX,8(DI) | |
- MOVQ R8,88(DI) | |
- MOVQ 16(DI),SI | |
- MOVQ 96(DI),DX | |
- MOVQ 24(DI),CX | |
- MOVQ 104(DI),R8 | |
- MOVQ SI,R9 | |
- CMOVQEQ DX,SI | |
- CMOVQEQ R9,DX | |
- MOVQ CX,R9 | |
- CMOVQEQ R8,CX | |
- CMOVQEQ R9,R8 | |
- MOVQ SI,16(DI) | |
- MOVQ DX,96(DI) | |
- MOVQ CX,24(DI) | |
- MOVQ R8,104(DI) | |
- MOVQ 32(DI),SI | |
- MOVQ 112(DI),DX | |
- MOVQ 40(DI),CX | |
- MOVQ 120(DI),R8 | |
- MOVQ SI,R9 | |
- CMOVQEQ DX,SI | |
- CMOVQEQ R9,DX | |
- MOVQ CX,R9 | |
- CMOVQEQ R8,CX | |
- CMOVQEQ R9,R8 | |
- MOVQ SI,32(DI) | |
- MOVQ DX,112(DI) | |
- MOVQ CX,40(DI) | |
- MOVQ R8,120(DI) | |
- MOVQ 48(DI),SI | |
- MOVQ 128(DI),DX | |
- MOVQ 56(DI),CX | |
- MOVQ 136(DI),R8 | |
- MOVQ SI,R9 | |
- CMOVQEQ DX,SI | |
- CMOVQEQ R9,DX | |
- MOVQ CX,R9 | |
- CMOVQEQ R8,CX | |
- CMOVQEQ R9,R8 | |
- MOVQ SI,48(DI) | |
- MOVQ DX,128(DI) | |
- MOVQ CX,56(DI) | |
- MOVQ R8,136(DI) | |
- MOVQ 64(DI),SI | |
- MOVQ 144(DI),DX | |
- MOVQ 72(DI),CX | |
- MOVQ 152(DI),R8 | |
- MOVQ SI,R9 | |
- CMOVQEQ DX,SI | |
- CMOVQEQ R9,DX | |
- MOVQ CX,R9 | |
- CMOVQEQ R8,CX | |
- CMOVQEQ R9,R8 | |
- MOVQ SI,64(DI) | |
- MOVQ DX,144(DI) | |
- MOVQ CX,72(DI) | |
- MOVQ R8,152(DI) | |
- MOVQ DI,AX | |
- MOVQ SI,DX | |
+ SUBQ $1, SI | |
+ NOTQ SI | |
+ MOVQ SI, X15 | |
+ PSHUFD $0x44, X15, X15 | |
+ | |
+ MOVOU 0(DI), X0 | |
+ MOVOU 16(DI), X2 | |
+ MOVOU 32(DI), X4 | |
+ MOVOU 48(DI), X6 | |
+ MOVOU 64(DI), X8 | |
+ MOVOU 80(DI), X1 | |
+ MOVOU 96(DI), X3 | |
+ MOVOU 112(DI), X5 | |
+ MOVOU 128(DI), X7 | |
+ MOVOU 144(DI), X9 | |
+ | |
+ MOVO X1, X10 | |
+ MOVO X3, X11 | |
+ MOVO X5, X12 | |
+ MOVO X7, X13 | |
+ MOVO X9, X14 | |
+ | |
+ PXOR X0, X10 | |
+ PXOR X2, X11 | |
+ PXOR X4, X12 | |
+ PXOR X6, X13 | |
+ PXOR X8, X14 | |
+ PAND X15, X10 | |
+ PAND X15, X11 | |
+ PAND X15, X12 | |
+ PAND X15, X13 | |
+ PAND X15, X14 | |
+ PXOR X10, X0 | |
+ PXOR X10, X1 | |
+ PXOR X11, X2 | |
+ PXOR X11, X3 | |
+ PXOR X12, X4 | |
+ PXOR X12, X5 | |
+ PXOR X13, X6 | |
+ PXOR X13, X7 | |
+ PXOR X14, X8 | |
+ PXOR X14, X9 | |
+ | |
+ MOVOU X0, 0(DI) | |
+ MOVOU X2, 16(DI) | |
+ MOVOU X4, 32(DI) | |
+ MOVOU X6, 48(DI) | |
+ MOVOU X8, 64(DI) | |
+ MOVOU X1, 80(DI) | |
+ MOVOU X3, 96(DI) | |
+ MOVOU X5, 112(DI) | |
+ MOVOU X7, 128(DI) | |
+ MOVOU X9, 144(DI) | |
RET | |
diff --git a/curve25519/curve25519.go b/curve25519/curve25519.go | |
index 6918c47..2d14c2a 100644 | |
--- a/curve25519/curve25519.go | |
+++ b/curve25519/curve25519.go | |
@@ -8,6 +8,10 @@ | |
package curve25519 | |
+import ( | |
+ "encoding/binary" | |
+) | |
+ | |
// This code is a port of the public domain, "ref10" implementation of | |
// curve25519 from SUPERCOP 20130419 by D. J. Bernstein. | |
@@ -50,17 +54,11 @@ func feCopy(dst, src *fieldElement) { | |
// | |
// Preconditions: b in {0,1}. | |
func feCSwap(f, g *fieldElement, b int32) { | |
- var x fieldElement | |
b = -b | |
- for i := range x { | |
- x[i] = b & (f[i] ^ g[i]) | |
- } | |
- | |
for i := range f { | |
- f[i] ^= x[i] | |
- } | |
- for i := range g { | |
- g[i] ^= x[i] | |
+ t := b & (f[i] ^ g[i]) | |
+ f[i] ^= t | |
+ g[i] ^= t | |
} | |
} | |
@@ -75,12 +73,7 @@ func load3(in []byte) int64 { | |
// load4 reads a 32-bit, little-endian value from in. | |
func load4(in []byte) int64 { | |
- var r int64 | |
- r = int64(in[0]) | |
- r |= int64(in[1]) << 8 | |
- r |= int64(in[2]) << 16 | |
- r |= int64(in[3]) << 24 | |
- return r | |
+ return int64(binary.LittleEndian.Uint32(in)) | |
} | |
func feFromBytes(dst *fieldElement, src *[32]byte) { | |
diff --git a/curve25519/curve25519_test.go b/curve25519/curve25519_test.go | |
index 14b0ee8..051a830 100644 | |
--- a/curve25519/curve25519_test.go | |
+++ b/curve25519/curve25519_test.go | |
@@ -27,3 +27,13 @@ func TestBaseScalarMult(t *testing.T) { | |
t.Errorf("incorrect result: got %s, want %s", result, expectedHex) | |
} | |
} | |
+ | |
+func BenchmarkScalarBaseMult(b *testing.B) { | |
+ var in, out [32]byte | |
+ in[0] = 1 | |
+ | |
+ b.SetBytes(32) | |
+ for i := 0; i < b.N; i++ { | |
+ ScalarBaseMult(&out, &in) | |
+ } | |
+} | |
diff --git a/nacl/box/example_test.go b/nacl/box/example_test.go | |
new file mode 100644 | |
index 0000000..25e42d2 | |
--- /dev/null | |
+++ b/nacl/box/example_test.go | |
@@ -0,0 +1,95 @@ | |
+package box_test | |
+ | |
+import ( | |
+ crypto_rand "crypto/rand" // Custom so it's clear which rand we're using. | |
+ "fmt" | |
+ "io" | |
+ | |
+ "golang.org/x/crypto/nacl/box" | |
+) | |
+ | |
+func Example() { | |
+ senderPublicKey, senderPrivateKey, err := box.GenerateKey(crypto_rand.Reader) | |
+ if err != nil { | |
+ panic(err) | |
+ } | |
+ | |
+ recipientPublicKey, recipientPrivateKey, err := box.GenerateKey(crypto_rand.Reader) | |
+ if err != nil { | |
+ panic(err) | |
+ } | |
+ | |
+ // You must use a different nonce for each message you encrypt with the | |
+ // same key. Since the nonce here is 192 bits long, a random value | |
+ // provides a sufficiently small probability of repeats. | |
+ var nonce [24]byte | |
+ if _, err := io.ReadFull(crypto_rand.Reader, nonce[:]); err != nil { | |
+ panic(err) | |
+ } | |
+ | |
+ msg := []byte("Alas, poor Yorick! I knew him, Horatio") | |
+ // This encrypts msg and appends the result to the nonce. | |
+ encrypted := box.Seal(nonce[:], msg, &nonce, recipientPublicKey, senderPrivateKey) | |
+ | |
+ // The recipient can decrypt the message using their private key and the | |
+ // sender's public key. When you decrypt, you must use the same nonce you | |
+ // used to encrypt the message. One way to achieve this is to store the | |
+ // nonce alongside the encrypted message. Above, we stored the nonce in the | |
+ // first 24 bytes of the encrypted text. | |
+ var decryptNonce [24]byte | |
+ copy(decryptNonce[:], encrypted[:24]) | |
+ decrypted, ok := box.Open(nil, encrypted[24:], &decryptNonce, senderPublicKey, recipientPrivateKey) | |
+ if !ok { | |
+ panic("decryption error") | |
+ } | |
+ fmt.Println(string(decrypted)) | |
+ // Output: Alas, poor Yorick! I knew him, Horatio | |
+} | |
+ | |
+func Example_precompute() { | |
+ senderPublicKey, senderPrivateKey, err := box.GenerateKey(crypto_rand.Reader) | |
+ if err != nil { | |
+ panic(err) | |
+ } | |
+ | |
+ recipientPublicKey, recipientPrivateKey, err := box.GenerateKey(crypto_rand.Reader) | |
+ if err != nil { | |
+ panic(err) | |
+ } | |
+ | |
+ // The shared key can be used to speed up processing when using the same | |
+ // pair of keys repeatedly. | |
+ sharedEncryptKey := new([32]byte) | |
+ box.Precompute(sharedEncryptKey, recipientPublicKey, senderPrivateKey) | |
+ | |
+ // You must use a different nonce for each message you encrypt with the | |
+ // same key. Since the nonce here is 192 bits long, a random value | |
+ // provides a sufficiently small probability of repeats. | |
+ var nonce [24]byte | |
+ if _, err := io.ReadFull(crypto_rand.Reader, nonce[:]); err != nil { | |
+ panic(err) | |
+ } | |
+ | |
+ msg := []byte("A fellow of infinite jest, of most excellent fancy") | |
+ // This encrypts msg and appends the result to the nonce. | |
+ encrypted := box.SealAfterPrecomputation(nonce[:], msg, &nonce, sharedEncryptKey) | |
+ | |
+ // The shared key can be used to speed up processing when using the same | |
+ // pair of keys repeatedly. | |
+ var sharedDecryptKey [32]byte | |
+ box.Precompute(&sharedDecryptKey, senderPublicKey, recipientPrivateKey) | |
+ | |
+ // The recipient can decrypt the message using the shared key. When you | |
+ // decrypt, you must use the same nonce you used to encrypt the message. | |
+ // One way to achieve this is to store the nonce alongside the encrypted | |
+ // message. Above, we stored the nonce in the first 24 bytes of the | |
+ // encrypted text. | |
+ var decryptNonce [24]byte | |
+ copy(decryptNonce[:], encrypted[:24]) | |
+ decrypted, ok := box.OpenAfterPrecomputation(nil, encrypted[24:], &decryptNonce, &sharedDecryptKey) | |
+ if !ok { | |
+ panic("decryption error") | |
+ } | |
+ fmt.Println(string(decrypted)) | |
+ // Output: A fellow of infinite jest, of most excellent fancy | |
+} | |
diff --git a/nacl/secretbox/example_test.go b/nacl/secretbox/example_test.go | |
index b25e663..789f4ff 100644 | |
--- a/nacl/secretbox/example_test.go | |
+++ b/nacl/secretbox/example_test.go | |
@@ -43,7 +43,7 @@ func Example() { | |
// 24 bytes of the encrypted text. | |
var decryptNonce [24]byte | |
copy(decryptNonce[:], encrypted[:24]) | |
- decrypted, ok := secretbox.Open([]byte{}, encrypted[24:], &decryptNonce, &secretKey) | |
+ decrypted, ok := secretbox.Open(nil, encrypted[24:], &decryptNonce, &secretKey) | |
if !ok { | |
panic("decryption error") | |
} | |
diff --git a/ocsp/ocsp.go b/ocsp/ocsp.go | |
index 8ed8796..6bd347e 100644 | |
--- a/ocsp/ocsp.go | |
+++ b/ocsp/ocsp.go | |
@@ -450,8 +450,8 @@ func ParseRequest(bytes []byte) (*Request, error) { | |
// then the signature over the response is checked. If issuer is not nil then | |
// it will be used to validate the signature or embedded certificate. | |
// | |
-// Invalid signatures or parse failures will result in a ParseError. Error | |
-// responses will result in a ResponseError. | |
+// Invalid responses and parse failures will result in a ParseError. | |
+// Error responses will result in a ResponseError. | |
func ParseResponse(bytes []byte, issuer *x509.Certificate) (*Response, error) { | |
return ParseResponseForCert(bytes, nil, issuer) | |
} | |
@@ -462,8 +462,8 @@ func ParseResponse(bytes []byte, issuer *x509.Certificate) (*Response, error) { | |
// issuer is not nil then it will be used to validate the signature or embedded | |
// certificate. | |
// | |
-// Invalid signatures or parse failures will result in a ParseError. Error | |
-// responses will result in a ResponseError. | |
+// Invalid responses and parse failures will result in a ParseError. | |
+// Error responses will result in a ResponseError. | |
func ParseResponseForCert(bytes []byte, cert, issuer *x509.Certificate) (*Response, error) { | |
var resp responseASN1 | |
rest, err := asn1.Unmarshal(bytes, &resp) | |
@@ -496,10 +496,32 @@ func ParseResponseForCert(bytes []byte, cert, issuer *x509.Certificate) (*Respon | |
return nil, ParseError("OCSP response contains bad number of responses") | |
} | |
+ var singleResp singleResponse | |
+ if cert == nil { | |
+ singleResp = basicResp.TBSResponseData.Responses[0] | |
+ } else { | |
+ match := false | |
+ for _, resp := range basicResp.TBSResponseData.Responses { | |
+ if cert == nil || cert.SerialNumber.Cmp(resp.CertID.SerialNumber) == 0 { | |
+ singleResp = resp | |
+ match = true | |
+ break | |
+ } | |
+ } | |
+ if !match { | |
+ return nil, ParseError("no response matching the supplied certificate") | |
+ } | |
+ } | |
+ | |
ret := &Response{ | |
TBSResponseData: basicResp.TBSResponseData.Raw, | |
Signature: basicResp.Signature.RightAlign(), | |
SignatureAlgorithm: getSignatureAlgorithmFromOID(basicResp.SignatureAlgorithm.Algorithm), | |
+ Extensions: singleResp.SingleExtensions, | |
+ SerialNumber: singleResp.CertID.SerialNumber, | |
+ ProducedAt: basicResp.TBSResponseData.ProducedAt, | |
+ ThisUpdate: singleResp.ThisUpdate, | |
+ NextUpdate: singleResp.NextUpdate, | |
} | |
// Handle the ResponderID CHOICE tag. ResponderID can be flattened into | |
@@ -542,25 +564,14 @@ func ParseResponseForCert(bytes []byte, cert, issuer *x509.Certificate) (*Respon | |
} | |
} | |
- var r singleResponse | |
- for _, resp := range basicResp.TBSResponseData.Responses { | |
- if cert == nil || cert.SerialNumber.Cmp(resp.CertID.SerialNumber) == 0 { | |
- r = resp | |
- break | |
- } | |
- } | |
- | |
- for _, ext := range r.SingleExtensions { | |
+ for _, ext := range singleResp.SingleExtensions { | |
if ext.Critical { | |
return nil, ParseError("unsupported critical extension") | |
} | |
} | |
- ret.Extensions = r.SingleExtensions | |
- | |
- ret.SerialNumber = r.CertID.SerialNumber | |
for h, oid := range hashOIDs { | |
- if r.CertID.HashAlgorithm.Algorithm.Equal(oid) { | |
+ if singleResp.CertID.HashAlgorithm.Algorithm.Equal(oid) { | |
ret.IssuerHash = h | |
break | |
} | |
@@ -570,20 +581,16 @@ func ParseResponseForCert(bytes []byte, cert, issuer *x509.Certificate) (*Respon | |
} | |
switch { | |
- case bool(r.Good): | |
+ case bool(singleResp.Good): | |
ret.Status = Good | |
- case bool(r.Unknown): | |
+ case bool(singleResp.Unknown): | |
ret.Status = Unknown | |
default: | |
ret.Status = Revoked | |
- ret.RevokedAt = r.Revoked.RevocationTime | |
- ret.RevocationReason = int(r.Revoked.Reason) | |
+ ret.RevokedAt = singleResp.Revoked.RevocationTime | |
+ ret.RevocationReason = int(singleResp.Revoked.Reason) | |
} | |
- ret.ProducedAt = basicResp.TBSResponseData.ProducedAt | |
- ret.ThisUpdate = r.ThisUpdate | |
- ret.NextUpdate = r.NextUpdate | |
- | |
return ret, nil | |
} | |
diff --git a/ocsp/ocsp_test.go b/ocsp/ocsp_test.go | |
index d325d85..df674b3 100644 | |
--- a/ocsp/ocsp_test.go | |
+++ b/ocsp/ocsp_test.go | |
@@ -343,6 +343,21 @@ func TestOCSPDecodeMultiResponse(t *testing.T) { | |
} | |
} | |
+func TestOCSPDecodeMultiResponseWithoutMatchingCert(t *testing.T) { | |
+ wrongCert, _ := hex.DecodeString(startComHex) | |
+ cert, err := x509.ParseCertificate(wrongCert) | |
+ if err != nil { | |
+ t.Fatal(err) | |
+ } | |
+ | |
+ responseBytes, _ := hex.DecodeString(ocspMultiResponseHex) | |
+ _, err = ParseResponseForCert(responseBytes, cert, nil) | |
+ want := ParseError("no response matching the supplied certificate") | |
+ if err != want { | |
+ t.Errorf("err: got %q, want %q", err, want) | |
+ } | |
+} | |
+ | |
// This OCSP response was taken from Thawte's public OCSP responder. | |
// To recreate: | |
// $ openssl s_client -tls1 -showcerts -servername www.google.com -connect www.google.com:443 | |
diff --git a/pkcs12/pkcs12.go b/pkcs12/pkcs12.go | |
index ad6341e..eff9ad3 100644 | |
--- a/pkcs12/pkcs12.go | |
+++ b/pkcs12/pkcs12.go | |
@@ -109,6 +109,10 @@ func ToPEM(pfxData []byte, password string) ([]*pem.Block, error) { | |
bags, encodedPassword, err := getSafeContents(pfxData, encodedPassword) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ | |
blocks := make([]*pem.Block, 0, len(bags)) | |
for _, bag := range bags { | |
block, err := convertBag(&bag, encodedPassword) | |
diff --git a/ssh/agent/client_test.go b/ssh/agent/client_test.go | |
index a13a650..5fc47e5 100644 | |
--- a/ssh/agent/client_test.go | |
+++ b/ssh/agent/client_test.go | |
@@ -180,9 +180,12 @@ func TestCert(t *testing.T) { | |
// therefore is buffered (net.Pipe deadlocks if both sides start with | |
// a write.) | |
func netPipe() (net.Conn, net.Conn, error) { | |
- listener, err := net.Listen("tcp", ":0") | |
+ listener, err := net.Listen("tcp", "127.0.0.1:0") | |
if err != nil { | |
- return nil, nil, err | |
+ listener, err = net.Listen("tcp", "[::1]:0") | |
+ if err != nil { | |
+ return nil, nil, err | |
+ } | |
} | |
defer listener.Close() | |
c1, err := net.Dial("tcp", listener.Addr().String()) | |
@@ -200,6 +203,9 @@ func netPipe() (net.Conn, net.Conn, error) { | |
} | |
func TestAuth(t *testing.T) { | |
+ agent, _, cleanup := startAgent(t) | |
+ defer cleanup() | |
+ | |
a, b, err := netPipe() | |
if err != nil { | |
t.Fatalf("netPipe: %v", err) | |
@@ -208,9 +214,6 @@ func TestAuth(t *testing.T) { | |
defer a.Close() | |
defer b.Close() | |
- agent, _, cleanup := startAgent(t) | |
- defer cleanup() | |
- | |
if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil { | |
t.Errorf("Add: %v", err) | |
} | |
@@ -233,7 +236,9 @@ func TestAuth(t *testing.T) { | |
conn.Close() | |
}() | |
- conf := ssh.ClientConfig{} | |
+ conf := ssh.ClientConfig{ | |
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(), | |
+ } | |
conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers)) | |
conn, _, _, err := ssh.NewClientConn(b, "", &conf) | |
if err != nil { | |
diff --git a/ssh/agent/example_test.go b/ssh/agent/example_test.go | |
index c1130f7..8556225 100644 | |
--- a/ssh/agent/example_test.go | |
+++ b/ssh/agent/example_test.go | |
@@ -6,20 +6,20 @@ package agent_test | |
import ( | |
"log" | |
- "os" | |
"net" | |
+ "os" | |
- "golang.org/x/crypto/ssh" | |
- "golang.org/x/crypto/ssh/agent" | |
+ "golang.org/x/crypto/ssh" | |
+ "golang.org/x/crypto/ssh/agent" | |
) | |
func ExampleClientAgent() { | |
// ssh-agent has a UNIX socket under $SSH_AUTH_SOCK | |
socket := os.Getenv("SSH_AUTH_SOCK") | |
- conn, err := net.Dial("unix", socket) | |
- if err != nil { | |
- log.Fatalf("net.Dial: %v", err) | |
- } | |
+ conn, err := net.Dial("unix", socket) | |
+ if err != nil { | |
+ log.Fatalf("net.Dial: %v", err) | |
+ } | |
agentClient := agent.NewClient(conn) | |
config := &ssh.ClientConfig{ | |
User: "username", | |
@@ -29,6 +29,7 @@ func ExampleClientAgent() { | |
// wants it. | |
ssh.PublicKeysCallback(agentClient.Signers), | |
}, | |
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(), | |
} | |
sshc, err := ssh.Dial("tcp", "localhost:22", config) | |
diff --git a/ssh/agent/server_test.go b/ssh/agent/server_test.go | |
index ec9cdee..6b0837d 100644 | |
--- a/ssh/agent/server_test.go | |
+++ b/ssh/agent/server_test.go | |
@@ -56,7 +56,9 @@ func TestSetupForwardAgent(t *testing.T) { | |
incoming <- conn | |
}() | |
- conf := ssh.ClientConfig{} | |
+ conf := ssh.ClientConfig{ | |
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(), | |
+ } | |
conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf) | |
if err != nil { | |
t.Fatalf("NewClientConn: %v", err) | |
diff --git a/ssh/certs.go b/ssh/certs.go | |
index 6331c94..b1f0220 100644 | |
--- a/ssh/certs.go | |
+++ b/ssh/certs.go | |
@@ -251,10 +251,18 @@ type CertChecker struct { | |
// for user certificates. | |
SupportedCriticalOptions []string | |
- // IsAuthority should return true if the key is recognized as | |
- // an authority. This allows for certificates to be signed by other | |
- // certificates. | |
- IsAuthority func(auth PublicKey) bool | |
+ // IsUserAuthority should return true if the key is recognized as an | |
+ // authority for the given user certificate. This allows for | |
+ // certificates to be signed by other certificates. This must be set | |
+ // if this CertChecker will be checking user certificates. | |
+ IsUserAuthority func(auth PublicKey) bool | |
+ | |
+ // IsHostAuthority should report whether the key is recognized as | |
+ // an authority for this host. This allows for certificates to be | |
+ // signed by other keys, and for those other keys to only be valid | |
+ // signers for particular hostnames. This must be set if this | |
+ // CertChecker will be checking host certificates. | |
+ IsHostAuthority func(auth PublicKey, address string) bool | |
// Clock is used for verifying time stamps. If nil, time.Now | |
// is used. | |
@@ -268,7 +276,7 @@ type CertChecker struct { | |
// HostKeyFallback is called when CertChecker.CheckHostKey encounters a | |
// public key that is not a certificate. It must implement host key | |
// validation or else, if nil, all such keys are rejected. | |
- HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error | |
+ HostKeyFallback HostKeyCallback | |
// IsRevoked is called for each certificate so that revocation checking | |
// can be implemented. It should return true if the given certificate | |
@@ -290,8 +298,17 @@ func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) | |
if cert.CertType != HostCert { | |
return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType) | |
} | |
+ if !c.IsHostAuthority(cert.SignatureKey, addr) { | |
+ return fmt.Errorf("ssh: no authorities for hostname: %v", addr) | |
+ } | |
+ | |
+ hostname, _, err := net.SplitHostPort(addr) | |
+ if err != nil { | |
+ return err | |
+ } | |
- return c.CheckCert(addr, cert) | |
+ // Pass hostname only as principal for host certificates (consistent with OpenSSH) | |
+ return c.CheckCert(hostname, cert) | |
} | |
// Authenticate checks a user certificate. Authenticate can be used as | |
@@ -308,6 +325,9 @@ func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permis | |
if cert.CertType != UserCert { | |
return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) | |
} | |
+ if !c.IsUserAuthority(cert.SignatureKey) { | |
+ return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority") | |
+ } | |
if err := c.CheckCert(conn.User(), cert); err != nil { | |
return nil, err | |
@@ -356,10 +376,6 @@ func (c *CertChecker) CheckCert(principal string, cert *Certificate) error { | |
} | |
} | |
- if !c.IsAuthority(cert.SignatureKey) { | |
- return fmt.Errorf("ssh: certificate signed by unrecognized authority") | |
- } | |
- | |
clock := c.Clock | |
if clock == nil { | |
clock = time.Now | |
diff --git a/ssh/certs_test.go b/ssh/certs_test.go | |
index c5f2e53..0200531 100644 | |
--- a/ssh/certs_test.go | |
+++ b/ssh/certs_test.go | |
@@ -104,7 +104,7 @@ func TestValidateCert(t *testing.T) { | |
t.Fatalf("got %v (%T), want *Certificate", key, key) | |
} | |
checker := CertChecker{} | |
- checker.IsAuthority = func(k PublicKey) bool { | |
+ checker.IsUserAuthority = func(k PublicKey) bool { | |
return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal()) | |
} | |
@@ -142,7 +142,7 @@ func TestValidateCertTime(t *testing.T) { | |
checker := CertChecker{ | |
Clock: func() time.Time { return time.Unix(ts, 0) }, | |
} | |
- checker.IsAuthority = func(k PublicKey) bool { | |
+ checker.IsUserAuthority = func(k PublicKey) bool { | |
return bytes.Equal(k.Marshal(), | |
testPublicKeys["ecdsa"].Marshal()) | |
} | |
@@ -160,7 +160,7 @@ func TestValidateCertTime(t *testing.T) { | |
func TestHostKeyCert(t *testing.T) { | |
cert := &Certificate{ | |
- ValidPrincipals: []string{"hostname", "hostname.domain"}, | |
+ ValidPrincipals: []string{"hostname", "hostname.domain", "otherhost"}, | |
Key: testPublicKeys["rsa"], | |
ValidBefore: CertTimeInfinity, | |
CertType: HostCert, | |
@@ -168,8 +168,8 @@ func TestHostKeyCert(t *testing.T) { | |
cert.SignCert(rand.Reader, testSigners["ecdsa"]) | |
checker := &CertChecker{ | |
- IsAuthority: func(p PublicKey) bool { | |
- return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal()) | |
+ IsHostAuthority: func(p PublicKey, addr string) bool { | |
+ return addr == "hostname:22" && bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal()) | |
}, | |
} | |
@@ -178,7 +178,14 @@ func TestHostKeyCert(t *testing.T) { | |
t.Errorf("NewCertSigner: %v", err) | |
} | |
- for _, name := range []string{"hostname", "otherhost"} { | |
+ for _, test := range []struct { | |
+ addr string | |
+ succeed bool | |
+ }{ | |
+ {addr: "hostname:22", succeed: true}, | |
+ {addr: "otherhost:22", succeed: false}, // The certificate is valid for 'otherhost' as hostname, but we only recognize the authority of the signer for the address 'hostname:22' | |
+ {addr: "lasthost:22", succeed: false}, | |
+ } { | |
c1, c2, err := netPipe() | |
if err != nil { | |
t.Fatalf("netPipe: %v", err) | |
@@ -201,16 +208,15 @@ func TestHostKeyCert(t *testing.T) { | |
User: "user", | |
HostKeyCallback: checker.CheckHostKey, | |
} | |
- _, _, _, err = NewClientConn(c2, name, config) | |
+ _, _, _, err = NewClientConn(c2, test.addr, config) | |
- succeed := name == "hostname" | |
- if (err == nil) != succeed { | |
- t.Fatalf("NewClientConn(%q): %v", name, err) | |
+ if (err == nil) != test.succeed { | |
+ t.Fatalf("NewClientConn(%q): %v", test.addr, err) | |
} | |
err = <-errc | |
- if (err == nil) != succeed { | |
- t.Fatalf("NewServerConn(%q): %v", name, err) | |
+ if (err == nil) != test.succeed { | |
+ t.Fatalf("NewServerConn(%q): %v", test.addr, err) | |
} | |
} | |
} | |
diff --git a/ssh/client.go b/ssh/client.go | |
index c97f297..a7e3263 100644 | |
--- a/ssh/client.go | |
+++ b/ssh/client.go | |
@@ -5,6 +5,7 @@ | |
package ssh | |
import ( | |
+ "bytes" | |
"errors" | |
"fmt" | |
"net" | |
@@ -13,7 +14,7 @@ import ( | |
) | |
// Client implements a traditional SSH client that supports shells, | |
-// subprocesses, port forwarding and tunneled dialing. | |
+// subprocesses, TCP port/streamlocal forwarding and tunneled dialing. | |
type Client struct { | |
Conn | |
@@ -59,6 +60,7 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client { | |
conn.forwards.closeAll() | |
}() | |
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip")) | |
+ go conn.forwards.handleChannels(conn.HandleChannelOpen("[email protected]")) | |
return conn | |
} | |
@@ -68,6 +70,11 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client { | |
func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) { | |
fullConf := *config | |
fullConf.SetDefaults() | |
+ if fullConf.HostKeyCallback == nil { | |
+ c.Close() | |
+ return nil, nil, nil, errors.New("ssh: must specify HostKeyCallback") | |
+ } | |
+ | |
conn := &connection{ | |
sshConn: sshConn{conn: c}, | |
} | |
@@ -173,6 +180,13 @@ func Dial(network, addr string, config *ClientConfig) (*Client, error) { | |
return NewClient(c, chans, reqs), nil | |
} | |
+// HostKeyCallback is the function type used for verifying server | |
+// keys. A HostKeyCallback must return nil if the host key is OK, or | |
+// an error to reject it. It receives the hostname as passed to Dial | |
+// or NewClientConn. The remote address is the RemoteAddr of the | |
+// net.Conn underlying the the SSH connection. | |
+type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error | |
+ | |
// A ClientConfig structure is used to configure a Client. It must not be | |
// modified after having been passed to an SSH function. | |
type ClientConfig struct { | |
@@ -188,10 +202,12 @@ type ClientConfig struct { | |
// be used during authentication. | |
Auth []AuthMethod | |
- // HostKeyCallback, if not nil, is called during the cryptographic | |
- // handshake to validate the server's host key. A nil HostKeyCallback | |
- // implies that all host keys are accepted. | |
- HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error | |
+ // HostKeyCallback is called during the cryptographic | |
+ // handshake to validate the server's host key. The client | |
+ // configuration must supply this callback for the connection | |
+ // to succeed. The functions InsecureIgnoreHostKey or | |
+ // FixedHostKey can be used for simplistic host key checks. | |
+ HostKeyCallback HostKeyCallback | |
// ClientVersion contains the version identification string that will | |
// be used for the connection. If empty, a reasonable default is used. | |
@@ -209,3 +225,33 @@ type ClientConfig struct { | |
// A Timeout of zero means no timeout. | |
Timeout time.Duration | |
} | |
+ | |
+// InsecureIgnoreHostKey returns a function that can be used for | |
+// ClientConfig.HostKeyCallback to accept any host key. It should | |
+// not be used for production code. | |
+func InsecureIgnoreHostKey() HostKeyCallback { | |
+ return func(hostname string, remote net.Addr, key PublicKey) error { | |
+ return nil | |
+ } | |
+} | |
+ | |
+type fixedHostKey struct { | |
+ key PublicKey | |
+} | |
+ | |
+func (f *fixedHostKey) check(hostname string, remote net.Addr, key PublicKey) error { | |
+ if f.key == nil { | |
+ return fmt.Errorf("ssh: required host key was nil") | |
+ } | |
+ if !bytes.Equal(key.Marshal(), f.key.Marshal()) { | |
+ return fmt.Errorf("ssh: host key mismatch") | |
+ } | |
+ return nil | |
+} | |
+ | |
+// FixedHostKey returns a function for use in | |
+// ClientConfig.HostKeyCallback to accept only a specific host key. | |
+func FixedHostKey(key PublicKey) HostKeyCallback { | |
+ hk := &fixedHostKey{key} | |
+ return hk.check | |
+} | |
diff --git a/ssh/client_auth.go b/ssh/client_auth.go | |
index fd1ec5d..b882da0 100644 | |
--- a/ssh/client_auth.go | |
+++ b/ssh/client_auth.go | |
@@ -179,31 +179,26 @@ func (cb publicKeyCallback) method() string { | |
} | |
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { | |
- // Authentication is performed in two stages. The first stage sends an | |
- // enquiry to test if each key is acceptable to the remote. The second | |
- // stage attempts to authenticate with the valid keys obtained in the | |
- // first stage. | |
+ // Authentication is performed by sending an enquiry to test if a key is | |
+ // acceptable to the remote. If the key is acceptable, the client will | |
+ // attempt to authenticate with the valid key. If not the client will repeat | |
+ // the process with the remaining keys. | |
signers, err := cb() | |
if err != nil { | |
return false, nil, err | |
} | |
- var validKeys []Signer | |
+ var methods []string | |
for _, signer := range signers { | |
- if ok, err := validateKey(signer.PublicKey(), user, c); ok { | |
- validKeys = append(validKeys, signer) | |
- } else { | |
- if err != nil { | |
- return false, nil, err | |
- } | |
+ ok, err := validateKey(signer.PublicKey(), user, c) | |
+ if err != nil { | |
+ return false, nil, err | |
+ } | |
+ if !ok { | |
+ continue | |
} | |
- } | |
- // methods that may continue if this auth is not successful. | |
- var methods []string | |
- for _, signer := range validKeys { | |
pub := signer.PublicKey() | |
- | |
pubKey := pub.Marshal() | |
sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{ | |
User: user, | |
@@ -236,13 +231,29 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand | |
if err != nil { | |
return false, nil, err | |
} | |
- if success { | |
+ | |
+ // If authentication succeeds or the list of available methods does not | |
+ // contain the "publickey" method, do not attempt to authenticate with any | |
+ // other keys. According to RFC 4252 Section 7, the latter can occur when | |
+ // additional authentication methods are required. | |
+ if success || !containsMethod(methods, cb.method()) { | |
return success, methods, err | |
} | |
} | |
+ | |
return false, methods, nil | |
} | |
+func containsMethod(methods []string, method string) bool { | |
+ for _, m := range methods { | |
+ if m == method { | |
+ return true | |
+ } | |
+ } | |
+ | |
+ return false | |
+} | |
+ | |
// validateKey validates the key provided is acceptable to the server. | |
func validateKey(key PublicKey, user string, c packetConn) (bool, error) { | |
pubKey := key.Marshal() | |
diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go | |
index e384c79..bd9f8a1 100644 | |
--- a/ssh/client_auth_test.go | |
+++ b/ssh/client_auth_test.go | |
@@ -38,7 +38,7 @@ func tryAuth(t *testing.T, config *ClientConfig) error { | |
defer c2.Close() | |
certChecker := CertChecker{ | |
- IsAuthority: func(k PublicKey) bool { | |
+ IsUserAuthority: func(k PublicKey) bool { | |
return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal()) | |
}, | |
UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { | |
@@ -76,8 +76,6 @@ func tryAuth(t *testing.T, config *ClientConfig) error { | |
} | |
return nil, errors.New("keyboard-interactive failed") | |
}, | |
- AuthLogCallback: func(conn ConnMetadata, method string, err error) { | |
- }, | |
} | |
serverConfig.AddHostKey(testSigners["rsa"]) | |
@@ -92,6 +90,7 @@ func TestClientAuthPublicKey(t *testing.T) { | |
Auth: []AuthMethod{ | |
PublicKeys(testSigners["rsa"]), | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err != nil { | |
t.Fatalf("unable to dial remote side: %s", err) | |
@@ -104,6 +103,7 @@ func TestAuthMethodPassword(t *testing.T) { | |
Auth: []AuthMethod{ | |
Password(clientPassword), | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err != nil { | |
@@ -123,6 +123,7 @@ func TestAuthMethodFallback(t *testing.T) { | |
return "WRONG", nil | |
}), | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err != nil { | |
@@ -141,6 +142,7 @@ func TestAuthMethodWrongPassword(t *testing.T) { | |
Password("wrong"), | |
PublicKeys(testSigners["rsa"]), | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err != nil { | |
@@ -158,6 +160,7 @@ func TestAuthMethodKeyboardInteractive(t *testing.T) { | |
Auth: []AuthMethod{ | |
KeyboardInteractive(answers.Challenge), | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err != nil { | |
@@ -203,6 +206,7 @@ func TestAuthMethodRSAandDSA(t *testing.T) { | |
Auth: []AuthMethod{ | |
PublicKeys(testSigners["dsa"], testSigners["rsa"]), | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err != nil { | |
t.Fatalf("client could not authenticate with rsa key: %v", err) | |
@@ -219,6 +223,7 @@ func TestClientHMAC(t *testing.T) { | |
Config: Config{ | |
MACs: []string{mac}, | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err != nil { | |
t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err) | |
@@ -254,6 +259,7 @@ func TestClientUnsupportedKex(t *testing.T) { | |
Config: Config{ | |
KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "common algorithm") { | |
t.Errorf("got %v, expected 'common algorithm'", err) | |
@@ -273,7 +279,8 @@ func TestClientLoginCert(t *testing.T) { | |
} | |
clientConfig := &ClientConfig{ | |
- User: "user", | |
+ User: "user", | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner)) | |
@@ -363,6 +370,7 @@ func testPermissionsPassing(withPermissions bool, t *testing.T) { | |
Auth: []AuthMethod{ | |
PublicKeys(testSigners["rsa"]), | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if withPermissions { | |
clientConfig.User = "permissions" | |
@@ -409,6 +417,7 @@ func TestRetryableAuth(t *testing.T) { | |
}), 2), | |
PublicKeys(testSigners["rsa"]), | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err != nil { | |
@@ -430,7 +439,8 @@ func ExampleRetryableAuthMethod(t *testing.T) { | |
} | |
config := &ClientConfig{ | |
- User: user, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
+ User: user, | |
Auth: []AuthMethod{ | |
RetryableAuthMethod(KeyboardInteractiveChallenge(Cb), NumberOfPrompts), | |
}, | |
@@ -450,7 +460,8 @@ func TestClientAuthNone(t *testing.T) { | |
serverConfig.AddHostKey(testSigners["rsa"]) | |
clientConfig := &ClientConfig{ | |
- User: user, | |
+ User: user, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
c1, c2, err := netPipe() | |
@@ -469,3 +480,100 @@ func TestClientAuthNone(t *testing.T) { | |
t.Fatalf("server: got %q, want %q", serverConn.User(), user) | |
} | |
} | |
+ | |
+// Test if authentication attempts are limited on server when MaxAuthTries is set | |
+func TestClientAuthMaxAuthTries(t *testing.T) { | |
+ user := "testuser" | |
+ | |
+ serverConfig := &ServerConfig{ | |
+ MaxAuthTries: 2, | |
+ PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) { | |
+ if conn.User() == "testuser" && string(pass) == "right" { | |
+ return nil, nil | |
+ } | |
+ return nil, errors.New("password auth failed") | |
+ }, | |
+ } | |
+ serverConfig.AddHostKey(testSigners["rsa"]) | |
+ | |
+ expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{ | |
+ Reason: 2, | |
+ Message: "too many authentication failures", | |
+ }) | |
+ | |
+ for tries := 2; tries < 4; tries++ { | |
+ n := tries | |
+ clientConfig := &ClientConfig{ | |
+ User: user, | |
+ Auth: []AuthMethod{ | |
+ RetryableAuthMethod(PasswordCallback(func() (string, error) { | |
+ n-- | |
+ if n == 0 { | |
+ return "right", nil | |
+ } else { | |
+ return "wrong", nil | |
+ } | |
+ }), tries), | |
+ }, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
+ } | |
+ | |
+ c1, c2, err := netPipe() | |
+ if err != nil { | |
+ t.Fatalf("netPipe: %v", err) | |
+ } | |
+ defer c1.Close() | |
+ defer c2.Close() | |
+ | |
+ go newServer(c1, serverConfig) | |
+ _, _, _, err = NewClientConn(c2, "", clientConfig) | |
+ if tries > 2 { | |
+ if err == nil { | |
+ t.Fatalf("client: got no error, want %s", expectedErr) | |
+ } else if err.Error() != expectedErr.Error() { | |
+ t.Fatalf("client: got %s, want %s", err, expectedErr) | |
+ } | |
+ } else { | |
+ if err != nil { | |
+ t.Fatalf("client: got %s, want no error", err) | |
+ } | |
+ } | |
+ } | |
+} | |
+ | |
+// Test if authentication attempts are correctly limited on server | |
+// when more public keys are provided then MaxAuthTries | |
+func TestClientAuthMaxAuthTriesPublicKey(t *testing.T) { | |
+ signers := []Signer{} | |
+ for i := 0; i < 6; i++ { | |
+ signers = append(signers, testSigners["dsa"]) | |
+ } | |
+ | |
+ validConfig := &ClientConfig{ | |
+ User: "testuser", | |
+ Auth: []AuthMethod{ | |
+ PublicKeys(append([]Signer{testSigners["rsa"]}, signers...)...), | |
+ }, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
+ } | |
+ if err := tryAuth(t, validConfig); err != nil { | |
+ t.Fatalf("unable to dial remote side: %s", err) | |
+ } | |
+ | |
+ expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{ | |
+ Reason: 2, | |
+ Message: "too many authentication failures", | |
+ }) | |
+ invalidConfig := &ClientConfig{ | |
+ User: "testuser", | |
+ Auth: []AuthMethod{ | |
+ PublicKeys(append(signers, testSigners["rsa"])...), | |
+ }, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
+ } | |
+ if err := tryAuth(t, invalidConfig); err == nil { | |
+ t.Fatalf("client: got no error, want %s", expectedErr) | |
+ } else if err.Error() != expectedErr.Error() { | |
+ t.Fatalf("client: got %s, want %s", err, expectedErr) | |
+ } | |
+} | |
diff --git a/ssh/client_test.go b/ssh/client_test.go | |
index 1fe790c..ccf5607 100644 | |
--- a/ssh/client_test.go | |
+++ b/ssh/client_test.go | |
@@ -6,6 +6,7 @@ package ssh | |
import ( | |
"net" | |
+ "strings" | |
"testing" | |
) | |
@@ -13,6 +14,7 @@ func testClientVersion(t *testing.T, config *ClientConfig, expected string) { | |
clientConn, serverConn := net.Pipe() | |
defer clientConn.Close() | |
receivedVersion := make(chan string, 1) | |
+ config.HostKeyCallback = InsecureIgnoreHostKey() | |
go func() { | |
version, err := readVersion(serverConn) | |
if err != nil { | |
@@ -37,3 +39,43 @@ func TestCustomClientVersion(t *testing.T) { | |
func TestDefaultClientVersion(t *testing.T) { | |
testClientVersion(t, &ClientConfig{}, packageVersion) | |
} | |
+ | |
+func TestHostKeyCheck(t *testing.T) { | |
+ for _, tt := range []struct { | |
+ name string | |
+ wantError string | |
+ key PublicKey | |
+ }{ | |
+ {"no callback", "must specify HostKeyCallback", nil}, | |
+ {"correct key", "", testSigners["rsa"].PublicKey()}, | |
+ {"mismatch", "mismatch", testSigners["ecdsa"].PublicKey()}, | |
+ } { | |
+ c1, c2, err := netPipe() | |
+ if err != nil { | |
+ t.Fatalf("netPipe: %v", err) | |
+ } | |
+ defer c1.Close() | |
+ defer c2.Close() | |
+ serverConf := &ServerConfig{ | |
+ NoClientAuth: true, | |
+ } | |
+ serverConf.AddHostKey(testSigners["rsa"]) | |
+ | |
+ go NewServerConn(c1, serverConf) | |
+ clientConf := ClientConfig{ | |
+ User: "user", | |
+ } | |
+ if tt.key != nil { | |
+ clientConf.HostKeyCallback = FixedHostKey(tt.key) | |
+ } | |
+ | |
+ _, _, _, err = NewClientConn(c2, "", &clientConf) | |
+ if err != nil { | |
+ if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) { | |
+ t.Errorf("%s: got error %q, missing %q", tt.name, err.Error(), tt.wantError) | |
+ } | |
+ } else if tt.wantError != "" { | |
+ t.Errorf("%s: succeeded, but want error string %q", tt.name, tt.wantError) | |
+ } | |
+ } | |
+} | |
diff --git a/ssh/common.go b/ssh/common.go | |
index 8656d0f..dc39e4d 100644 | |
--- a/ssh/common.go | |
+++ b/ssh/common.go | |
@@ -9,6 +9,7 @@ import ( | |
"crypto/rand" | |
"fmt" | |
"io" | |
+ "math" | |
"sync" | |
_ "crypto/sha1" | |
@@ -40,7 +41,7 @@ var supportedKexAlgos = []string{ | |
kexAlgoDH14SHA1, kexAlgoDH1SHA1, | |
} | |
-// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods | |
+// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods | |
// of authenticating servers) in preference order. | |
var supportedHostKeyAlgos = []string{ | |
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, | |
@@ -186,7 +187,7 @@ type Config struct { | |
// The maximum number of bytes sent or received after which a | |
// new key is negotiated. It must be at least 256. If | |
- // unspecified, 1 gigabyte is used. | |
+ // unspecified, a size suitable for the chosen cipher is used. | |
RekeyThreshold uint64 | |
// The allowed key exchanges algorithms. If unspecified then a | |
@@ -230,11 +231,12 @@ func (c *Config) SetDefaults() { | |
} | |
if c.RekeyThreshold == 0 { | |
- // RFC 4253, section 9 suggests rekeying after 1G. | |
- c.RekeyThreshold = 1 << 30 | |
- } | |
- if c.RekeyThreshold < minRekeyThreshold { | |
+ // cipher specific default | |
+ } else if c.RekeyThreshold < minRekeyThreshold { | |
c.RekeyThreshold = minRekeyThreshold | |
+ } else if c.RekeyThreshold >= math.MaxInt64 { | |
+ // Avoid weirdness if somebody uses -1 as a threshold. | |
+ c.RekeyThreshold = math.MaxInt64 | |
} | |
} | |
diff --git a/ssh/connection.go b/ssh/connection.go | |
index e786f2f..fd6b068 100644 | |
--- a/ssh/connection.go | |
+++ b/ssh/connection.go | |
@@ -25,7 +25,7 @@ type ConnMetadata interface { | |
// User returns the user ID for this connection. | |
User() string | |
- // SessionID returns the sesson hash, also denoted by H. | |
+ // SessionID returns the session hash, also denoted by H. | |
SessionID() []byte | |
// ClientVersion returns the client's version string as hashed | |
diff --git a/ssh/doc.go b/ssh/doc.go | |
index d6be894..67b7322 100644 | |
--- a/ssh/doc.go | |
+++ b/ssh/doc.go | |
@@ -14,5 +14,8 @@ others. | |
References: | |
[PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD | |
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1 | |
+ | |
+This package does not fall under the stability promise of the Go language itself, | |
+so its API may be changed when pressing needs arise. | |
*/ | |
package ssh // import "golang.org/x/crypto/ssh" | |
diff --git a/ssh/example_test.go b/ssh/example_test.go | |
index 4d2eabd..618398c 100644 | |
--- a/ssh/example_test.go | |
+++ b/ssh/example_test.go | |
@@ -5,12 +5,16 @@ | |
package ssh_test | |
import ( | |
+ "bufio" | |
"bytes" | |
"fmt" | |
"io/ioutil" | |
"log" | |
"net" | |
"net/http" | |
+ "os" | |
+ "path/filepath" | |
+ "strings" | |
"golang.org/x/crypto/ssh" | |
"golang.org/x/crypto/ssh/terminal" | |
@@ -91,8 +95,6 @@ func ExampleNewServerConn() { | |
go ssh.DiscardRequests(reqs) | |
// Service the incoming Channel channel. | |
- | |
- // Service the incoming Channel channel. | |
for newChannel := range chans { | |
// Channels have a type, depending on the application level | |
// protocol intended. In the case of a shell, the type is | |
@@ -131,16 +133,59 @@ func ExampleNewServerConn() { | |
} | |
} | |
+func ExampleHostKeyCheck() { | |
+ // Every client must provide a host key check. Here is a | |
+ // simple-minded parse of OpenSSH's known_hosts file | |
+ host := "hostname" | |
+ file, err := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts")) | |
+ if err != nil { | |
+ log.Fatal(err) | |
+ } | |
+ defer file.Close() | |
+ | |
+ scanner := bufio.NewScanner(file) | |
+ var hostKey ssh.PublicKey | |
+ for scanner.Scan() { | |
+ fields := strings.Split(scanner.Text(), " ") | |
+ if len(fields) != 3 { | |
+ continue | |
+ } | |
+ if strings.Contains(fields[0], host) { | |
+ var err error | |
+ hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes()) | |
+ if err != nil { | |
+ log.Fatalf("error parsing %q: %v", fields[2], err) | |
+ } | |
+ break | |
+ } | |
+ } | |
+ | |
+ if hostKey == nil { | |
+ log.Fatalf("no hostkey for %s", host) | |
+ } | |
+ | |
+ config := ssh.ClientConfig{ | |
+ User: os.Getenv("USER"), | |
+ HostKeyCallback: ssh.FixedHostKey(hostKey), | |
+ } | |
+ | |
+ _, err = ssh.Dial("tcp", host+":22", &config) | |
+ log.Println(err) | |
+} | |
+ | |
func ExampleDial() { | |
+ var hostKey ssh.PublicKey | |
// An SSH client is represented with a ClientConn. | |
// | |
// To authenticate with the remote server you must pass at least one | |
- // implementation of AuthMethod via the Auth field in ClientConfig. | |
+ // implementation of AuthMethod via the Auth field in ClientConfig, | |
+ // and provide a HostKeyCallback. | |
config := &ssh.ClientConfig{ | |
User: "username", | |
Auth: []ssh.AuthMethod{ | |
ssh.Password("yourpassword"), | |
}, | |
+ HostKeyCallback: ssh.FixedHostKey(hostKey), | |
} | |
client, err := ssh.Dial("tcp", "yourserver.com:22", config) | |
if err != nil { | |
@@ -166,6 +211,7 @@ func ExampleDial() { | |
} | |
func ExamplePublicKeys() { | |
+ var hostKey ssh.PublicKey | |
// A public key may be used to authenticate against the remote | |
// server by using an unencrypted PEM-encoded private key file. | |
// | |
@@ -188,6 +234,7 @@ func ExamplePublicKeys() { | |
// Use the PublicKeys method for remote authentication. | |
ssh.PublicKeys(signer), | |
}, | |
+ HostKeyCallback: ssh.FixedHostKey(hostKey), | |
} | |
// Connect to the remote server and perform the SSH handshake. | |
@@ -199,11 +246,13 @@ func ExamplePublicKeys() { | |
} | |
func ExampleClient_Listen() { | |
+ var hostKey ssh.PublicKey | |
config := &ssh.ClientConfig{ | |
User: "username", | |
Auth: []ssh.AuthMethod{ | |
ssh.Password("password"), | |
}, | |
+ HostKeyCallback: ssh.FixedHostKey(hostKey), | |
} | |
// Dial your ssh server. | |
conn, err := ssh.Dial("tcp", "localhost:22", config) | |
@@ -226,12 +275,14 @@ func ExampleClient_Listen() { | |
} | |
func ExampleSession_RequestPty() { | |
+ var hostKey ssh.PublicKey | |
// Create client config | |
config := &ssh.ClientConfig{ | |
User: "username", | |
Auth: []ssh.AuthMethod{ | |
ssh.Password("password"), | |
}, | |
+ HostKeyCallback: ssh.FixedHostKey(hostKey), | |
} | |
// Connect to ssh server | |
conn, err := ssh.Dial("tcp", "localhost:22", config) | |
diff --git a/ssh/handshake.go b/ssh/handshake.go | |
index 8de6506..932ce83 100644 | |
--- a/ssh/handshake.go | |
+++ b/ssh/handshake.go | |
@@ -74,7 +74,7 @@ type handshakeTransport struct { | |
startKex chan *pendingKex | |
// data for host key checking | |
- hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error | |
+ hostKeyCallback HostKeyCallback | |
dialAddress string | |
remoteAddr net.Addr | |
@@ -107,6 +107,8 @@ func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, | |
config: config, | |
} | |
+ t.resetReadThresholds() | |
+ t.resetWriteThresholds() | |
// We always start with a mandatory key exchange. | |
t.requestKex <- struct{}{} | |
@@ -237,6 +239,17 @@ func (t *handshakeTransport) requestKeyExchange() { | |
} | |
} | |
+func (t *handshakeTransport) resetWriteThresholds() { | |
+ t.writePacketsLeft = packetRekeyThreshold | |
+ if t.config.RekeyThreshold > 0 { | |
+ t.writeBytesLeft = int64(t.config.RekeyThreshold) | |
+ } else if t.algorithms != nil { | |
+ t.writeBytesLeft = t.algorithms.w.rekeyBytes() | |
+ } else { | |
+ t.writeBytesLeft = 1 << 30 | |
+ } | |
+} | |
+ | |
func (t *handshakeTransport) kexLoop() { | |
write: | |
@@ -285,12 +298,8 @@ write: | |
t.writeError = err | |
t.sentInitPacket = nil | |
t.sentInitMsg = nil | |
- t.writePacketsLeft = packetRekeyThreshold | |
- if t.config.RekeyThreshold > 0 { | |
- t.writeBytesLeft = int64(t.config.RekeyThreshold) | |
- } else if t.algorithms != nil { | |
- t.writeBytesLeft = t.algorithms.w.rekeyBytes() | |
- } | |
+ | |
+ t.resetWriteThresholds() | |
// we have completed the key exchange. Since the | |
// reader is still blocked, it is safe to clear out | |
@@ -344,6 +353,17 @@ write: | |
// key exchange itself. | |
const packetRekeyThreshold = (1 << 31) | |
+func (t *handshakeTransport) resetReadThresholds() { | |
+ t.readPacketsLeft = packetRekeyThreshold | |
+ if t.config.RekeyThreshold > 0 { | |
+ t.readBytesLeft = int64(t.config.RekeyThreshold) | |
+ } else if t.algorithms != nil { | |
+ t.readBytesLeft = t.algorithms.r.rekeyBytes() | |
+ } else { | |
+ t.readBytesLeft = 1 << 30 | |
+ } | |
+} | |
+ | |
func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) { | |
p, err := t.conn.readPacket() | |
if err != nil { | |
@@ -391,12 +411,7 @@ func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) { | |
return nil, err | |
} | |
- t.readPacketsLeft = packetRekeyThreshold | |
- if t.config.RekeyThreshold > 0 { | |
- t.readBytesLeft = int64(t.config.RekeyThreshold) | |
- } else { | |
- t.readBytesLeft = t.algorithms.r.rekeyBytes() | |
- } | |
+ t.resetReadThresholds() | |
// By default, a key exchange is hidden from higher layers by | |
// translating it into msgIgnore. | |
@@ -574,7 +589,9 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { | |
} | |
result.SessionID = t.sessionID | |
- t.conn.prepareKeyChange(t.algorithms, result) | |
+ if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil { | |
+ return err | |
+ } | |
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { | |
return err | |
} | |
@@ -614,11 +631,9 @@ func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics * | |
return nil, err | |
} | |
- if t.hostKeyCallback != nil { | |
- err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) | |
- if err != nil { | |
- return nil, err | |
- } | |
+ err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) | |
+ if err != nil { | |
+ return nil, err | |
} | |
return result, nil | |
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go | |
index 1b83112..91d4935 100644 | |
--- a/ssh/handshake_test.go | |
+++ b/ssh/handshake_test.go | |
@@ -40,9 +40,12 @@ func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error | |
// therefore is buffered (net.Pipe deadlocks if both sides start with | |
// a write.) | |
func netPipe() (net.Conn, net.Conn, error) { | |
- listener, err := net.Listen("tcp", ":0") | |
+ listener, err := net.Listen("tcp", "127.0.0.1:0") | |
if err != nil { | |
- return nil, nil, err | |
+ listener, err = net.Listen("tcp", "[::1]:0") | |
+ if err != nil { | |
+ return nil, nil, err | |
+ } | |
} | |
defer listener.Close() | |
c1, err := net.Dial("tcp", listener.Addr().String()) | |
@@ -436,6 +439,7 @@ func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, couple | |
clientConf.SetDefaults() | |
clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) | |
clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} | |
+ clientConn.hostKeyCallback = InsecureIgnoreHostKey() | |
go clientConn.readLoop() | |
go clientConn.kexLoop() | |
@@ -525,3 +529,31 @@ func TestDisconnect(t *testing.T) { | |
t.Errorf("readPacket 3 succeeded") | |
} | |
} | |
+ | |
+func TestHandshakeRekeyDefault(t *testing.T) { | |
+ clientConf := &ClientConfig{ | |
+ Config: Config{ | |
+ Ciphers: []string{"aes128-ctr"}, | |
+ }, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
+ } | |
+ trC, trS, err := handshakePair(clientConf, "addr", false) | |
+ if err != nil { | |
+ t.Fatalf("handshakePair: %v", err) | |
+ } | |
+ defer trC.Close() | |
+ defer trS.Close() | |
+ | |
+ trC.writePacket([]byte{msgRequestSuccess, 0, 0}) | |
+ trC.Close() | |
+ | |
+ rgb := (1024 + trC.readBytesLeft) >> 30 | |
+ wgb := (1024 + trC.writeBytesLeft) >> 30 | |
+ | |
+ if rgb != 64 { | |
+ t.Errorf("got rekey after %dG read, want 64G", rgb) | |
+ } | |
+ if wgb != 64 { | |
+ t.Errorf("got rekey after %dG write, want 64G", wgb) | |
+ } | |
+} | |
diff --git a/ssh/keys.go b/ssh/keys.go | |
index f38de98..cf68532 100644 | |
--- a/ssh/keys.go | |
+++ b/ssh/keys.go | |
@@ -824,7 +824,7 @@ func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) { | |
// Implemented based on the documentation at | |
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key | |
-func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) { | |
+func parseOpenSSHPrivateKey(key []byte) (crypto.PrivateKey, error) { | |
magic := append([]byte("openssh-key-v1"), 0) | |
if !bytes.Equal(magic, key[0:len(magic)]) { | |
return nil, errors.New("ssh: invalid openssh private key format") | |
@@ -844,14 +844,15 @@ func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) { | |
return nil, err | |
} | |
+ if w.KdfName != "none" || w.CipherName != "none" { | |
+ return nil, errors.New("ssh: cannot decode encrypted private keys") | |
+ } | |
+ | |
pk1 := struct { | |
Check1 uint32 | |
Check2 uint32 | |
Keytype string | |
- Pub []byte | |
- Priv []byte | |
- Comment string | |
- Pad []byte `ssh:"rest"` | |
+ Rest []byte `ssh:"rest"` | |
}{} | |
if err := Unmarshal(w.PrivKeyBlock, &pk1); err != nil { | |
@@ -862,24 +863,75 @@ func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) { | |
return nil, errors.New("ssh: checkint mismatch") | |
} | |
- // we only handle ed25519 keys currently | |
- if pk1.Keytype != KeyAlgoED25519 { | |
- return nil, errors.New("ssh: unhandled key type") | |
- } | |
+ // we only handle ed25519 and rsa keys currently | |
+ switch pk1.Keytype { | |
+ case KeyAlgoRSA: | |
+ // https://github.com/openssh/openssh-portable/blob/master/sshkey.c#L2760-L2773 | |
+ key := struct { | |
+ N *big.Int | |
+ E *big.Int | |
+ D *big.Int | |
+ Iqmp *big.Int | |
+ P *big.Int | |
+ Q *big.Int | |
+ Comment string | |
+ Pad []byte `ssh:"rest"` | |
+ }{} | |
+ | |
+ if err := Unmarshal(pk1.Rest, &key); err != nil { | |
+ return nil, err | |
+ } | |
- for i, b := range pk1.Pad { | |
- if int(b) != i+1 { | |
- return nil, errors.New("ssh: padding not as expected") | |
+ for i, b := range key.Pad { | |
+ if int(b) != i+1 { | |
+ return nil, errors.New("ssh: padding not as expected") | |
+ } | |
} | |
- } | |
- if len(pk1.Priv) != ed25519.PrivateKeySize { | |
- return nil, errors.New("ssh: private key unexpected length") | |
- } | |
+ pk := &rsa.PrivateKey{ | |
+ PublicKey: rsa.PublicKey{ | |
+ N: key.N, | |
+ E: int(key.E.Int64()), | |
+ }, | |
+ D: key.D, | |
+ Primes: []*big.Int{key.P, key.Q}, | |
+ } | |
- pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize)) | |
- copy(pk, pk1.Priv) | |
- return &pk, nil | |
+ if err := pk.Validate(); err != nil { | |
+ return nil, err | |
+ } | |
+ | |
+ pk.Precompute() | |
+ | |
+ return pk, nil | |
+ case KeyAlgoED25519: | |
+ key := struct { | |
+ Pub []byte | |
+ Priv []byte | |
+ Comment string | |
+ Pad []byte `ssh:"rest"` | |
+ }{} | |
+ | |
+ if err := Unmarshal(pk1.Rest, &key); err != nil { | |
+ return nil, err | |
+ } | |
+ | |
+ if len(key.Priv) != ed25519.PrivateKeySize { | |
+ return nil, errors.New("ssh: private key unexpected length") | |
+ } | |
+ | |
+ for i, b := range key.Pad { | |
+ if int(b) != i+1 { | |
+ return nil, errors.New("ssh: padding not as expected") | |
+ } | |
+ } | |
+ | |
+ pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize)) | |
+ copy(pk, key.Priv) | |
+ return &pk, nil | |
+ default: | |
+ return nil, errors.New("ssh: unhandled key type") | |
+ } | |
} | |
// FingerprintLegacyMD5 returns the user presentation of the key's | |
diff --git a/ssh/knownhosts/knownhosts.go b/ssh/knownhosts/knownhosts.go | |
new file mode 100644 | |
index 0000000..ea92b29 | |
--- /dev/null | |
+++ b/ssh/knownhosts/knownhosts.go | |
@@ -0,0 +1,546 @@ | |
+// Copyright 2017 The Go Authors. All rights reserved. | |
+// Use of this source code is governed by a BSD-style | |
+// license that can be found in the LICENSE file. | |
+ | |
+// Package knownhosts implements a parser for the OpenSSH | |
+// known_hosts host key database. | |
+package knownhosts | |
+ | |
+import ( | |
+ "bufio" | |
+ "bytes" | |
+ "crypto/hmac" | |
+ "crypto/rand" | |
+ "crypto/sha1" | |
+ "encoding/base64" | |
+ "errors" | |
+ "fmt" | |
+ "io" | |
+ "net" | |
+ "os" | |
+ "strings" | |
+ | |
+ "golang.org/x/crypto/ssh" | |
+) | |
+ | |
+// See the sshd manpage | |
+// (http://man.openbsd.org/sshd#SSH_KNOWN_HOSTS_FILE_FORMAT) for | |
+// background. | |
+ | |
+type addr struct{ host, port string } | |
+ | |
+func (a *addr) String() string { | |
+ h := a.host | |
+ if strings.Contains(h, ":") { | |
+ h = "[" + h + "]" | |
+ } | |
+ return h + ":" + a.port | |
+} | |
+ | |
+type matcher interface { | |
+ match([]addr) bool | |
+} | |
+ | |
+type hostPattern struct { | |
+ negate bool | |
+ addr addr | |
+} | |
+ | |
+func (p *hostPattern) String() string { | |
+ n := "" | |
+ if p.negate { | |
+ n = "!" | |
+ } | |
+ | |
+ return n + p.addr.String() | |
+} | |
+ | |
+type hostPatterns []hostPattern | |
+ | |
+func (ps hostPatterns) match(addrs []addr) bool { | |
+ matched := false | |
+ for _, p := range ps { | |
+ for _, a := range addrs { | |
+ m := p.match(a) | |
+ if !m { | |
+ continue | |
+ } | |
+ if p.negate { | |
+ return false | |
+ } | |
+ matched = true | |
+ } | |
+ } | |
+ return matched | |
+} | |
+ | |
+// See | |
+// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/addrmatch.c | |
+// The matching of * has no regard for separators, unlike filesystem globs | |
+func wildcardMatch(pat []byte, str []byte) bool { | |
+ for { | |
+ if len(pat) == 0 { | |
+ return len(str) == 0 | |
+ } | |
+ if len(str) == 0 { | |
+ return false | |
+ } | |
+ | |
+ if pat[0] == '*' { | |
+ if len(pat) == 1 { | |
+ return true | |
+ } | |
+ | |
+ for j := range str { | |
+ if wildcardMatch(pat[1:], str[j:]) { | |
+ return true | |
+ } | |
+ } | |
+ return false | |
+ } | |
+ | |
+ if pat[0] == '?' || pat[0] == str[0] { | |
+ pat = pat[1:] | |
+ str = str[1:] | |
+ } else { | |
+ return false | |
+ } | |
+ } | |
+} | |
+ | |
+func (l *hostPattern) match(a addr) bool { | |
+ return wildcardMatch([]byte(l.addr.host), []byte(a.host)) && l.addr.port == a.port | |
+} | |
+ | |
+type keyDBLine struct { | |
+ cert bool | |
+ matcher matcher | |
+ knownKey KnownKey | |
+} | |
+ | |
+func serialize(k ssh.PublicKey) string { | |
+ return k.Type() + " " + base64.StdEncoding.EncodeToString(k.Marshal()) | |
+} | |
+ | |
+func (l *keyDBLine) match(addrs []addr) bool { | |
+ return l.matcher.match(addrs) | |
+} | |
+ | |
+type hostKeyDB struct { | |
+ // Serialized version of revoked keys | |
+ revoked map[string]*KnownKey | |
+ lines []keyDBLine | |
+} | |
+ | |
+func newHostKeyDB() *hostKeyDB { | |
+ db := &hostKeyDB{ | |
+ revoked: make(map[string]*KnownKey), | |
+ } | |
+ | |
+ return db | |
+} | |
+ | |
+func keyEq(a, b ssh.PublicKey) bool { | |
+ return bytes.Equal(a.Marshal(), b.Marshal()) | |
+} | |
+ | |
+// IsAuthorityForHost can be used as a callback in ssh.CertChecker | |
+func (db *hostKeyDB) IsHostAuthority(remote ssh.PublicKey, address string) bool { | |
+ h, p, err := net.SplitHostPort(address) | |
+ if err != nil { | |
+ return false | |
+ } | |
+ a := addr{host: h, port: p} | |
+ | |
+ for _, l := range db.lines { | |
+ if l.cert && keyEq(l.knownKey.Key, remote) && l.match([]addr{a}) { | |
+ return true | |
+ } | |
+ } | |
+ return false | |
+} | |
+ | |
+// IsRevoked can be used as a callback in ssh.CertChecker | |
+func (db *hostKeyDB) IsRevoked(key *ssh.Certificate) bool { | |
+ _, ok := db.revoked[string(key.Marshal())] | |
+ return ok | |
+} | |
+ | |
+const markerCert = "@cert-authority" | |
+const markerRevoked = "@revoked" | |
+ | |
+func nextWord(line []byte) (string, []byte) { | |
+ i := bytes.IndexAny(line, "\t ") | |
+ if i == -1 { | |
+ return string(line), nil | |
+ } | |
+ | |
+ return string(line[:i]), bytes.TrimSpace(line[i:]) | |
+} | |
+ | |
+func parseLine(line []byte) (marker, host string, key ssh.PublicKey, err error) { | |
+ if w, next := nextWord(line); w == markerCert || w == markerRevoked { | |
+ marker = w | |
+ line = next | |
+ } | |
+ | |
+ host, line = nextWord(line) | |
+ if len(line) == 0 { | |
+ return "", "", nil, errors.New("knownhosts: missing host pattern") | |
+ } | |
+ | |
+ // ignore the keytype as it's in the key blob anyway. | |
+ _, line = nextWord(line) | |
+ if len(line) == 0 { | |
+ return "", "", nil, errors.New("knownhosts: missing key type pattern") | |
+ } | |
+ | |
+ keyBlob, _ := nextWord(line) | |
+ | |
+ keyBytes, err := base64.StdEncoding.DecodeString(keyBlob) | |
+ if err != nil { | |
+ return "", "", nil, err | |
+ } | |
+ key, err = ssh.ParsePublicKey(keyBytes) | |
+ if err != nil { | |
+ return "", "", nil, err | |
+ } | |
+ | |
+ return marker, host, key, nil | |
+} | |
+ | |
+func (db *hostKeyDB) parseLine(line []byte, filename string, linenum int) error { | |
+ marker, pattern, key, err := parseLine(line) | |
+ if err != nil { | |
+ return err | |
+ } | |
+ | |
+ if marker == markerRevoked { | |
+ db.revoked[string(key.Marshal())] = &KnownKey{ | |
+ Key: key, | |
+ Filename: filename, | |
+ Line: linenum, | |
+ } | |
+ | |
+ return nil | |
+ } | |
+ | |
+ entry := keyDBLine{ | |
+ cert: marker == markerCert, | |
+ knownKey: KnownKey{ | |
+ Filename: filename, | |
+ Line: linenum, | |
+ Key: key, | |
+ }, | |
+ } | |
+ | |
+ if pattern[0] == '|' { | |
+ entry.matcher, err = newHashedHost(pattern) | |
+ } else { | |
+ entry.matcher, err = newHostnameMatcher(pattern) | |
+ } | |
+ | |
+ if err != nil { | |
+ return err | |
+ } | |
+ | |
+ db.lines = append(db.lines, entry) | |
+ return nil | |
+} | |
+ | |
+func newHostnameMatcher(pattern string) (matcher, error) { | |
+ var hps hostPatterns | |
+ for _, p := range strings.Split(pattern, ",") { | |
+ if len(p) == 0 { | |
+ continue | |
+ } | |
+ | |
+ var a addr | |
+ var negate bool | |
+ if p[0] == '!' { | |
+ negate = true | |
+ p = p[1:] | |
+ } | |
+ | |
+ if len(p) == 0 { | |
+ return nil, errors.New("knownhosts: negation without following hostname") | |
+ } | |
+ | |
+ var err error | |
+ if p[0] == '[' { | |
+ a.host, a.port, err = net.SplitHostPort(p) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ } else { | |
+ a.host, a.port, err = net.SplitHostPort(p) | |
+ if err != nil { | |
+ a.host = p | |
+ a.port = "22" | |
+ } | |
+ } | |
+ hps = append(hps, hostPattern{ | |
+ negate: negate, | |
+ addr: a, | |
+ }) | |
+ } | |
+ return hps, nil | |
+} | |
+ | |
+// KnownKey represents a key declared in a known_hosts file. | |
+type KnownKey struct { | |
+ Key ssh.PublicKey | |
+ Filename string | |
+ Line int | |
+} | |
+ | |
+func (k *KnownKey) String() string { | |
+ return fmt.Sprintf("%s:%d: %s", k.Filename, k.Line, serialize(k.Key)) | |
+} | |
+ | |
+// KeyError is returned if we did not find the key in the host key | |
+// database, or there was a mismatch. Typically, in batch | |
+// applications, this should be interpreted as failure. Interactive | |
+// applications can offer an interactive prompt to the user. | |
+type KeyError struct { | |
+ // Want holds the accepted host keys. For each key algorithm, | |
+ // there can be one hostkey. If Want is empty, the host is | |
+ // unknown. If Want is non-empty, there was a mismatch, which | |
+ // can signify a MITM attack. | |
+ Want []KnownKey | |
+} | |
+ | |
+func (u *KeyError) Error() string { | |
+ if len(u.Want) == 0 { | |
+ return "knownhosts: key is unknown" | |
+ } | |
+ return "knownhosts: key mismatch" | |
+} | |
+ | |
+// RevokedError is returned if we found a key that was revoked. | |
+type RevokedError struct { | |
+ Revoked KnownKey | |
+} | |
+ | |
+func (r *RevokedError) Error() string { | |
+ return "knownhosts: key is revoked" | |
+} | |
+ | |
+// check checks a key against the host database. This should not be | |
+// used for verifying certificates. | |
+func (db *hostKeyDB) check(address string, remote net.Addr, remoteKey ssh.PublicKey) error { | |
+ if revoked := db.revoked[string(remoteKey.Marshal())]; revoked != nil { | |
+ return &RevokedError{Revoked: *revoked} | |
+ } | |
+ | |
+ host, port, err := net.SplitHostPort(remote.String()) | |
+ if err != nil { | |
+ return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", remote, err) | |
+ } | |
+ | |
+ addrs := []addr{ | |
+ {host, port}, | |
+ } | |
+ | |
+ if address != "" { | |
+ host, port, err := net.SplitHostPort(address) | |
+ if err != nil { | |
+ return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", address, err) | |
+ } | |
+ | |
+ addrs = append(addrs, addr{host, port}) | |
+ } | |
+ | |
+ return db.checkAddrs(addrs, remoteKey) | |
+} | |
+ | |
+// checkAddrs checks if we can find the given public key for any of | |
+// the given addresses. If we only find an entry for the IP address, | |
+// or only the hostname, then this still succeeds. | |
+func (db *hostKeyDB) checkAddrs(addrs []addr, remoteKey ssh.PublicKey) error { | |
+ // TODO(hanwen): are these the right semantics? What if there | |
+ // is just a key for the IP address, but not for the | |
+ // hostname? | |
+ | |
+ // Algorithm => key. | |
+ knownKeys := map[string]KnownKey{} | |
+ for _, l := range db.lines { | |
+ if l.match(addrs) { | |
+ typ := l.knownKey.Key.Type() | |
+ if _, ok := knownKeys[typ]; !ok { | |
+ knownKeys[typ] = l.knownKey | |
+ } | |
+ } | |
+ } | |
+ | |
+ keyErr := &KeyError{} | |
+ for _, v := range knownKeys { | |
+ keyErr.Want = append(keyErr.Want, v) | |
+ } | |
+ | |
+ // Unknown remote host. | |
+ if len(knownKeys) == 0 { | |
+ return keyErr | |
+ } | |
+ | |
+ // If the remote host starts using a different, unknown key type, we | |
+ // also interpret that as a mismatch. | |
+ if known, ok := knownKeys[remoteKey.Type()]; !ok || !keyEq(known.Key, remoteKey) { | |
+ return keyErr | |
+ } | |
+ | |
+ return nil | |
+} | |
+ | |
+// The Read function parses file contents. | |
+func (db *hostKeyDB) Read(r io.Reader, filename string) error { | |
+ scanner := bufio.NewScanner(r) | |
+ | |
+ lineNum := 0 | |
+ for scanner.Scan() { | |
+ lineNum++ | |
+ line := scanner.Bytes() | |
+ line = bytes.TrimSpace(line) | |
+ if len(line) == 0 || line[0] == '#' { | |
+ continue | |
+ } | |
+ | |
+ if err := db.parseLine(line, filename, lineNum); err != nil { | |
+ return fmt.Errorf("knownhosts: %s:%d: %v", filename, lineNum, err) | |
+ } | |
+ } | |
+ return scanner.Err() | |
+} | |
+ | |
+// New creates a host key callback from the given OpenSSH host key | |
+// files. The returned callback is for use in | |
+// ssh.ClientConfig.HostKeyCallback. Hashed hostnames are not supported. | |
+func New(files ...string) (ssh.HostKeyCallback, error) { | |
+ db := newHostKeyDB() | |
+ for _, fn := range files { | |
+ f, err := os.Open(fn) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ defer f.Close() | |
+ if err := db.Read(f, fn); err != nil { | |
+ return nil, err | |
+ } | |
+ } | |
+ | |
+ var certChecker ssh.CertChecker | |
+ certChecker.IsHostAuthority = db.IsHostAuthority | |
+ certChecker.IsRevoked = db.IsRevoked | |
+ certChecker.HostKeyFallback = db.check | |
+ | |
+ return certChecker.CheckHostKey, nil | |
+} | |
+ | |
+// Normalize normalizes an address into the form used in known_hosts | |
+func Normalize(address string) string { | |
+ host, port, err := net.SplitHostPort(address) | |
+ if err != nil { | |
+ host = address | |
+ port = "22" | |
+ } | |
+ entry := host | |
+ if port != "22" { | |
+ entry = "[" + entry + "]:" + port | |
+ } else if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") { | |
+ entry = "[" + entry + "]" | |
+ } | |
+ return entry | |
+} | |
+ | |
+// Line returns a line to add append to the known_hosts files. | |
+func Line(addresses []string, key ssh.PublicKey) string { | |
+ var trimmed []string | |
+ for _, a := range addresses { | |
+ trimmed = append(trimmed, Normalize(a)) | |
+ } | |
+ | |
+ return strings.Join(trimmed, ",") + " " + serialize(key) | |
+} | |
+ | |
+// HashHostname hashes the given hostname. The hostname is not | |
+// normalized before hashing. | |
+func HashHostname(hostname string) string { | |
+ // TODO(hanwen): check if we can safely normalize this always. | |
+ salt := make([]byte, sha1.Size) | |
+ | |
+ _, err := rand.Read(salt) | |
+ if err != nil { | |
+ panic(fmt.Sprintf("crypto/rand failure %v", err)) | |
+ } | |
+ | |
+ hash := hashHost(hostname, salt) | |
+ return encodeHash(sha1HashType, salt, hash) | |
+} | |
+ | |
+func decodeHash(encoded string) (hashType string, salt, hash []byte, err error) { | |
+ if len(encoded) == 0 || encoded[0] != '|' { | |
+ err = errors.New("knownhosts: hashed host must start with '|'") | |
+ return | |
+ } | |
+ components := strings.Split(encoded, "|") | |
+ if len(components) != 4 { | |
+ err = fmt.Errorf("knownhosts: got %d components, want 3", len(components)) | |
+ return | |
+ } | |
+ | |
+ hashType = components[1] | |
+ if salt, err = base64.StdEncoding.DecodeString(components[2]); err != nil { | |
+ return | |
+ } | |
+ if hash, err = base64.StdEncoding.DecodeString(components[3]); err != nil { | |
+ return | |
+ } | |
+ return | |
+} | |
+ | |
+func encodeHash(typ string, salt []byte, hash []byte) string { | |
+ return strings.Join([]string{"", | |
+ typ, | |
+ base64.StdEncoding.EncodeToString(salt), | |
+ base64.StdEncoding.EncodeToString(hash), | |
+ }, "|") | |
+} | |
+ | |
+// See https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120 | |
+func hashHost(hostname string, salt []byte) []byte { | |
+ mac := hmac.New(sha1.New, salt) | |
+ mac.Write([]byte(hostname)) | |
+ return mac.Sum(nil) | |
+} | |
+ | |
+type hashedHost struct { | |
+ salt []byte | |
+ hash []byte | |
+} | |
+ | |
+const sha1HashType = "1" | |
+ | |
+func newHashedHost(encoded string) (*hashedHost, error) { | |
+ typ, salt, hash, err := decodeHash(encoded) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ | |
+ // The type field seems for future algorithm agility, but it's | |
+ // actually hardcoded in openssh currently, see | |
+ // https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120 | |
+ if typ != sha1HashType { | |
+ return nil, fmt.Errorf("knownhosts: got hash type %s, must be '1'", typ) | |
+ } | |
+ | |
+ return &hashedHost{salt: salt, hash: hash}, nil | |
+} | |
+ | |
+func (h *hashedHost) match(addrs []addr) bool { | |
+ for _, a := range addrs { | |
+ if bytes.Equal(hashHost(Normalize(a.String()), h.salt), h.hash) { | |
+ return true | |
+ } | |
+ } | |
+ return false | |
+} | |
diff --git a/ssh/knownhosts/knownhosts_test.go b/ssh/knownhosts/knownhosts_test.go | |
new file mode 100644 | |
index 0000000..be7cc0e | |
--- /dev/null | |
+++ b/ssh/knownhosts/knownhosts_test.go | |
@@ -0,0 +1,329 @@ | |
+// Copyright 2017 The Go Authors. All rights reserved. | |
+// Use of this source code is governed by a BSD-style | |
+// license that can be found in the LICENSE file. | |
+ | |
+package knownhosts | |
+ | |
+import ( | |
+ "bytes" | |
+ "fmt" | |
+ "net" | |
+ "reflect" | |
+ "testing" | |
+ | |
+ "golang.org/x/crypto/ssh" | |
+) | |
+ | |
+const edKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGBAarftlLeoyf+v+nVchEZII/vna2PCV8FaX4vsF5BX" | |
+const alternateEdKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIIXffBYeYL+WVzVru8npl5JHt2cjlr4ornFTWzoij9sx" | |
+const ecKeyStr = "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNLCu01+wpXe3xB5olXCN4SqU2rQu0qjSRKJO4Bg+JRCPU+ENcgdA5srTU8xYDz/GEa4dzK5ldPw4J/gZgSXCMs=" | |
+ | |
+var ecKey, alternateEdKey, edKey ssh.PublicKey | |
+var testAddr = &net.TCPAddr{ | |
+ IP: net.IP{198, 41, 30, 196}, | |
+ Port: 22, | |
+} | |
+ | |
+var testAddr6 = &net.TCPAddr{ | |
+ IP: net.IP{198, 41, 30, 196, | |
+ 1, 2, 3, 4, | |
+ 1, 2, 3, 4, | |
+ 1, 2, 3, 4, | |
+ }, | |
+ Port: 22, | |
+} | |
+ | |
+func init() { | |
+ var err error | |
+ ecKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(ecKeyStr)) | |
+ if err != nil { | |
+ panic(err) | |
+ } | |
+ edKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(edKeyStr)) | |
+ if err != nil { | |
+ panic(err) | |
+ } | |
+ alternateEdKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(alternateEdKeyStr)) | |
+ if err != nil { | |
+ panic(err) | |
+ } | |
+} | |
+ | |
+func testDB(t *testing.T, s string) *hostKeyDB { | |
+ db := newHostKeyDB() | |
+ if err := db.Read(bytes.NewBufferString(s), "testdb"); err != nil { | |
+ t.Fatalf("Read: %v", err) | |
+ } | |
+ | |
+ return db | |
+} | |
+ | |
+func TestRevoked(t *testing.T) { | |
+ db := testDB(t, "\n\n@revoked * "+edKeyStr+"\n") | |
+ want := &RevokedError{ | |
+ Revoked: KnownKey{ | |
+ Key: edKey, | |
+ Filename: "testdb", | |
+ Line: 3, | |
+ }, | |
+ } | |
+ if err := db.check("", &net.TCPAddr{ | |
+ Port: 42, | |
+ }, edKey); err == nil { | |
+ t.Fatal("no error for revoked key") | |
+ } else if !reflect.DeepEqual(want, err) { | |
+ t.Fatalf("got %#v, want %#v", want, err) | |
+ } | |
+} | |
+ | |
+func TestHostAuthority(t *testing.T) { | |
+ for _, m := range []struct { | |
+ authorityFor string | |
+ address string | |
+ | |
+ good bool | |
+ }{ | |
+ {authorityFor: "localhost", address: "localhost:22", good: true}, | |
+ {authorityFor: "localhost", address: "localhost", good: false}, | |
+ {authorityFor: "localhost", address: "localhost:1234", good: false}, | |
+ {authorityFor: "[localhost]:1234", address: "localhost:1234", good: true}, | |
+ {authorityFor: "[localhost]:1234", address: "localhost:22", good: false}, | |
+ {authorityFor: "[localhost]:1234", address: "localhost", good: false}, | |
+ } { | |
+ db := testDB(t, `@cert-authority `+m.authorityFor+` `+edKeyStr) | |
+ if ok := db.IsHostAuthority(db.lines[0].knownKey.Key, m.address); ok != m.good { | |
+ t.Errorf("IsHostAuthority: authority %s, address %s, wanted good = %v, got good = %v", | |
+ m.authorityFor, m.address, m.good, ok) | |
+ } | |
+ } | |
+} | |
+ | |
+func TestBracket(t *testing.T) { | |
+ db := testDB(t, `[git.eclipse.org]:29418,[198.41.30.196]:29418 `+edKeyStr) | |
+ | |
+ if err := db.check("git.eclipse.org:29418", &net.TCPAddr{ | |
+ IP: net.IP{198, 41, 30, 196}, | |
+ Port: 29418, | |
+ }, edKey); err != nil { | |
+ t.Errorf("got error %v, want none", err) | |
+ } | |
+ | |
+ if err := db.check("git.eclipse.org:29419", &net.TCPAddr{ | |
+ Port: 42, | |
+ }, edKey); err == nil { | |
+ t.Fatalf("no error for unknown address") | |
+ } else if ke, ok := err.(*KeyError); !ok { | |
+ t.Fatalf("got type %T, want *KeyError", err) | |
+ } else if len(ke.Want) > 0 { | |
+ t.Fatalf("got Want %v, want []", ke.Want) | |
+ } | |
+} | |
+ | |
+func TestNewKeyType(t *testing.T) { | |
+ str := fmt.Sprintf("%s %s", testAddr, edKeyStr) | |
+ db := testDB(t, str) | |
+ if err := db.check("", testAddr, ecKey); err == nil { | |
+ t.Fatalf("no error for unknown address") | |
+ } else if ke, ok := err.(*KeyError); !ok { | |
+ t.Fatalf("got type %T, want *KeyError", err) | |
+ } else if len(ke.Want) == 0 { | |
+ t.Fatalf("got empty KeyError.Want") | |
+ } | |
+} | |
+ | |
+func TestSameKeyType(t *testing.T) { | |
+ str := fmt.Sprintf("%s %s", testAddr, edKeyStr) | |
+ db := testDB(t, str) | |
+ if err := db.check("", testAddr, alternateEdKey); err == nil { | |
+ t.Fatalf("no error for unknown address") | |
+ } else if ke, ok := err.(*KeyError); !ok { | |
+ t.Fatalf("got type %T, want *KeyError", err) | |
+ } else if len(ke.Want) == 0 { | |
+ t.Fatalf("got empty KeyError.Want") | |
+ } else if got, want := ke.Want[0].Key.Marshal(), edKey.Marshal(); !bytes.Equal(got, want) { | |
+ t.Fatalf("got key %q, want %q", got, want) | |
+ } | |
+} | |
+ | |
+func TestIPAddress(t *testing.T) { | |
+ str := fmt.Sprintf("%s %s", testAddr, edKeyStr) | |
+ db := testDB(t, str) | |
+ if err := db.check("", testAddr, edKey); err != nil { | |
+ t.Errorf("got error %q, want none", err) | |
+ } | |
+} | |
+ | |
+func TestIPv6Address(t *testing.T) { | |
+ str := fmt.Sprintf("%s %s", testAddr6, edKeyStr) | |
+ db := testDB(t, str) | |
+ | |
+ if err := db.check("", testAddr6, edKey); err != nil { | |
+ t.Errorf("got error %q, want none", err) | |
+ } | |
+} | |
+ | |
+func TestBasic(t *testing.T) { | |
+ str := fmt.Sprintf("#comment\n\nserver.org,%s %s\notherhost %s", testAddr, edKeyStr, ecKeyStr) | |
+ db := testDB(t, str) | |
+ if err := db.check("server.org:22", testAddr, edKey); err != nil { | |
+ t.Errorf("got error %q, want none", err) | |
+ } | |
+ | |
+ want := KnownKey{ | |
+ Key: edKey, | |
+ Filename: "testdb", | |
+ Line: 3, | |
+ } | |
+ if err := db.check("server.org:22", testAddr, ecKey); err == nil { | |
+ t.Errorf("succeeded, want KeyError") | |
+ } else if ke, ok := err.(*KeyError); !ok { | |
+ t.Errorf("got %T, want *KeyError", err) | |
+ } else if len(ke.Want) != 1 { | |
+ t.Errorf("got %v, want 1 entry", ke) | |
+ } else if !reflect.DeepEqual(ke.Want[0], want) { | |
+ t.Errorf("got %v, want %v", ke.Want[0], want) | |
+ } | |
+} | |
+ | |
+func TestNegate(t *testing.T) { | |
+ str := fmt.Sprintf("%s,!server.org %s", testAddr, edKeyStr) | |
+ db := testDB(t, str) | |
+ if err := db.check("server.org:22", testAddr, ecKey); err == nil { | |
+ t.Errorf("succeeded") | |
+ } else if ke, ok := err.(*KeyError); !ok { | |
+ t.Errorf("got error type %T, want *KeyError", err) | |
+ } else if len(ke.Want) != 0 { | |
+ t.Errorf("got expected keys %d (first of type %s), want []", len(ke.Want), ke.Want[0].Key.Type()) | |
+ } | |
+} | |
+ | |
+func TestWildcard(t *testing.T) { | |
+ str := fmt.Sprintf("server*.domain %s", edKeyStr) | |
+ db := testDB(t, str) | |
+ | |
+ want := &KeyError{ | |
+ Want: []KnownKey{{ | |
+ Filename: "testdb", | |
+ Line: 1, | |
+ Key: edKey, | |
+ }}, | |
+ } | |
+ | |
+ got := db.check("server.domain:22", &net.TCPAddr{}, ecKey) | |
+ if !reflect.DeepEqual(got, want) { | |
+ t.Errorf("got %s, want %s", got, want) | |
+ } | |
+} | |
+ | |
+func TestLine(t *testing.T) { | |
+ for in, want := range map[string]string{ | |
+ "server.org": "server.org " + edKeyStr, | |
+ "server.org:22": "server.org " + edKeyStr, | |
+ "server.org:23": "[server.org]:23 " + edKeyStr, | |
+ "[c629:1ec4:102:304:102:304:102:304]:22": "[c629:1ec4:102:304:102:304:102:304] " + edKeyStr, | |
+ "[c629:1ec4:102:304:102:304:102:304]:23": "[c629:1ec4:102:304:102:304:102:304]:23 " + edKeyStr, | |
+ } { | |
+ if got := Line([]string{in}, edKey); got != want { | |
+ t.Errorf("Line(%q) = %q, want %q", in, got, want) | |
+ } | |
+ } | |
+} | |
+ | |
+func TestWildcardMatch(t *testing.T) { | |
+ for _, c := range []struct { | |
+ pat, str string | |
+ want bool | |
+ }{ | |
+ {"a?b", "abb", true}, | |
+ {"ab", "abc", false}, | |
+ {"abc", "ab", false}, | |
+ {"a*b", "axxxb", true}, | |
+ {"a*b", "axbxb", true}, | |
+ {"a*b", "axbxbc", false}, | |
+ {"a*?", "axbxc", true}, | |
+ {"a*b*", "axxbxxxxxx", true}, | |
+ {"a*b*c", "axxbxxxxxxc", true}, | |
+ {"a*b*?", "axxbxxxxxxc", true}, | |
+ {"a*b*z", "axxbxxbxxxz", true}, | |
+ {"a*b*z", "axxbxxzxxxz", true}, | |
+ {"a*b*z", "axxbxxzxxx", false}, | |
+ } { | |
+ got := wildcardMatch([]byte(c.pat), []byte(c.str)) | |
+ if got != c.want { | |
+ t.Errorf("wildcardMatch(%q, %q) = %v, want %v", c.pat, c.str, got, c.want) | |
+ } | |
+ | |
+ } | |
+} | |
+ | |
+// TODO(hanwen): test coverage for certificates. | |
+ | |
+const testHostname = "hostname" | |
+ | |
+// generated with keygen -H -f | |
+const encodedTestHostnameHash = "|1|IHXZvQMvTcZTUU29+2vXFgx8Frs=|UGccIWfRVDwilMBnA3WJoRAC75Y=" | |
+ | |
+func TestHostHash(t *testing.T) { | |
+ testHostHash(t, testHostname, encodedTestHostnameHash) | |
+} | |
+ | |
+func TestHashList(t *testing.T) { | |
+ encoded := HashHostname(testHostname) | |
+ testHostHash(t, testHostname, encoded) | |
+} | |
+ | |
+func testHostHash(t *testing.T, hostname, encoded string) { | |
+ typ, salt, hash, err := decodeHash(encoded) | |
+ if err != nil { | |
+ t.Fatalf("decodeHash: %v", err) | |
+ } | |
+ | |
+ if got := encodeHash(typ, salt, hash); got != encoded { | |
+ t.Errorf("got encoding %s want %s", got, encoded) | |
+ } | |
+ | |
+ if typ != sha1HashType { | |
+ t.Fatalf("got hash type %q, want %q", typ, sha1HashType) | |
+ } | |
+ | |
+ got := hashHost(hostname, salt) | |
+ if !bytes.Equal(got, hash) { | |
+ t.Errorf("got hash %x want %x", got, hash) | |
+ } | |
+} | |
+ | |
+func TestNormalize(t *testing.T) { | |
+ for in, want := range map[string]string{ | |
+ "127.0.0.1:22": "127.0.0.1", | |
+ "[127.0.0.1]:22": "127.0.0.1", | |
+ "[127.0.0.1]:23": "[127.0.0.1]:23", | |
+ "127.0.0.1:23": "[127.0.0.1]:23", | |
+ "[a.b.c]:22": "a.b.c", | |
+ "[abcd:abcd:abcd:abcd]": "[abcd:abcd:abcd:abcd]", | |
+ "[abcd:abcd:abcd:abcd]:22": "[abcd:abcd:abcd:abcd]", | |
+ "[abcd:abcd:abcd:abcd]:23": "[abcd:abcd:abcd:abcd]:23", | |
+ } { | |
+ got := Normalize(in) | |
+ if got != want { | |
+ t.Errorf("Normalize(%q) = %q, want %q", in, got, want) | |
+ } | |
+ } | |
+} | |
+ | |
+func TestHashedHostkeyCheck(t *testing.T) { | |
+ str := fmt.Sprintf("%s %s", HashHostname(testHostname), edKeyStr) | |
+ db := testDB(t, str) | |
+ if err := db.check(testHostname+":22", testAddr, edKey); err != nil { | |
+ t.Errorf("check(%s): %v", testHostname, err) | |
+ } | |
+ want := &KeyError{ | |
+ Want: []KnownKey{{ | |
+ Filename: "testdb", | |
+ Line: 1, | |
+ Key: edKey, | |
+ }}, | |
+ } | |
+ if got := db.check(testHostname+":22", testAddr, alternateEdKey); !reflect.DeepEqual(got, want) { | |
+ t.Errorf("got error %v, want %v", got, want) | |
+ } | |
+} | |
diff --git a/ssh/server.go b/ssh/server.go | |
index 77c84d1..23b41d9 100644 | |
--- a/ssh/server.go | |
+++ b/ssh/server.go | |
@@ -45,6 +45,12 @@ type ServerConfig struct { | |
// authenticating. | |
NoClientAuth bool | |
+ // MaxAuthTries specifies the maximum number of authentication attempts | |
+ // permitted per connection. If set to a negative number, the number of | |
+ // attempts are unlimited. If set to zero, the number of attempts are limited | |
+ // to 6. | |
+ MaxAuthTries int | |
+ | |
// PasswordCallback, if non-nil, is called when a user | |
// attempts to authenticate using a password. | |
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) | |
@@ -143,6 +149,10 @@ type ServerConn struct { | |
func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) { | |
fullConf := *config | |
fullConf.SetDefaults() | |
+ if fullConf.MaxAuthTries == 0 { | |
+ fullConf.MaxAuthTries = 6 | |
+ } | |
+ | |
s := &connection{ | |
sshConn: sshConn{conn: c}, | |
} | |
@@ -267,8 +277,23 @@ func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, err | |
var cache pubKeyCache | |
var perms *Permissions | |
+ authFailures := 0 | |
+ | |
userAuthLoop: | |
for { | |
+ if authFailures >= config.MaxAuthTries && config.MaxAuthTries > 0 { | |
+ discMsg := &disconnectMsg{ | |
+ Reason: 2, | |
+ Message: "too many authentication failures", | |
+ } | |
+ | |
+ if err := s.transport.writePacket(Marshal(discMsg)); err != nil { | |
+ return nil, err | |
+ } | |
+ | |
+ return nil, discMsg | |
+ } | |
+ | |
var userAuthReq userAuthRequestMsg | |
if packet, err := s.transport.readPacket(); err != nil { | |
return nil, err | |
@@ -289,6 +314,11 @@ userAuthLoop: | |
if config.NoClientAuth { | |
authErr = nil | |
} | |
+ | |
+ // allow initial attempt of 'none' without penalty | |
+ if authFailures == 0 { | |
+ authFailures-- | |
+ } | |
case "password": | |
if config.PasswordCallback == nil { | |
authErr = errors.New("ssh: password auth not configured") | |
@@ -360,6 +390,7 @@ userAuthLoop: | |
if isQuery { | |
// The client can query if the given public key | |
// would be okay. | |
+ | |
if len(payload) > 0 { | |
return nil, parseError(msgUserAuthRequest) | |
} | |
@@ -409,6 +440,8 @@ userAuthLoop: | |
break userAuthLoop | |
} | |
+ authFailures++ | |
+ | |
var failureMsg userAuthFailureMsg | |
if config.PasswordCallback != nil { | |
failureMsg.Methods = append(failureMsg.Methods, "password") | |
diff --git a/ssh/session_test.go b/ssh/session_test.go | |
index f35a378..7dce6dd 100644 | |
--- a/ssh/session_test.go | |
+++ b/ssh/session_test.go | |
@@ -59,7 +59,8 @@ func dial(handler serverType, t *testing.T) *Client { | |
}() | |
config := &ClientConfig{ | |
- User: "testuser", | |
+ User: "testuser", | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
conn, chans, reqs, err := NewClientConn(c2, "", config) | |
@@ -641,7 +642,8 @@ func TestSessionID(t *testing.T) { | |
} | |
serverConf.AddHostKey(testSigners["ecdsa"]) | |
clientConf := &ClientConfig{ | |
- User: "user", | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
+ User: "user", | |
} | |
go func() { | |
@@ -747,7 +749,9 @@ func TestHostKeyAlgorithms(t *testing.T) { | |
// By default, we get the preferred algorithm, which is ECDSA 256. | |
- clientConf := &ClientConfig{} | |
+ clientConf := &ClientConfig{ | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
+ } | |
connect(clientConf, KeyAlgoECDSA256) | |
// Client asks for RSA explicitly. | |
diff --git a/ssh/streamlocal.go b/ssh/streamlocal.go | |
new file mode 100644 | |
index 0000000..a2dccc6 | |
--- /dev/null | |
+++ b/ssh/streamlocal.go | |
@@ -0,0 +1,115 @@ | |
+package ssh | |
+ | |
+import ( | |
+ "errors" | |
+ "io" | |
+ "net" | |
+) | |
+ | |
+// streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message | |
+// with "[email protected]" string. | |
+// | |
+// See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding | |
+// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235 | |
+type streamLocalChannelOpenDirectMsg struct { | |
+ socketPath string | |
+ reserved0 string | |
+ reserved1 uint32 | |
+} | |
+ | |
+// forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message | |
+// with "[email protected]" string. | |
+type forwardedStreamLocalPayload struct { | |
+ SocketPath string | |
+ Reserved0 string | |
+} | |
+ | |
+// streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message | |
+// with "[email protected]"/"[email protected]" string. | |
+type streamLocalChannelForwardMsg struct { | |
+ socketPath string | |
+} | |
+ | |
+// ListenUnix is similar to ListenTCP but uses a Unix domain socket. | |
+func (c *Client) ListenUnix(socketPath string) (net.Listener, error) { | |
+ m := streamLocalChannelForwardMsg{ | |
+ socketPath, | |
+ } | |
+ // send message | |
+ ok, _, err := c.SendRequest("[email protected]", true, Marshal(&m)) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ if !ok { | |
+ return nil, errors.New("ssh: [email protected] request denied by peer") | |
+ } | |
+ ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"}) | |
+ | |
+ return &unixListener{socketPath, c, ch}, nil | |
+} | |
+ | |
+func (c *Client) dialStreamLocal(socketPath string) (Channel, error) { | |
+ msg := streamLocalChannelOpenDirectMsg{ | |
+ socketPath: socketPath, | |
+ } | |
+ ch, in, err := c.OpenChannel("[email protected]", Marshal(&msg)) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ go DiscardRequests(in) | |
+ return ch, err | |
+} | |
+ | |
+type unixListener struct { | |
+ socketPath string | |
+ | |
+ conn *Client | |
+ in <-chan forward | |
+} | |
+ | |
+// Accept waits for and returns the next connection to the listener. | |
+func (l *unixListener) Accept() (net.Conn, error) { | |
+ s, ok := <-l.in | |
+ if !ok { | |
+ return nil, io.EOF | |
+ } | |
+ ch, incoming, err := s.newCh.Accept() | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ go DiscardRequests(incoming) | |
+ | |
+ return &chanConn{ | |
+ Channel: ch, | |
+ laddr: &net.UnixAddr{ | |
+ Name: l.socketPath, | |
+ Net: "unix", | |
+ }, | |
+ raddr: &net.UnixAddr{ | |
+ Name: "@", | |
+ Net: "unix", | |
+ }, | |
+ }, nil | |
+} | |
+ | |
+// Close closes the listener. | |
+func (l *unixListener) Close() error { | |
+ // this also closes the listener. | |
+ l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"}) | |
+ m := streamLocalChannelForwardMsg{ | |
+ l.socketPath, | |
+ } | |
+ ok, _, err := l.conn.SendRequest("[email protected]", true, Marshal(&m)) | |
+ if err == nil && !ok { | |
+ err = errors.New("ssh: [email protected] failed") | |
+ } | |
+ return err | |
+} | |
+ | |
+// Addr returns the listener's network address. | |
+func (l *unixListener) Addr() net.Addr { | |
+ return &net.UnixAddr{ | |
+ Name: l.socketPath, | |
+ Net: "unix", | |
+ } | |
+} | |
diff --git a/ssh/tcpip.go b/ssh/tcpip.go | |
index 6151241..acf1717 100644 | |
--- a/ssh/tcpip.go | |
+++ b/ssh/tcpip.go | |
@@ -20,12 +20,20 @@ import ( | |
// addr. Incoming connections will be available by calling Accept on | |
// the returned net.Listener. The listener must be serviced, or the | |
// SSH connection may hang. | |
+// N must be "tcp", "tcp4", "tcp6", or "unix". | |
func (c *Client) Listen(n, addr string) (net.Listener, error) { | |
- laddr, err := net.ResolveTCPAddr(n, addr) | |
- if err != nil { | |
- return nil, err | |
+ switch n { | |
+ case "tcp", "tcp4", "tcp6": | |
+ laddr, err := net.ResolveTCPAddr(n, addr) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ return c.ListenTCP(laddr) | |
+ case "unix": | |
+ return c.ListenUnix(addr) | |
+ default: | |
+ return nil, fmt.Errorf("ssh: unsupported protocol: %s", n) | |
} | |
- return c.ListenTCP(laddr) | |
} | |
// Automatic port allocation is broken with OpenSSH before 6.0. See | |
@@ -116,7 +124,7 @@ func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) { | |
} | |
// Register this forward, using the port number we obtained. | |
- ch := c.forwards.add(*laddr) | |
+ ch := c.forwards.add(laddr) | |
return &tcpListener{laddr, c, ch}, nil | |
} | |
@@ -131,7 +139,7 @@ type forwardList struct { | |
// forwardEntry represents an established mapping of a laddr on a | |
// remote ssh server to a channel connected to a tcpListener. | |
type forwardEntry struct { | |
- laddr net.TCPAddr | |
+ laddr net.Addr | |
c chan forward | |
} | |
@@ -139,16 +147,16 @@ type forwardEntry struct { | |
// arguments to add/remove/lookup should be address as specified in | |
// the original forward-request. | |
type forward struct { | |
- newCh NewChannel // the ssh client channel underlying this forward | |
- raddr *net.TCPAddr // the raddr of the incoming connection | |
+ newCh NewChannel // the ssh client channel underlying this forward | |
+ raddr net.Addr // the raddr of the incoming connection | |
} | |
-func (l *forwardList) add(addr net.TCPAddr) chan forward { | |
+func (l *forwardList) add(addr net.Addr) chan forward { | |
l.Lock() | |
defer l.Unlock() | |
f := forwardEntry{ | |
- addr, | |
- make(chan forward, 1), | |
+ laddr: addr, | |
+ c: make(chan forward, 1), | |
} | |
l.entries = append(l.entries, f) | |
return f.c | |
@@ -176,44 +184,69 @@ func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) { | |
func (l *forwardList) handleChannels(in <-chan NewChannel) { | |
for ch := range in { | |
- var payload forwardedTCPPayload | |
- if err := Unmarshal(ch.ExtraData(), &payload); err != nil { | |
- ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error()) | |
- continue | |
+ var ( | |
+ laddr net.Addr | |
+ raddr net.Addr | |
+ err error | |
+ ) | |
+ switch channelType := ch.ChannelType(); channelType { | |
+ case "forwarded-tcpip": | |
+ var payload forwardedTCPPayload | |
+ if err = Unmarshal(ch.ExtraData(), &payload); err != nil { | |
+ ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error()) | |
+ continue | |
+ } | |
+ | |
+ // RFC 4254 section 7.2 specifies that incoming | |
+ // addresses should list the address, in string | |
+ // format. It is implied that this should be an IP | |
+ // address, as it would be impossible to connect to it | |
+ // otherwise. | |
+ laddr, err = parseTCPAddr(payload.Addr, payload.Port) | |
+ if err != nil { | |
+ ch.Reject(ConnectionFailed, err.Error()) | |
+ continue | |
+ } | |
+ raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort) | |
+ if err != nil { | |
+ ch.Reject(ConnectionFailed, err.Error()) | |
+ continue | |
+ } | |
+ | |
+ case "[email protected]": | |
+ var payload forwardedStreamLocalPayload | |
+ if err = Unmarshal(ch.ExtraData(), &payload); err != nil { | |
+ ch.Reject(ConnectionFailed, "could not parse [email protected] payload: "+err.Error()) | |
+ continue | |
+ } | |
+ laddr = &net.UnixAddr{ | |
+ Name: payload.SocketPath, | |
+ Net: "unix", | |
+ } | |
+ raddr = &net.UnixAddr{ | |
+ Name: "@", | |
+ Net: "unix", | |
+ } | |
+ default: | |
+ panic(fmt.Errorf("ssh: unknown channel type %s", channelType)) | |
} | |
- | |
- // RFC 4254 section 7.2 specifies that incoming | |
- // addresses should list the address, in string | |
- // format. It is implied that this should be an IP | |
- // address, as it would be impossible to connect to it | |
- // otherwise. | |
- laddr, err := parseTCPAddr(payload.Addr, payload.Port) | |
- if err != nil { | |
- ch.Reject(ConnectionFailed, err.Error()) | |
- continue | |
- } | |
- raddr, err := parseTCPAddr(payload.OriginAddr, payload.OriginPort) | |
- if err != nil { | |
- ch.Reject(ConnectionFailed, err.Error()) | |
- continue | |
- } | |
- | |
- if ok := l.forward(*laddr, *raddr, ch); !ok { | |
+ if ok := l.forward(laddr, raddr, ch); !ok { | |
// Section 7.2, implementations MUST reject spurious incoming | |
// connections. | |
ch.Reject(Prohibited, "no forward for address") | |
continue | |
} | |
+ | |
} | |
} | |
// remove removes the forward entry, and the channel feeding its | |
// listener. | |
-func (l *forwardList) remove(addr net.TCPAddr) { | |
+func (l *forwardList) remove(addr net.Addr) { | |
l.Lock() | |
defer l.Unlock() | |
for i, f := range l.entries { | |
- if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port { | |
+ if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() { | |
l.entries = append(l.entries[:i], l.entries[i+1:]...) | |
close(f.c) | |
return | |
@@ -231,12 +264,12 @@ func (l *forwardList) closeAll() { | |
l.entries = nil | |
} | |
-func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool { | |
+func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool { | |
l.Lock() | |
defer l.Unlock() | |
for _, f := range l.entries { | |
- if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port { | |
- f.c <- forward{ch, &raddr} | |
+ if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() { | |
+ f.c <- forward{newCh: ch, raddr: raddr} | |
return true | |
} | |
} | |
@@ -262,7 +295,7 @@ func (l *tcpListener) Accept() (net.Conn, error) { | |
} | |
go DiscardRequests(incoming) | |
- return &tcpChanConn{ | |
+ return &chanConn{ | |
Channel: ch, | |
laddr: l.laddr, | |
raddr: s.raddr, | |
@@ -277,7 +310,7 @@ func (l *tcpListener) Close() error { | |
} | |
// this also closes the listener. | |
- l.conn.forwards.remove(*l.laddr) | |
+ l.conn.forwards.remove(l.laddr) | |
ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m)) | |
if err == nil && !ok { | |
err = errors.New("ssh: cancel-tcpip-forward failed") | |
@@ -293,29 +326,52 @@ func (l *tcpListener) Addr() net.Addr { | |
// Dial initiates a connection to the addr from the remote host. | |
// The resulting connection has a zero LocalAddr() and RemoteAddr(). | |
func (c *Client) Dial(n, addr string) (net.Conn, error) { | |
- // Parse the address into host and numeric port. | |
- host, portString, err := net.SplitHostPort(addr) | |
- if err != nil { | |
- return nil, err | |
- } | |
- port, err := strconv.ParseUint(portString, 10, 16) | |
- if err != nil { | |
- return nil, err | |
- } | |
- // Use a zero address for local and remote address. | |
- zeroAddr := &net.TCPAddr{ | |
- IP: net.IPv4zero, | |
- Port: 0, | |
- } | |
- ch, err := c.dial(net.IPv4zero.String(), 0, host, int(port)) | |
- if err != nil { | |
- return nil, err | |
+ var ch Channel | |
+ switch n { | |
+ case "tcp", "tcp4", "tcp6": | |
+ // Parse the address into host and numeric port. | |
+ host, portString, err := net.SplitHostPort(addr) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ port, err := strconv.ParseUint(portString, 10, 16) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port)) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ // Use a zero address for local and remote address. | |
+ zeroAddr := &net.TCPAddr{ | |
+ IP: net.IPv4zero, | |
+ Port: 0, | |
+ } | |
+ return &chanConn{ | |
+ Channel: ch, | |
+ laddr: zeroAddr, | |
+ raddr: zeroAddr, | |
+ }, nil | |
+ case "unix": | |
+ var err error | |
+ ch, err = c.dialStreamLocal(addr) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ return &chanConn{ | |
+ Channel: ch, | |
+ laddr: &net.UnixAddr{ | |
+ Name: "@", | |
+ Net: "unix", | |
+ }, | |
+ raddr: &net.UnixAddr{ | |
+ Name: addr, | |
+ Net: "unix", | |
+ }, | |
+ }, nil | |
+ default: | |
+ return nil, fmt.Errorf("ssh: unsupported protocol: %s", n) | |
} | |
- return &tcpChanConn{ | |
- Channel: ch, | |
- laddr: zeroAddr, | |
- raddr: zeroAddr, | |
- }, nil | |
} | |
// DialTCP connects to the remote address raddr on the network net, | |
@@ -332,7 +388,7 @@ func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) | |
if err != nil { | |
return nil, err | |
} | |
- return &tcpChanConn{ | |
+ return &chanConn{ | |
Channel: ch, | |
laddr: laddr, | |
raddr: raddr, | |
@@ -366,26 +422,26 @@ type tcpChan struct { | |
Channel // the backing channel | |
} | |
-// tcpChanConn fulfills the net.Conn interface without | |
+// chanConn fulfills the net.Conn interface without | |
// the tcpChan having to hold laddr or raddr directly. | |
-type tcpChanConn struct { | |
+type chanConn struct { | |
Channel | |
laddr, raddr net.Addr | |
} | |
// LocalAddr returns the local network address. | |
-func (t *tcpChanConn) LocalAddr() net.Addr { | |
+func (t *chanConn) LocalAddr() net.Addr { | |
return t.laddr | |
} | |
// RemoteAddr returns the remote network address. | |
-func (t *tcpChanConn) RemoteAddr() net.Addr { | |
+func (t *chanConn) RemoteAddr() net.Addr { | |
return t.raddr | |
} | |
// SetDeadline sets the read and write deadlines associated | |
// with the connection. | |
-func (t *tcpChanConn) SetDeadline(deadline time.Time) error { | |
+func (t *chanConn) SetDeadline(deadline time.Time) error { | |
if err := t.SetReadDeadline(deadline); err != nil { | |
return err | |
} | |
@@ -396,12 +452,14 @@ func (t *tcpChanConn) SetDeadline(deadline time.Time) error { | |
// A zero value for t means Read will not time out. | |
// After the deadline, the error from Read will implement net.Error | |
// with Timeout() == true. | |
-func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error { | |
+func (t *chanConn) SetReadDeadline(deadline time.Time) error { | |
+ // for compatibility with previous version, | |
+ // the error message contains "tcpChan" | |
return errors.New("ssh: tcpChan: deadline not supported") | |
} | |
// SetWriteDeadline exists to satisfy the net.Conn interface | |
// but is not implemented by this type. It always returns an error. | |
-func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error { | |
+func (t *chanConn) SetWriteDeadline(deadline time.Time) error { | |
return errors.New("ssh: tcpChan: deadline not supported") | |
} | |
diff --git a/ssh/terminal/util_solaris.go b/ssh/terminal/util_solaris.go | |
index 07eb5ed..a2e1b57 100644 | |
--- a/ssh/terminal/util_solaris.go | |
+++ b/ssh/terminal/util_solaris.go | |
@@ -14,14 +14,12 @@ import ( | |
// State contains the state of a terminal. | |
type State struct { | |
- termios syscall.Termios | |
+ state *unix.Termios | |
} | |
// IsTerminal returns true if the given file descriptor is a terminal. | |
func IsTerminal(fd int) bool { | |
- // see: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libbc/libc/gen/common/isatty.c | |
- var termio unix.Termio | |
- err := unix.IoctlSetTermio(fd, unix.TCGETA, &termio) | |
+ _, err := unix.IoctlGetTermio(fd, unix.TCGETA) | |
return err == nil | |
} | |
@@ -71,3 +69,60 @@ func ReadPassword(fd int) ([]byte, error) { | |
return ret, nil | |
} | |
+ | |
+// MakeRaw puts the terminal connected to the given file descriptor into raw | |
+// mode and returns the previous state of the terminal so that it can be | |
+// restored. | |
+// see http://cr.illumos.org/~webrev/andy_js/1060/ | |
+func MakeRaw(fd int) (*State, error) { | |
+ oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ oldTermios := *oldTermiosPtr | |
+ | |
+ newTermios := oldTermios | |
+ newTermios.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON | |
+ newTermios.Oflag &^= syscall.OPOST | |
+ newTermios.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN | |
+ newTermios.Cflag &^= syscall.CSIZE | syscall.PARENB | |
+ newTermios.Cflag |= syscall.CS8 | |
+ newTermios.Cc[unix.VMIN] = 1 | |
+ newTermios.Cc[unix.VTIME] = 0 | |
+ | |
+ if err := unix.IoctlSetTermios(fd, unix.TCSETS, &newTermios); err != nil { | |
+ return nil, err | |
+ } | |
+ | |
+ return &State{ | |
+ state: oldTermiosPtr, | |
+ }, nil | |
+} | |
+ | |
+// Restore restores the terminal connected to the given file descriptor to a | |
+// previous state. | |
+func Restore(fd int, oldState *State) error { | |
+ return unix.IoctlSetTermios(fd, unix.TCSETS, oldState.state) | |
+} | |
+ | |
+// GetState returns the current state of a terminal which may be useful to | |
+// restore the terminal after a signal. | |
+func GetState(fd int) (*State, error) { | |
+ oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ | |
+ return &State{ | |
+ state: oldTermiosPtr, | |
+ }, nil | |
+} | |
+ | |
+// GetSize returns the dimensions of the given terminal. | |
+func GetSize(fd int) (width, height int, err error) { | |
+ ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ) | |
+ if err != nil { | |
+ return 0, 0, err | |
+ } | |
+ return int(ws.Col), int(ws.Row), nil | |
+} | |
diff --git a/ssh/test/cert_test.go b/ssh/test/cert_test.go | |
index 364790f..b231dd8 100644 | |
--- a/ssh/test/cert_test.go | |
+++ b/ssh/test/cert_test.go | |
@@ -7,12 +7,14 @@ | |
package test | |
import ( | |
+ "bytes" | |
"crypto/rand" | |
"testing" | |
"golang.org/x/crypto/ssh" | |
) | |
+// Test both logging in with a cert, and also that the certificate presented by an OpenSSH host can be validated correctly | |
func TestCertLogin(t *testing.T) { | |
s := newServer(t) | |
defer s.Shutdown() | |
@@ -37,11 +39,39 @@ func TestCertLogin(t *testing.T) { | |
conf := &ssh.ClientConfig{ | |
User: username(), | |
+ HostKeyCallback: (&ssh.CertChecker{ | |
+ IsHostAuthority: func(pk ssh.PublicKey, addr string) bool { | |
+ return bytes.Equal(pk.Marshal(), testPublicKeys["ca"].Marshal()) | |
+ }, | |
+ }).CheckHostKey, | |
} | |
conf.Auth = append(conf.Auth, ssh.PublicKeys(certSigner)) | |
- client, err := s.TryDial(conf) | |
- if err != nil { | |
- t.Fatalf("TryDial: %v", err) | |
+ | |
+ for _, test := range []struct { | |
+ addr string | |
+ succeed bool | |
+ }{ | |
+ {addr: "host.example.com:22", succeed: true}, | |
+ {addr: "host.example.com:10000", succeed: true}, // non-standard port must be OK | |
+ {addr: "host.example.com", succeed: false}, // port must be specified | |
+ {addr: "host.ex4mple.com:22", succeed: false}, // wrong host | |
+ } { | |
+ client, err := s.TryDialWithAddr(conf, test.addr) | |
+ | |
+ // Always close client if opened successfully | |
+ if err == nil { | |
+ client.Close() | |
+ } | |
+ | |
+ // Now evaluate whether the test failed or passed | |
+ if test.succeed { | |
+ if err != nil { | |
+ t.Fatalf("TryDialWithAddr: %v", err) | |
+ } | |
+ } else { | |
+ if err == nil { | |
+ t.Fatalf("TryDialWithAddr, unexpected success") | |
+ } | |
+ } | |
} | |
- client.Close() | |
} | |
diff --git a/ssh/test/dial_unix_test.go b/ssh/test/dial_unix_test.go | |
new file mode 100644 | |
index 0000000..091e48c | |
--- /dev/null | |
+++ b/ssh/test/dial_unix_test.go | |
@@ -0,0 +1,128 @@ | |
+// Copyright 2012 The Go Authors. All rights reserved. | |
+// Use of this source code is governed by a BSD-style | |
+// license that can be found in the LICENSE file. | |
+ | |
+// +build !windows | |
+ | |
+package test | |
+ | |
+// direct-tcpip and direct-streamlocal functional tests | |
+ | |
+import ( | |
+ "fmt" | |
+ "io" | |
+ "io/ioutil" | |
+ "net" | |
+ "strings" | |
+ "testing" | |
+) | |
+ | |
+type dialTester interface { | |
+ TestServerConn(t *testing.T, c net.Conn) | |
+ TestClientConn(t *testing.T, c net.Conn) | |
+} | |
+ | |
+func testDial(t *testing.T, n, listenAddr string, x dialTester) { | |
+ server := newServer(t) | |
+ defer server.Shutdown() | |
+ sshConn := server.Dial(clientConfig()) | |
+ defer sshConn.Close() | |
+ | |
+ l, err := net.Listen(n, listenAddr) | |
+ if err != nil { | |
+ t.Fatalf("Listen: %v", err) | |
+ } | |
+ defer l.Close() | |
+ | |
+ testData := fmt.Sprintf("hello from %s, %s", n, listenAddr) | |
+ go func() { | |
+ for { | |
+ c, err := l.Accept() | |
+ if err != nil { | |
+ break | |
+ } | |
+ x.TestServerConn(t, c) | |
+ | |
+ io.WriteString(c, testData) | |
+ c.Close() | |
+ } | |
+ }() | |
+ | |
+ conn, err := sshConn.Dial(n, l.Addr().String()) | |
+ if err != nil { | |
+ t.Fatalf("Dial: %v", err) | |
+ } | |
+ x.TestClientConn(t, conn) | |
+ defer conn.Close() | |
+ b, err := ioutil.ReadAll(conn) | |
+ if err != nil { | |
+ t.Fatalf("ReadAll: %v", err) | |
+ } | |
+ t.Logf("got %q", string(b)) | |
+ if string(b) != testData { | |
+ t.Fatalf("expected %q, got %q", testData, string(b)) | |
+ } | |
+} | |
+ | |
+type tcpDialTester struct { | |
+ listenAddr string | |
+} | |
+ | |
+func (x *tcpDialTester) TestServerConn(t *testing.T, c net.Conn) { | |
+ host := strings.Split(x.listenAddr, ":")[0] | |
+ prefix := host + ":" | |
+ if !strings.HasPrefix(c.LocalAddr().String(), prefix) { | |
+ t.Fatalf("expected to start with %q, got %q", prefix, c.LocalAddr().String()) | |
+ } | |
+ if !strings.HasPrefix(c.RemoteAddr().String(), prefix) { | |
+ t.Fatalf("expected to start with %q, got %q", prefix, c.RemoteAddr().String()) | |
+ } | |
+} | |
+ | |
+func (x *tcpDialTester) TestClientConn(t *testing.T, c net.Conn) { | |
+ // we use zero addresses. see *Client.Dial. | |
+ if c.LocalAddr().String() != "0.0.0.0:0" { | |
+ t.Fatalf("expected \"0.0.0.0:0\", got %q", c.LocalAddr().String()) | |
+ } | |
+ if c.RemoteAddr().String() != "0.0.0.0:0" { | |
+ t.Fatalf("expected \"0.0.0.0:0\", got %q", c.RemoteAddr().String()) | |
+ } | |
+} | |
+ | |
+func TestDialTCP(t *testing.T) { | |
+ x := &tcpDialTester{ | |
+ listenAddr: "127.0.0.1:0", | |
+ } | |
+ testDial(t, "tcp", x.listenAddr, x) | |
+} | |
+ | |
+type unixDialTester struct { | |
+ listenAddr string | |
+} | |
+ | |
+func (x *unixDialTester) TestServerConn(t *testing.T, c net.Conn) { | |
+ if c.LocalAddr().String() != x.listenAddr { | |
+ t.Fatalf("expected %q, got %q", x.listenAddr, c.LocalAddr().String()) | |
+ } | |
+ if c.RemoteAddr().String() != "@" { | |
+ t.Fatalf("expected \"@\", got %q", c.RemoteAddr().String()) | |
+ } | |
+} | |
+ | |
+func (x *unixDialTester) TestClientConn(t *testing.T, c net.Conn) { | |
+ if c.RemoteAddr().String() != x.listenAddr { | |
+ t.Fatalf("expected %q, got %q", x.listenAddr, c.RemoteAddr().String()) | |
+ } | |
+ if c.LocalAddr().String() != "@" { | |
+ t.Fatalf("expected \"@\", got %q", c.LocalAddr().String()) | |
+ } | |
+} | |
+ | |
+func TestDialUnix(t *testing.T) { | |
+ addr, cleanup := newTempSocket(t) | |
+ defer cleanup() | |
+ x := &unixDialTester{ | |
+ listenAddr: addr, | |
+ } | |
+ testDial(t, "unix", x.listenAddr, x) | |
+} | |
diff --git a/ssh/test/forward_unix_test.go b/ssh/test/forward_unix_test.go | |
index 877a88c..ea81937 100644 | |
--- a/ssh/test/forward_unix_test.go | |
+++ b/ssh/test/forward_unix_test.go | |
@@ -16,13 +16,17 @@ import ( | |
"time" | |
) | |
-func TestPortForward(t *testing.T) { | |
+type closeWriter interface { | |
+ CloseWrite() error | |
+} | |
+ | |
+func testPortForward(t *testing.T, n, listenAddr string) { | |
server := newServer(t) | |
defer server.Shutdown() | |
conn := server.Dial(clientConfig()) | |
defer conn.Close() | |
- sshListener, err := conn.Listen("tcp", "localhost:0") | |
+ sshListener, err := conn.Listen(n, listenAddr) | |
if err != nil { | |
t.Fatal(err) | |
} | |
@@ -41,14 +45,14 @@ func TestPortForward(t *testing.T) { | |
}() | |
forwardedAddr := sshListener.Addr().String() | |
- tcpConn, err := net.Dial("tcp", forwardedAddr) | |
+ netConn, err := net.Dial(n, forwardedAddr) | |
if err != nil { | |
- t.Fatalf("TCP dial failed: %v", err) | |
+ t.Fatalf("net dial failed: %v", err) | |
} | |
readChan := make(chan []byte) | |
go func() { | |
- data, _ := ioutil.ReadAll(tcpConn) | |
+ data, _ := ioutil.ReadAll(netConn) | |
readChan <- data | |
}() | |
@@ -62,14 +66,14 @@ func TestPortForward(t *testing.T) { | |
for len(sent) < 1000*1000 { | |
// Send random sized chunks | |
m := rand.Intn(len(data)) | |
- n, err := tcpConn.Write(data[:m]) | |
+ n, err := netConn.Write(data[:m]) | |
if err != nil { | |
break | |
} | |
sent = append(sent, data[:n]...) | |
} | |
- if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil { | |
- t.Errorf("tcpConn.CloseWrite: %v", err) | |
+ if err := netConn.(closeWriter).CloseWrite(); err != nil { | |
+ t.Errorf("netConn.CloseWrite: %v", err) | |
} | |
read := <-readChan | |
@@ -86,19 +90,29 @@ func TestPortForward(t *testing.T) { | |
} | |
// Check that the forward disappeared. | |
- tcpConn, err = net.Dial("tcp", forwardedAddr) | |
+ netConn, err = net.Dial(n, forwardedAddr) | |
if err == nil { | |
- tcpConn.Close() | |
+ netConn.Close() | |
t.Errorf("still listening to %s after closing", forwardedAddr) | |
} | |
} | |
-func TestAcceptClose(t *testing.T) { | |
+func TestPortForwardTCP(t *testing.T) { | |
+ testPortForward(t, "tcp", "localhost:0") | |
+} | |
+ | |
+func TestPortForwardUnix(t *testing.T) { | |
+ addr, cleanup := newTempSocket(t) | |
+ defer cleanup() | |
+ testPortForward(t, "unix", addr) | |
+} | |
+ | |
+func testAcceptClose(t *testing.T, n, listenAddr string) { | |
server := newServer(t) | |
defer server.Shutdown() | |
conn := server.Dial(clientConfig()) | |
- sshListener, err := conn.Listen("tcp", "localhost:0") | |
+ sshListener, err := conn.Listen(n, listenAddr) | |
if err != nil { | |
t.Fatal(err) | |
} | |
@@ -124,13 +138,23 @@ func TestAcceptClose(t *testing.T) { | |
} | |
} | |
+func TestAcceptCloseTCP(t *testing.T) { | |
+ testAcceptClose(t, "tcp", "localhost:0") | |
+} | |
+ | |
+func TestAcceptCloseUnix(t *testing.T) { | |
+ addr, cleanup := newTempSocket(t) | |
+ defer cleanup() | |
+ testAcceptClose(t, "unix", addr) | |
+} | |
+ | |
// Check that listeners exit if the underlying client transport dies. | |
-func TestPortForwardConnectionClose(t *testing.T) { | |
+func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) { | |
server := newServer(t) | |
defer server.Shutdown() | |
conn := server.Dial(clientConfig()) | |
- sshListener, err := conn.Listen("tcp", "localhost:0") | |
+ sshListener, err := conn.Listen(n, listenAddr) | |
if err != nil { | |
t.Fatal(err) | |
} | |
@@ -158,3 +182,13 @@ func TestPortForwardConnectionClose(t *testing.T) { | |
t.Logf("quit as expected (error %v)", err) | |
} | |
} | |
+ | |
+func TestPortForwardConnectionCloseTCP(t *testing.T) { | |
+ testPortForwardConnectionClose(t, "tcp", "localhost:0") | |
+} | |
+ | |
+func TestPortForwardConnectionCloseUnix(t *testing.T) { | |
+ addr, cleanup := newTempSocket(t) | |
+ defer cleanup() | |
+ testPortForwardConnectionClose(t, "unix", addr) | |
+} | |
diff --git a/ssh/test/tcpip_test.go b/ssh/test/tcpip_test.go | |
deleted file mode 100644 | |
index a2eb935..0000000 | |
--- a/ssh/test/tcpip_test.go | |
+++ /dev/null | |
@@ -1,46 +0,0 @@ | |
-// Copyright 2012 The Go Authors. All rights reserved. | |
-// Use of this source code is governed by a BSD-style | |
-// license that can be found in the LICENSE file. | |
- | |
-// +build !windows | |
- | |
-package test | |
- | |
-// direct-tcpip functional tests | |
- | |
-import ( | |
- "io" | |
- "net" | |
- "testing" | |
-) | |
- | |
-func TestDial(t *testing.T) { | |
- server := newServer(t) | |
- defer server.Shutdown() | |
- sshConn := server.Dial(clientConfig()) | |
- defer sshConn.Close() | |
- | |
- l, err := net.Listen("tcp", "127.0.0.1:0") | |
- if err != nil { | |
- t.Fatalf("Listen: %v", err) | |
- } | |
- defer l.Close() | |
- | |
- go func() { | |
- for { | |
- c, err := l.Accept() | |
- if err != nil { | |
- break | |
- } | |
- | |
- io.WriteString(c, c.RemoteAddr().String()) | |
- c.Close() | |
- } | |
- }() | |
- | |
- conn, err := sshConn.Dial("tcp", l.Addr().String()) | |
- if err != nil { | |
- t.Fatalf("Dial: %v", err) | |
- } | |
- defer conn.Close() | |
-} | |
diff --git a/ssh/test/test_unix_test.go b/ssh/test/test_unix_test.go | |
index 3bfd881..e673536 100644 | |
--- a/ssh/test/test_unix_test.go | |
+++ b/ssh/test/test_unix_test.go | |
@@ -30,6 +30,7 @@ Protocol 2 | |
HostKey {{.Dir}}/id_rsa | |
HostKey {{.Dir}}/id_dsa | |
HostKey {{.Dir}}/id_ecdsa | |
+HostCertificate {{.Dir}}/id_rsa-cert.pub | |
Pidfile {{.Dir}}/sshd.pid | |
#UsePrivilegeSeparation no | |
KeyRegenerationInterval 3600 | |
@@ -119,6 +120,11 @@ func clientConfig() *ssh.ClientConfig { | |
ssh.PublicKeys(testSigners["user"]), | |
}, | |
HostKeyCallback: hostKeyDB().Check, | |
+ HostKeyAlgorithms: []string{ // by default, don't allow certs as this affects the hostKeyDB checker | |
+ ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521, | |
+ ssh.KeyAlgoRSA, ssh.KeyAlgoDSA, | |
+ ssh.KeyAlgoED25519, | |
+ }, | |
} | |
return config | |
} | |
@@ -154,6 +160,12 @@ func unixConnection() (*net.UnixConn, *net.UnixConn, error) { | |
} | |
func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) { | |
+ return s.TryDialWithAddr(config, "") | |
+} | |
+ | |
+// addr is the user specified host:port. While we don't actually dial it, | |
+// we need to know this for host key matching | |
+func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Client, error) { | |
sshd, err := exec.LookPath("sshd") | |
if err != nil { | |
s.t.Skipf("skipping test: %v", err) | |
@@ -179,7 +191,7 @@ func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) { | |
s.t.Fatalf("s.cmd.Start: %v", err) | |
} | |
s.clientConn = c1 | |
- conn, chans, reqs, err := ssh.NewClientConn(c1, "", config) | |
+ conn, chans, reqs, err := ssh.NewClientConn(c1, addr, config) | |
if err != nil { | |
return nil, err | |
} | |
@@ -250,6 +262,11 @@ func newServer(t *testing.T) *server { | |
writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k])) | |
} | |
+ for k, v := range testdata.SSHCertificates { | |
+ filename := "id_" + k + "-cert.pub" | |
+ writeFile(filepath.Join(dir, filename), v) | |
+ } | |
+ | |
var authkeys bytes.Buffer | |
for k, _ := range testdata.PEMBytes { | |
authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k])) | |
@@ -266,3 +283,13 @@ func newServer(t *testing.T) *server { | |
}, | |
} | |
} | |
+ | |
+func newTempSocket(t *testing.T) (string, func()) { | |
+ dir, err := ioutil.TempDir("", "socket") | |
+ if err != nil { | |
+ t.Fatal(err) | |
+ } | |
+ deferFunc := func() { os.RemoveAll(dir) } | |
+ addr := filepath.Join(dir, "sock") | |
+ return addr, deferFunc | |
+} | |
diff --git a/ssh/testdata/keys.go b/ssh/testdata/keys.go | |
index 736dad9..3b3d26c 100644 | |
--- a/ssh/testdata/keys.go | |
+++ b/ssh/testdata/keys.go | |
@@ -48,12 +48,69 @@ AAAEAaYmXltfW6nhRo3iWGglRB48lYq0z0Q3I3KyrdutEr6j7d/uFLuDlRbBc4ZVOsx+Gb | |
HKuOrPtLHFvHsjWPwO+/AAAAE2dhcnRvbm1AZ2FydG9ubS14cHMBAg== | |
-----END OPENSSH PRIVATE KEY----- | |
`), | |
+ "rsa-openssh-format": []byte(`-----BEGIN OPENSSH PRIVATE KEY----- | |
+b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAlwAAAAdzc2gtcn | |
+NhAAAAAwEAAQAAAIEAwa48yfWFi3uIdqzuf9X7C2Zxfea/Iaaw0zIwHudpF8U92WVIiC5l | |
+oEuW1+OaVi3UWfIEjWMV1tHGysrHOwtwc34BPCJqJknUQO/KtDTBTJ4Pryhw1bWPC999Lz | |
+a+yrCTdNQYBzoROXKExZgPFh9pTMi5wqpHDuOQ2qZFIEI3lT0AAAIQWL0H31i9B98AAAAH | |
+c3NoLXJzYQAAAIEAwa48yfWFi3uIdqzuf9X7C2Zxfea/Iaaw0zIwHudpF8U92WVIiC5loE | |
+uW1+OaVi3UWfIEjWMV1tHGysrHOwtwc34BPCJqJknUQO/KtDTBTJ4Pryhw1bWPC999Lza+ | |
+yrCTdNQYBzoROXKExZgPFh9pTMi5wqpHDuOQ2qZFIEI3lT0AAAADAQABAAAAgCThyTGsT4 | |
+IARDxVMhWl6eiB2ZrgFgWSeJm/NOqtppWgOebsIqPMMg4UVuVFsl422/lE3RkPhVkjGXgE | |
+pWvZAdCnmLmApK8wK12vF334lZhZT7t3Z9EzJps88PWEHo7kguf285HcnUM7FlFeissJdk | |
+kXly34y7/3X/a6Tclm+iABAAAAQE0xR/KxZ39slwfMv64Rz7WKk1PPskaryI29aHE3mKHk | |
+pY2QA+P3QlrKxT/VWUMjHUbNNdYfJm48xu0SGNMRdKMAAABBAORh2NP/06JUV3J9W/2Hju | |
+X1ViJuqqcQnJPVzpgSL826EC2xwOECTqoY8uvFpUdD7CtpksIxNVqRIhuNOlz0lqEAAABB | |
+ANkaHTTaPojClO0dKJ/Zjs7pWOCGliebBYprQ/Y4r9QLBkC/XaWMS26gFIrjgC7D2Rv+rZ | |
+wSD0v0RcmkITP1ZR0AAAAYcHF1ZXJuYUBMdWNreUh5ZHJvLmxvY2FsAQID | |
+-----END OPENSSH PRIVATE KEY-----`), | |
"user": []byte(`-----BEGIN EC PRIVATE KEY----- | |
MHcCAQEEILYCAeq8f7V4vSSypRw7pxy8yz3V5W4qg8kSC3zJhqpQoAoGCCqGSM49 | |
AwEHoUQDQgAEYcO2xNKiRUYOLEHM7VYAp57HNyKbOdYtHD83Z4hzNPVC4tM5mdGD | |
PLL8IEwvYu2wq+lpXfGQnNMbzYf9gspG0w== | |
-----END EC PRIVATE KEY----- | |
`), | |
+ "ca": []byte(`-----BEGIN RSA PRIVATE KEY----- | |
+MIIEpAIBAAKCAQEAvg9dQ9IRG59lYJb+GESfKWTch4yBpr7Ydw1jkK6vvtrx9jLo | |
+5hkA8X6+ElRPRqTAZSlN5cBm6YCAcQIOsmXDUn6Oj1lVPQAoOjTBTvsjM3NjGhvv | |
+52kHTY0nsMsBeY9q5DTtlzmlYkVUq2a6Htgf2mNi01dIw5fJ7uTTo8EbNf7O0i3u | |
+c9a8P19HaZl5NKiWN4EIZkfB2WdXYRJCVBsGgQj3dE/GrEmH9QINq1A+GkNvK96u | |
+vZm8H1jjmuqzHplWa7lFeXcx8FTVTbVb/iJrZ2Lc/JvIPitKZWhqbR59yrGjpwEp | |
+Id7bo4WhO5L3OB0fSIJYvfu+o4WYnt4f3UzecwIDAQABAoIBABRD9yHgKErVuC2Q | |
+bA+SYZY8VvdtF/X7q4EmQFORDNRA7EPgMc03JU6awRGbQ8i4kHs46EFzPoXvWcKz | |
+AXYsO6N0Myc900Tp22A5d9NAHATEbPC/wdje7hRq1KyZONMJY9BphFv3nZbY5apR | |
+Dc90JBFZP5RhXjTc3n9GjvqLAKfFEKVmPRCvqxCOZunw6XR+SgIQLJo36nsIsbhW | |
+QUXIVaCI6cXMN8bRPm8EITdBNZu06Fpu4ZHm6VaxlXN9smERCDkgBSNXNWHKxmmA | |
+c3Glo2DByUr2/JFBOrLEe9fkYgr24KNCQkHVcSaFxEcZvTggr7StjKISVHlCNEaB | |
+7Q+kPoECgYEA3zE9FmvFGoQCU4g4Nl3dpQHs6kaAW8vJlrmq3xsireIuaJoa2HMe | |
+wYdIvgCnK9DIjyxd5OWnE4jXtAEYPsyGD32B5rSLQrRO96lgb3f4bESCLUb3Bsn/ | |
+sdgeE3p1xZMA0B59htqCrvVgN9k8WxyevBxYl3/gSBm/p8OVH1RTW/ECgYEA2f9Z | |
+95OLj0KQHQtxQXf+I3VjhCw3LkLW39QZOXVI0QrCJfqqP7uxsJXH9NYX0l0GFTcR | |
+kRrlyoaSU1EGQosZh+n1MvplGBTkTSV47/bPsTzFpgK2NfEZuFm9RoWgltS+nYeH | |
+Y2k4mnAN3PhReCMwuprmJz8GRLsO3Cs2s2YylKMCgYEA2UX+uO/q7jgqZ5UJW+ue | |
+1H5+W0aMuFA3i7JtZEnvRaUVFqFGlwXin/WJ2+WY1++k/rPrJ+Rk9IBXtBUIvEGw | |
+FC5TIfsKQsJyyWgqx/jbbtJ2g4s8+W/1qfTAuqeRNOg5d2DnRDs90wJuS4//0JaY | |
+9HkHyVwkQyxFxhSA/AHEMJECgYA2MvyFR1O9bIk0D3I7GsA+xKLXa77Ua53MzIjw | |
+9i4CezBGDQpjCiFli/fI8am+jY5DnAtsDknvjoG24UAzLy5L0mk6IXMdB6SzYYut | |
+7ak5oahqW+Y9hxIj+XvLmtGQbphtxhJtLu35x75KoBpxSh6FZpmuTEccs31AVCYn | |
+eFM/DQKBgQDOPUwbLKqVi6ddFGgrV9MrWw+SWsDa43bPuyvYppMM3oqesvyaX1Dt | |
+qDvN7owaNxNM4OnfKcZr91z8YPVCFo4RbBif3DXRzjNNBlxEjHBtuMOikwvsmucN | |
+vIrbeEpjTiUMTEAr6PoTiVHjsfS8WAM6MDlF5M+2PNswDsBpa2yLgA== | |
+-----END RSA PRIVATE KEY----- | |
+`), | |
+} | |
+ | |
+var SSHCertificates = map[string][]byte{ | |
+ // The following are corresponding certificates for the private keys above, signed by the CA key | |
+ // Generated by the following commands: | |
+ // | |
+ // 1. Assumes "rsa" key above in file named "rsa", write out the public key to "rsa.pub": | |
+ // ssh-keygen -y -f rsa > rsa.pu | |
+ // | |
+ // 2. Assumes "ca" key above in file named "ca", sign a cert for "rsa.pub": | |
+ // ssh-keygen -s ca -h -n host.example.com -V +500w -I host.example.com-key rsa.pub | |
+ "rsa": []byte(`[email protected] AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgLjYqmmuTSEmjVhSfLQphBSTJMLwIZhRgmpn8FHKLiEIAAAADAQABAAAAgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQAAAAAAAAAAAAAAAgAAABRob3N0LmV4YW1wbGUuY29tLWtleQAAABQAAAAQaG9zdC5leGFtcGxlLmNvbQAAAABZHN8UAAAAAGsjIYUAAAAAAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQC+D11D0hEbn2Vglv4YRJ8pZNyHjIGmvth3DWOQrq++2vH2MujmGQDxfr4SVE9GpMBlKU3lwGbpgIBxAg6yZcNSfo6PWVU9ACg6NMFO+yMzc2MaG+/naQdNjSewywF5j2rkNO2XOaViRVSrZroe2B/aY2LTV0jDl8nu5NOjwRs1/s7SLe5z1rw/X0dpmXk0qJY3gQhmR8HZZ1dhEkJUGwaBCPd0T8asSYf1Ag2rUD4aQ28r3q69mbwfWOOa6rMemVZruUV5dzHwVNVNtVv+ImtnYtz8m8g+K0plaGptHn3KsaOnASkh3tujhaE7kvc4HR9Igli9+76jhZie3h/dTN5zAAABDwAAAAdzc2gtcnNhAAABALeDea+60H6xJGhktAyosHaSY7AYzLocaqd8hJQjEIDifBwzoTlnBmcK9CxGhKuaoJFThdCLdaevCeOSuquh8HTkf+2ebZZc/G5T+2thPvPqmcuEcmMosWo+SIjYhbP3S6KD49aLC1X0kz8IBQeauFvURhkZ5ZjhA1L4aQYt9NjL73nqOl8PplRui+Ov5w8b4ldul4zOvYAFrzfcP6wnnXk3c1Zzwwf5wynD5jakO8GpYKBuhM7Z4crzkKSQjU3hla7xqgfomC5Gz4XbR2TNjcQiRrJQ0UlKtX3X3ObRCEhuvG0Kzjklhv+Ddw6txrhKjMjiSi/Yyius/AE8TmC1p4U= host.example.com | |
+`), | |
} | |
var PEMEncryptedKeys = []struct { | |
diff --git a/xts/xts.go b/xts/xts.go | |
index c9a283b..a7643fd 100644 | |
--- a/xts/xts.go | |
+++ b/xts/xts.go | |
@@ -23,6 +23,7 @@ package xts // import "golang.org/x/crypto/xts" | |
import ( | |
"crypto/cipher" | |
+ "encoding/binary" | |
"errors" | |
) | |
@@ -65,21 +66,20 @@ func (c *Cipher) Encrypt(ciphertext, plaintext []byte, sectorNum uint64) { | |
} | |
var tweak [blockSize]byte | |
- for i := 0; i < 8; i++ { | |
- tweak[i] = byte(sectorNum) | |
- sectorNum >>= 8 | |
- } | |
+ binary.LittleEndian.PutUint64(tweak[:8], sectorNum) | |
c.k2.Encrypt(tweak[:], tweak[:]) | |
- for i := 0; i < len(plaintext); i += blockSize { | |
- for j := 0; j < blockSize; j++ { | |
- ciphertext[i+j] = plaintext[i+j] ^ tweak[j] | |
+ for len(plaintext) > 0 { | |
+ for j := range tweak { | |
+ ciphertext[j] = plaintext[j] ^ tweak[j] | |
} | |
- c.k1.Encrypt(ciphertext[i:], ciphertext[i:]) | |
- for j := 0; j < blockSize; j++ { | |
- ciphertext[i+j] ^= tweak[j] | |
+ c.k1.Encrypt(ciphertext, ciphertext) | |
+ for j := range tweak { | |
+ ciphertext[j] ^= tweak[j] | |
} | |
+ plaintext = plaintext[blockSize:] | |
+ ciphertext = ciphertext[blockSize:] | |
mul2(&tweak) | |
} | |
@@ -97,21 +97,20 @@ func (c *Cipher) Decrypt(plaintext, ciphertext []byte, sectorNum uint64) { | |
} | |
var tweak [blockSize]byte | |
- for i := 0; i < 8; i++ { | |
- tweak[i] = byte(sectorNum) | |
- sectorNum >>= 8 | |
- } | |
+ binary.LittleEndian.PutUint64(tweak[:8], sectorNum) | |
c.k2.Encrypt(tweak[:], tweak[:]) | |
- for i := 0; i < len(plaintext); i += blockSize { | |
- for j := 0; j < blockSize; j++ { | |
- plaintext[i+j] = ciphertext[i+j] ^ tweak[j] | |
+ for len(ciphertext) > 0 { | |
+ for j := range tweak { | |
+ plaintext[j] = ciphertext[j] ^ tweak[j] | |
} | |
- c.k1.Decrypt(plaintext[i:], plaintext[i:]) | |
- for j := 0; j < blockSize; j++ { | |
- plaintext[i+j] ^= tweak[j] | |
+ c.k1.Decrypt(plaintext, plaintext) | |
+ for j := range tweak { | |
+ plaintext[j] ^= tweak[j] | |
} | |
+ plaintext = plaintext[blockSize:] | |
+ ciphertext = ciphertext[blockSize:] | |
mul2(&tweak) | |
} | |
diff --git a/xts/xts_test.go b/xts/xts_test.go | |
index 7a5e9fa..96d3b6c 100644 | |
--- a/xts/xts_test.go | |
+++ b/xts/xts_test.go | |
@@ -83,3 +83,23 @@ func TestXTS(t *testing.T) { | |
} | |
} | |
} | |
+ | |
+func TestShorterCiphertext(t *testing.T) { | |
+ // Decrypt used to panic if the input was shorter than the output. See | |
+ // https://go-review.googlesource.com/c/39954/ | |
+ c, err := NewCipher(aes.NewCipher, make([]byte, 32)) | |
+ if err != nil { | |
+ t.Fatalf("NewCipher failed: %s", err) | |
+ } | |
+ | |
+ plaintext := make([]byte, 32) | |
+ encrypted := make([]byte, 48) | |
+ decrypted := make([]byte, 48) | |
+ | |
+ c.Encrypt(encrypted, plaintext, 0) | |
+ c.Decrypt(decrypted, encrypted[:len(plaintext)], 0) | |
+ | |
+ if !bytes.Equal(plaintext, decrypted[:len(plaintext)]) { | |
+ t.Errorf("En/Decryption is not inverse") | |
+ } | |
+} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment