Skip to content

Instantly share code, notes, and snippets.

@danwakefield
Forked from creack/main.go
Last active August 29, 2015 14:18
Show Gist options
  • Save danwakefield/4b322fa6b25f8b999a3d to your computer and use it in GitHub Desktop.
Save danwakefield/4b322fa6b25f8b999a3d to your computer and use it in GitHub Desktop.
package main
import (
"crypto/rand"
"flag"
"fmt"
"io"
"log"
"net"
"os"
"golang.org/x/crypto/nacl/box"
)
// Dial generates a private/public key pair,
// connects to the server, perform the handshake
// and return a reader/writer.
func Dial(addr string) (io.ReadWriteCloser, error) {
pub, priv, err := box.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
// Exchange the keys and nonce.
nonce := NewNonce()
rpub := &[32]byte{}
if _, err = conn.Write(pub[:]); err != nil {
return nil, err
}
if _, err = conn.Write(nonce[:]); err != nil {
return nil, err
}
if _, err = conn.Read(rpub[:]); err != nil {
return nil, err
}
secureConn := NewSecureReadWriteCloser(conn, priv, rpub, nonce)
return secureConn, nil
}
// Serve starts a secure echo server on the given listener.
func Serve(l net.Listener) (err error) {
pub, priv, err := box.GenerateKey(rand.Reader)
if err != nil {
return err
}
for {
conn, err := l.Accept()
if err != nil {
return err
}
go func(c net.Conn) {
defer c.Close()
rpub := &[32]byte{}
nonce := &[24]byte{}
if _, err = conn.Write(pub[:]); err != nil {
panic("Serve - public key exchange failed")
}
if _, err = conn.Read(rpub[:]); err != nil {
panic("Serve - remote public key exchange failed")
}
if _, err = conn.Read(nonce[:]); err != nil {
panic("Serve - nonce exchange failed")
}
secureConn := NewSecureReadWriteCloser(conn, priv, rpub, *nonce)
io.Copy(secureConn, secureConn)
}(conn)
}
}
func main() {
port := flag.Int("l", 0, "Listen mode. Specify port")
flag.Parse()
// Server mode
if *port != 0 {
l, err := net.Listen("tcp", fmt.Sprintf(":%d", *port))
if err != nil {
log.Fatal(err)
}
defer l.Close()
log.Fatal(Serve(l))
}
// Client mode
if len(os.Args) != 3 {
log.Fatalf("Usage: %s <port> <message>", os.Args[0])
}
conn, err := Dial("localhost:" + os.Args[1])
if err != nil {
log.Fatal(err)
}
if _, err := conn.Write([]byte(os.Args[2])); err != nil {
log.Fatal(err)
}
buf := make([]byte, len(os.Args[2]))
n, err := conn.Read(buf)
if err != nil {
log.Fatal(err)
}
fmt.Printf("%s\n", buf[:n])
}
package main
import (
"fmt"
"io"
"io/ioutil"
"net"
"testing"
)
func TestReadWriterPing(t *testing.T) {
priv, pub := &[32]byte{'p', 'r', 'i', 'v'}, &[32]byte{'p', 'u', 'b'}
r, w := io.Pipe()
defer w.Close()
secureR := NewSecureReader(r, priv, pub)
secureW := NewSecureWriter(w, priv, pub)
// Encrypt hello world
go fmt.Fprintf(secureW, "hello world\n")
// Decrypt message
buf := make([]byte, 1024)
n, err := secureR.Read(buf)
if err != nil {
t.Fatal(err)
}
buf = buf[:n]
// Make sure we have hello world back
if res := string(buf); res != "hello world\n" {
t.Fatalf("Unexpected result: %s != %s", res, "hello world")
}
}
func TestSecureWriter(t *testing.T) {
priv, pub := &[32]byte{'p', 'r', 'i', 'v'}, &[32]byte{'p', 'u', 'b'}
r, w := io.Pipe()
secureW := NewSecureWriter(w, priv, pub)
// Make sure we are secure
// Encrypt hello world
go func() {
fmt.Fprintf(secureW, "hello world\n")
w.Close()
}()
// Read from the underlying transport instead of the decoder
buf, err := ioutil.ReadAll(r)
if err != nil {
t.Fatal(err)
}
// Make sure we dont' read the plain text message.
if res := string(buf); res == "hello world\n" {
t.Fatal("Unexpected result. The message is not encrypted.")
}
r, w = io.Pipe()
secureW = NewSecureWriter(w, priv, pub)
// Make sure we are unique
// Encrypt hello world
go func() {
fmt.Fprintf(secureW, "hello world\n")
w.Close()
}()
// Read from the underlying transport instead of the decoder
buf2, err := ioutil.ReadAll(r)
if err != nil {
t.Fatal(err)
}
// Make sure we dont' read the plain text message.
if string(buf) == string(buf2) {
t.Fatal("Unexpected result. The encrypted message is not unique.")
}
}
func TestSecureEchoServer(t *testing.T) {
// Create a random listener
l, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
defer l.Close()
// Start the server
go Serve(l)
conn, err := Dial(l.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
expected := "hello world\n"
if _, err := fmt.Fprintf(conn, expected); err != nil {
t.Fatal(err)
}
buf := make([]byte, 2048)
n, err := conn.Read(buf)
if err != nil {
t.Fatal(err)
}
if got := string(buf[:n]); got != expected {
t.Fatalf("Unexpected result:\nGot:\t\t%s\nExpected:\t%s\n", got, expected)
}
}
func TestSecureServe(t *testing.T) {
// Create a random listener
l, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
defer l.Close()
// Start the server
go Serve(l)
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatal(err)
}
unexpected := "hello world\n"
if _, err := fmt.Fprintf(conn, unexpected); err != nil {
t.Fatal(err)
}
buf := make([]byte, 2048)
n, err := conn.Read(buf)
if err != nil {
t.Fatal(err)
}
if got := string(buf[:n]); got == unexpected {
t.Fatalf("Unexpected result:\nGot raw data instead of serialized key")
}
}
func TestSecureDial(t *testing.T) {
// Create a random listener
l, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
defer l.Close()
// Start the server
go func(l net.Listener) {
for {
conn, err := l.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
key := [32]byte{}
c.Write(key[:])
buf := make([]byte, 2048)
n, err := c.Read(buf)
if err != nil {
t.Fatal(err)
}
if got := string(buf[:n]); got == "hello world\n" {
t.Fatal("Unexpected result. Got raw data instead of encrypted")
}
}(conn)
}
}(l)
conn, err := Dial(l.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
expected := "hello world\n"
if _, err := fmt.Fprintf(conn, expected); err != nil {
t.Fatal(err)
}
}
package main
import (
"crypto/rand"
"crypto/subtle"
"fmt"
"io"
"io/ioutil"
"golang.org/x/crypto/nacl/box"
)
var (
nonce = NewNonce()
replaceReaderNonce = false
replaceWriterNonce = false
empty = []byte{}
)
// SecureWriter holds the underlying communication channel and encryption keys
// for using it.
type SecureWriter struct {
w io.WriteCloser
key [32]byte
nonce [24]byte
}
// SecureReader holds the underlying communication channel and encryption keys
// for using it.
type SecureReader struct {
r io.ReadCloser
key [32]byte
nonce [24]byte
}
// SecureReadWriteCloser holds SecureReaders and SecureWriters together.
type SecureReadWriteCloser struct {
SecureReader
SecureWriter
}
// NewNonce sets up a nonce that is both random and sequential.
// This is useful against replay attacks.
// Using the least significant 6 bytes as the counter still leaves 2^36
// possible messages before a loop.
func NewNonce() [24]byte {
b := [24]byte{}
rand.Read(b[:18])
return b
}
func incrNonce(n *[24]byte) {
zero := 0
// Will always do 6 comparisons and incrs so as to
// not leak information on how large the nonce counter is.
// Probably unnecessary but it cant hurt.
for k, v := range [...]int{23, 22, 21, 20, 19, 18} {
n[v]++
// Check for byte wrap around from 255 to 0
// If it occurs incr the next most significant byte
if subtle.ConstantTimeByteEq(n[v], 0) != 1 {
for i := 0; i < 5-k; i++ {
subtle.ConstantTimeByteEq(0, 0)
zero++
}
break
}
}
zero = 0
}
// NewSecureWriter instantiates a new SecureWriter
func NewSecureWriter(w io.Writer, priv, pub *[32]byte) io.Writer {
wc, found := w.(io.WriteCloser)
if !found {
panic(fmt.Sprintf("Could not cast %v to io.WriteCloser", w))
}
s := SecureWriter{
w: wc,
nonce: nonce,
}
// Ensure that successive writers will have different nonces while
// a Reader, Writer pair will be created with the same nonce initially.
if replaceWriterNonce {
s.nonce = NewNonce()
}
replaceWriterNonce = true
box.Precompute(&s.key, pub, priv)
return s
}
// NewSecureReader instantiates a new SecureReader
func NewSecureReader(r io.Reader, priv, pub *[32]byte) io.Reader {
rc, found := r.(io.ReadCloser)
if !found {
rc = ioutil.NopCloser(r)
}
s := SecureReader{
r: rc,
nonce: nonce,
}
if replaceReaderNonce {
s.nonce = NewNonce()
}
replaceReaderNonce = true
box.Precompute(&s.key, pub, priv)
return s
}
// NewSecureReadWriteCloser wraps a SecureReader and SecureWriter in the
// ReadWriteCloser interface while ensuring that they share a nonce.
func NewSecureReadWriteCloser(conn io.ReadWriter, priv, pub *[32]byte, nonce [24]byte) io.ReadWriteCloser {
nsr := NewSecureReader(conn, priv, pub)
nsw := NewSecureWriter(conn, priv, pub)
sr := nsr.(SecureReader)
sw := nsw.(SecureWriter)
sr.nonce = nonce
sw.nonce = nonce
s := SecureReadWriteCloser{
sr,
sw,
}
return s
}
// Write encrypts b and writes it to the embedded Writer
func (s SecureWriter) Write(b []byte) (n int, err error) {
msg := box.SealAfterPrecomputation(empty, b, &s.nonce, &s.key)
n, err = s.w.Write(msg)
if err != nil {
return n, err
}
incrNonce(&s.nonce)
return n, nil
}
// Read reads from the embedded Reader and decrypts the msg to b.
func (s SecureReader) Read(b []byte) (n int, err error) {
msg := make([]byte, len(b)+box.Overhead)
n, err = s.r.Read(msg)
if err != nil {
return n, err
}
out, ok := box.OpenAfterPrecomputation(empty, msg[:n], &s.nonce, &s.key)
if !ok {
panic("Could not decrypt")
}
copy(b, out)
incrNonce(&s.nonce)
return len(out), nil
}
// Close calls the Close methods on the embedded Reader and Writer.
func (s SecureReadWriteCloser) Close() (err error) {
if err = s.SecureReader.Close(); err != nil {
return err
}
if err = s.SecureWriter.Close(); err != nil {
return err
}
return nil
}
// Close closes the Reader stream.
func (s SecureReader) Close() (err error) {
if err = s.r.Close(); err != nil {
return err
}
return nil
}
// Close closes the Writer stream.
func (s SecureWriter) Close() (err error) {
if err = s.w.Close(); err != nil {
return err
}
return nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment