Skip to content

Instantly share code, notes, and snippets.

@ngrilly
Created December 4, 2013 09:49
Show Gist options
  • Save ngrilly/7784968 to your computer and use it in GitHub Desktop.
Save ngrilly/7784968 to your computer and use it in GitHub Desktop.
Sharing authentication cookies between Python and Go. This code is published in response to this discussion on golang-nuts: https://groups.google.com/forum/#!topic/golang-nuts/_nUQ1brlPBY This code is not really ready for public consumption: documentation is missing in some places, tests are missing in the Python version, code is not general eno…
package auth
import (
"net/http"
"net/url"
"strings"
"time"
"lib/crypto"
"lib/utils"
)
// TODO: add documentation
type Authenticator struct {
CookieName string
CookieMaxAge int64
signer *crypto.TimestampSigner
}
func NewAuthenticator(secretKeys []*crypto.SecretKey, cookieName string, cookieMaxAge int64) *Authenticator {
return &Authenticator{
signer: crypto.NewTimestampSigner(secretKeys, "auth"),
CookieName: cookieName,
CookieMaxAge: cookieMaxAge,
}
}
func NewDefaultAuthenticator(secretKeys []*crypto.SecretKey) *Authenticator {
return NewAuthenticator(secretKeys, "userid", 60*60*24)
}
func (a *Authenticator) Authenticate(r *http.Request) (string, error) {
cookie, err := r.Cookie(a.CookieName)
if err != nil {
return "", err
}
userId, err := a.signer.Unsign(cookie.Value, a.CookieMaxAge)
if err != nil {
return "", err
}
return url.QueryUnescape(userId)
}
func (a *Authenticator) Login(w http.ResponseWriter, userId string) {
expires := time.Now().Add(time.Duration(a.CookieMaxAge) * time.Second)
a.setCookie(w, userId, expires)
}
func (a *Authenticator) Logout(w http.ResponseWriter) {
a.setCookie(w, "", time.Unix(0, 0))
}
// setCookie adds or replaces the cookie
func (a *Authenticator) setCookie(w http.ResponseWriter, userId string, expires time.Time) {
cookie := http.Cookie{
Name: a.CookieName,
Path: "/",
Value: a.signer.Sign(url.QueryEscape(userId)),
Expires: expires,
HttpOnly: true,
}
for i, s := range w.Header()["Set-Cookie"] {
if strings.HasPrefix(s, a.CookieName+"=") {
w.Header()["Set-Cookie"][i] = cookie.String()
return
}
}
w.Header().Add("Set-Cookie", cookie.String())
}
func RedirectToLogin(w http.ResponseWriter, r *http.Request, loginPath string) {
u, err := url.Parse(loginPath)
utils.PanicOnError(err)
query := u.Query()
query.Set("from_page", r.URL.RequestURI())
u.RawQuery = query.Encode()
http.Redirect(w, r, u.RequestURI(), http.StatusSeeOther)
}
"""
Ideas and code borrowed from:
https://docs.djangoproject.com/en/1.5/topics/http/sessions/
https://github.com/facebook/tornado/blob/master/tornado/web.py
https://github.com/bbangert/beaker/blob/master/beaker/session.py
http://www.senchalabs.org/connect/cookieSession.html
http://www.gorillatoolkit.org/pkg/securecookie
http://security.stackexchange.com/questions/30707/demystifying-web-authentication-stateless-session-cookies
"""
from email.utils import formatdate
import logging
import time
import urllib
import cherrypy
import crypto
logger = logging.getLogger(__name__)
def authenticate(
secret_keys,
login_path='/auth/login',
authorize=None,
cookie_name='userid',
cookie_max_age=3600 * 24
):
"""
Authenticate and authorize the user before serving a request.
Before the request handler, this tool reads the authentication cookie and set cherrypy.request.login.
Then, the request handler can modify cherrypy.request.login to login or logout a user:
# Login
cherrypy.request.login = userid
# Logout
cherrypy.request.login = None
Parameters:
- secret_keys is an instance of crypto.SecretKeys.
- login_path is the login form URL.
- authorize is an optional callable that can raise a 403 or 404 HTTP error if the user
is not authorized. When authorize is not None, unauthenticated users are redirected to the login page.
- cookie_name is the name of the cookie storing the user ID.
- cookie_max_age is the cookie validity time in seconds (24 hours by default).
"""
signer = crypto.TimestampSigner(secret_keys, 'auth')
# set_cookie is called after the handler to send the Set-Cookie HTTP header
def set_cookie():
user_id = cherrypy.request.login
if cookie_name not in cherrypy.request.cookie and user_id is None:
return
cookie = cherrypy.response.cookie
if user_id is None:
cookie[cookie_name] = ''
cookie[cookie_name]['Expires'] = 'Thu, 01 Jan 1970 00:00:00 GMT'
else:
cookie[cookie_name] = signer.sign(urllib.quote(user_id, ''))
if cookie_max_age:
cookie[cookie_name]['Expires'] = formatdate(time.time() + cookie_max_age, usegmt=True)
cookie[cookie_name]['Path'] = '/'
cookie[cookie_name]['HTTPOnly'] = True
# Read and unsign cookie
cherrypy.request.login = None
cookie = cherrypy.request.cookie.get(cookie_name)
if cookie is not None:
try:
value = signer.unsign(cookie.value, cookie_max_age)
except crypto.Error as e:
logger.warning(e)
else:
cherrypy.request.login = urllib.unquote(value)
# Call set_cookie after the handler
cherrypy.request.hooks.attach('before_finalize', set_cookie)
# Check authorization or redirect an unauthenticated user to the login page
if authorize:
if cherrypy.request.login is None:
separator = '&' if '?' in login_path else '?'
url = login_path + separator + 'from_page=' + urllib.quote_plus(cherrypy.request.path_info)
raise cherrypy.HTTPRedirect(url)
authorize()
package auth
import (
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"testing"
"lib/crypto"
)
var authenticator = NewAuthenticator(
crypto.SecretKeys("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa=,bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb="),
"userid",
60,
)
func TestAll(t *testing.T) {
server := newTestServer()
defer server.Close()
jar, err := cookiejar.New(nil)
if err != nil {
t.Fatal(err)
}
client := http.Client{Jar: jar}
// Hit /page without any cookie
if got, want := responseStatus(client.Get(server.URL+"/page")), http.StatusForbidden; got != want {
t.Errorf("Status code: got %d, want %d", got, want)
}
// Log the user in
if got, want := responseStatus(client.Post(server.URL+"/login", "", nil)), http.StatusOK; got != want {
t.Errorf("Status code: got %d, want %d", got, want)
}
// Hit /page with a cookie
if got, want := responseStatus(client.Get(server.URL+"/page")), http.StatusOK; got != want {
t.Errorf("Status code: got %d, want %d", got, want)
}
// Hit /page with an expired cookie
authenticator.CookieMaxAge = -1
if got, want := responseStatus(client.Get(server.URL+"/page")), http.StatusForbidden; got != want {
t.Errorf("Status code: got %d, want %d", got, want)
}
authenticator.CookieMaxAge = 60
// Hit /page with a cookie
if got, want := responseStatus(client.Get(server.URL+"/page")), http.StatusOK; got != want {
t.Errorf("Status code: got %d, want %d", got, want)
}
// Log the user out
if got, want := responseStatus(client.Get(server.URL+"/logout")), http.StatusOK; got != want {
t.Errorf("Status code: got %d, want %d", got, want)
}
// Hit /page without any cookie
if got, want := responseStatus(client.Get(server.URL+"/page")), http.StatusForbidden; got != want {
t.Errorf("Status code: got %d, want %d", got, want)
}
}
func TestAddOrReplaceCookie(t *testing.T) {
w := httptest.NewRecorder()
authenticator.Login(w, "john")
authenticator.Login(w, "mary")
if l := len(w.HeaderMap["Set-Cookie"]); l != 1 {
t.Errorf("Expected 1 cookie, got %d cookies", l)
}
}
func TestRedirectToLogin(t *testing.T) {
r, _ := http.NewRequest("GET", "http://localhost/tickets?priority=high", nil)
w := httptest.NewRecorder()
RedirectToLogin(w, r, "/login")
l := w.HeaderMap.Get("Location")
if l != "/login?from_page=%2Ftickets%3Fpriority%3Dhigh" {
t.Errorf("RedirectToLogin returned invalid header Location: %v", l)
}
}
func newTestServer() *httptest.Server {
const userId = "john"
mux := http.NewServeMux()
mux.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
authenticator.Login(w, userId)
})
mux.HandleFunc("/logout", func(w http.ResponseWriter, r *http.Request) {
authenticator.Logout(w)
})
mux.HandleFunc("/page", func(w http.ResponseWriter, r *http.Request) {
if id, _ := authenticator.Authenticate(r); id != userId {
w.WriteHeader(http.StatusForbidden)
}
})
return httptest.NewServer(mux)
}
func responseStatus(r *http.Response, err error) int {
if err != nil || r.Body.Close() != nil {
return 0
}
return r.StatusCode
}
// Cryptographic utilities.
//
// Ideas and code borrowed from:
// https://docs.djangoproject.com/en/1.5/topics/signing/
// http://pythonhosted.org/itsdangerous/
// http://www.w3.org/TR/WebCryptoAPI/
package crypto
import (
"bytes"
"crypto/hmac"
hash "crypto/sha256"
"crypto/subtle"
"encoding/base64"
"fmt"
"strconv"
"strings"
"time"
)
const (
sep = "|"
)
// SecretKeys returns a list of secret keys used to sign messages.
//
// A list is used to allow key rotation. New messages are signed with the first
// key in the list. Received messages are verified using all keys in the list
// until one of them verifies the signature.
//
// comma_separated_keys is a string containing secret keys, separated by commas.
// A key can be generated in Python with base64.urlsafe_b64encode(os.urandom(32)).
func SecretKeys(commaSeparatedKeys string) []*SecretKey {
keys := []*SecretKey{}
for _, s := range strings.Split(commaSeparatedKeys, ",") {
key, err := base64.URLEncoding.DecodeString(strings.TrimSpace(s))
if err != nil {
panic(fmt.Errorf("crypto: could not base64 decode key %s: %s", s, err))
}
if len(key) < hash.Size {
panic(fmt.Errorf("crypto: key is too short (got %d bytes, expected %d bytes)", len(key), hash.Size))
}
keys = append(keys, &SecretKey{key})
}
return keys
}
type SecretKey struct {
key []byte
}
// Hmac returns the digest of msg, using a key derived from context.
//
// context is a string that identifies the part of the application using the
// MAC. A different context must be passed in for every part of the
// application using HMAC. This is important to prevent a MAC generated in
// one part of the application (for example an activation link) to be
// reused in another part of the same application (for example a signed
// cookie) by a malicious user.
func (k *SecretKey) Hmac(context string, msg []byte) []byte {
h := hmac.New(hash.New, k.DerivedKey(context))
h.Write(msg)
return h.Sum(nil)
}
func (k *SecretKey) Base64Hmac(context string, msg []byte) []byte {
mac := k.Hmac(context, msg)
b64Mac := make([]byte, base64.URLEncoding.EncodedLen(len(mac)))
base64.URLEncoding.Encode(b64Mac, mac)
return bytes.TrimRight(b64Mac, "=")
}
// DerivedKey returns a new key, derived from the master key, intended for use in
// the given context. This lets the application have a single secret key, and reuse
// it in multiple incompatible contexts in the form of derived keys. By deriving the
// master key with HMAC, the master key stays protected even if one of the derived keys is
// exposed.
//
// Reference:
// http://en.wikipedia.org/wiki/Key_derivation_function
func (k *SecretKey) DerivedKey(context string) []byte {
h := hmac.New(hash.New, k.key)
h.Write([]byte(context))
return h.Sum(nil)
}
// Sign and verify messages using message authentication codes (MAC).
type Signer struct {
keys []*SecretKey
context string
}
func NewSigner(keys []*SecretKey, context string) *Signer {
return &Signer{keys, context}
}
// Sign returns the original value completed with a signature.
func (s *Signer) Sign(value string) string {
sig := s.keys[0].Base64Hmac(s.context, []byte(value))
return value + sep + string(sig)
}
// Unsign returns the original value (without the signature) if the signature is valid
func (s *Signer) Unsign(signedValue string) (string, error) {
i := strings.LastIndex(signedValue, sep)
if i == -1 {
return "", fmt.Errorf("crypto: missing signature in %s", signedValue)
}
value := []byte(signedValue[0:i])
sig := []byte(signedValue[i+1:])
for _, key := range s.keys {
expectedSig := key.Base64Hmac(s.context, value)
if subtle.ConstantTimeCompare(sig, expectedSig) == 1 {
return string(value), nil
}
}
return "", fmt.Errorf("crypto: invalid signature in %s", signedValue)
}
type TimestampSigner struct {
*Signer
// Function that returns the current timestamp. Can be overriden for testing
timeFunc func() int64
}
func NewTimestampSigner(keys []*SecretKey, context string) *TimestampSigner {
signer := NewSigner(keys, context)
timeFunc := func() int64 { return time.Now().UTC().Unix() }
return &TimestampSigner{signer, timeFunc}
}
func (s *TimestampSigner) Sign(value string) string {
// TODO: Encode the timestamp like itsdangerous?
value = fmt.Sprintf("%s%s%d", value, sep, s.timeFunc())
return s.Signer.Sign(value)
}
// Unsign retrieves original value and check it wasn't signed more than max_age seconds ago.
func (s *TimestampSigner) Unsign(signedValue string, maxAge int64) (string, error) {
unsignedValue, err := s.Signer.Unsign(signedValue)
if err != nil {
return "", err
}
i := strings.LastIndex(unsignedValue, sep)
if i == -1 {
return "", fmt.Errorf("crypto: missing timestamp in {0}", signedValue)
}
timestamp, err := strconv.ParseInt(unsignedValue[i+1:], 10, 64)
if err != nil {
return "", fmt.Errorf("crypto: invalid timestamp in %s", signedValue)
}
// TODO: implement minAge ?
age := s.timeFunc() - timestamp
if age > maxAge {
return "", fmt.Errorf("crypto: expired signature in %s (age %d > %d seconds)", signedValue, age, maxAge)
}
return unsignedValue[0:i], nil
}
"""
Cryptographic utilities.
Ideas and code borrowed from:
https://docs.djangoproject.com/en/1.5/topics/signing/
http://pythonhosted.org/itsdangerous/
http://www.w3.org/TR/WebCryptoAPI/
"""
import base64
import hashlib
import hmac
from itertools import izip
import time
import sys
SEP = b'|'
DIGESTMOD = hashlib.sha256
digest_size = DIGESTMOD().digest_size
class Error(Exception):
pass
class SecretKeys(list):
"""
A list of secret keys used to sign messages. A list is used to allow key
rotation. New messages are signed with the first key in the list. Received
messages are verified using all keys in the list until one of them verifies
the signature.
"""
def __init__(self, comma_separated_keys):
"""
comma_separated_keys is a string containing secret keys, separated by commas.
A key can be generated with base64.urlsafe_b64encode(os.urandom(32)).
"""
for key in to_bytes(comma_separated_keys).split(','):
try:
key = base64.urlsafe_b64decode(key.strip())
except TypeError as e:
raise Error("Could not base64 decode key {0}: {1}".format(key, e))
if len(key) < digest_size:
raise Error("Key is too short (got {0} bytes, expected {1} bytes)".format(len(key), digest_size))
self.append(SecretKey(key))
class SecretKey(object):
def __init__(self, key):
self._key = key
def hmac(self, context, msg):
"""
Returns the digest of msg, using a key derived from context.
context is a string that identifies the part of the application using the
MAC. A different context must be passed in for every part of the
application using HMAC. This is important to prevent a MAC generated in
one part of the application (for example an activation link) to be
reused in another part of the same application (for example a signed
cookie) by a malicious user.
"""
return hmac.new(self.derived_key(context), msg, DIGESTMOD).digest()
def base64_hmac(self, context, msg):
return base64.urlsafe_b64encode(self.hmac(context, msg)).rstrip(b'=')
def derived_key(self, context):
"""
Return a new key, derived from the master key, intended for use in the given context.
This lets the application have a single secret key, and reuse it in multiple
incompatible contexts in the form of derived keys. By deriving the master key
with HMAC, the master key stays protected even if one of the derived keys is
exposed.
Reference:
http://en.wikipedia.org/wiki/Key_derivation_function
"""
return hmac.new(self._key, context, DIGESTMOD).digest()
class Signer(object):
"""
Sign and verify messages using message authentication codes (MAC).
"""
def __init__(self, secret_keys, context):
self.keys = secret_keys
self.context = to_bytes(context)
def sign(self, value):
"""
Return the original value completed with a signature.
"""
return to_bytes(value) + SEP + self.keys[0].base64_hmac(self.context, value)
def unsign(self, signed_value):
"""
Return the original value (without the signature) if the signature is valid, or raise an Error otherwise.
"""
signed_value = to_bytes(signed_value)
if SEP not in signed_value:
raise Error("Missing signature in {0}".format(signed_value))
value, sig = signed_value.rsplit(SEP, 1)
for key in self.keys:
if compare_digest(sig, key.base64_hmac(self.context, value)):
return value
raise Error("Invalid signature in {0}".format(signed_value))
class TimestampSigner(Signer):
def sign(self, value):
# TODO: Encode the timestamp like itsdangerous?
timestamp = int(time.time())
value = to_bytes(value) + SEP + str(timestamp)
return Signer.sign(self, value)
def unsign(self, signed_value, max_age=None):
"""
Retrieve original value and check it wasn't signed more than max_age seconds ago.
"""
unsigned_value = Signer.unsign(self, signed_value)
if SEP not in unsigned_value:
raise Error("Missing timestamp in {0}".format(signed_value))
value, timestamp = unsigned_value.rsplit(SEP, 1)
if max_age is not None:
try:
timestamp = int(timestamp)
except ValueError:
raise Error("Invalid timestamp in {0}".format(signed_value))
age = time.time() - timestamp
if age > max_age:
raise Error("Expired signature in {0} (age {1} > {2} seconds)".format(signed_value, age, max_age))
return value
def to_bytes(s, encoding='utf-8', errors='strict'):
if isinstance(s, unicode):
s = s.encode(encoding, errors)
return s
def _compare_digest(a, b):
"""
Return True if the two strings are equal, False otherwise, using a constant
time comparison designed to prevent a timing attack.
For the sake of simplicity, this function executes in constant time only
when the two strings have the same length. It short-circuits when they
have different lengths. Since we only use it to compare hashes of
known expected length, this is acceptable.
>>> compare_digest('abc', 'abc')
True
>>> compare_digest('abc', 'xyz')
False
"""
if len(a) != len(b):
return False
result = 0
for x, y in izip(a, b):
result |= ord(x) ^ ord(y)
return result == 0
# Use Python 3.3 implementation of hmac.compare_digest if available
compare_digest = getattr(hmac, 'compare_digest', _compare_digest)
package crypto
import "testing"
const (
key1 = "1evsWQ8Z4LXca49FUL6SEl-xJL4UayDIlqmRBO9zi5k="
key2 = "2evsWQ8Z4LXca49FUL6SEl-xJL4UayDIlqmRBO9zi5k="
key3 = "3evsWQ8Z4LXca49FUL6SEl-xJL4UayDIlqmRBO9zi5k="
context = "test"
)
var (
secretKeys = SecretKeys(key1 + "," + key2)
rotatedKeys = SecretKeys(key3 + "," + key1)
)
func TestKeySize(t *testing.T) {
defer func() {
if err := recover(); err == nil {
t.Errorf("Expected panic when key size is incorrect")
}
}()
SecretKeys("tooshortkey=")
}
func TestSign(t *testing.T) {
signer := NewSigner(secretKeys, context)
signedValue := signer.Sign("test")
unsignedValue, _ := signer.Unsign(signedValue)
if unsignedValue != "test" {
t.Errorf("Expected %s, got %s", "test", unsignedValue)
}
}
func TestTimestampSign(t *testing.T) {
signer := NewTimestampSigner(secretKeys, context)
signer.timeFunc = func() int64 { return 0 }
signedValue := signer.Sign("test")
signer.timeFunc = func() int64 { return 10 }
unsignedValue, _ := signer.Unsign(signedValue, 20)
if unsignedValue != "test" {
t.Errorf("Expected valid timestamped value")
}
unsignedValue, _ = signer.Unsign(signedValue, 5)
if unsignedValue != "" {
t.Errorf("Expected expired timestamped value")
}
}
func TestKeyRotation(t *testing.T) {
signer := NewSigner(secretKeys, context)
signedValue := signer.Sign("rotation")
signer = NewSigner(rotatedKeys, context)
unsignedValue, _ := signer.Unsign(signedValue)
if unsignedValue != "rotation" {
t.Errorf("Expected %s, got %s", "rotation", unsignedValue)
}
}
package utils
// PanicOnError panics if err != nil.
func PanicOnError(err error) {
if err != nil {
panic(err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment