Skip to content

Instantly share code, notes, and snippets.

@igolaizola
Last active June 16, 2020 12:59
Show Gist options
  • Save igolaizola/98a39856c2413565f59371fe68fa8863 to your computer and use it in GitHub Desktop.
Save igolaizola/98a39856c2413565f59371fe68fa8863 to your computer and use it in GitHub Desktop.
Hybrid DTLS: DTLS client connection that accepts client hello messages acting also as a DTLS server
package hdtls
import (
"bytes"
"context"
"errors"
"io"
"net"
"sync"
"time"
"github.com/pion/dtls/v2"
)
// Client initiates a new DTLS client connection. While the connection is
// running, it accepts client hello messages acting as a DTLS server and if a
// correct handshake is done replaces the previous DTLS session with the new
// handshaked one.
func Client(conn net.Conn, config *dtls.Config) (*Conn, error) {
return hybrid(conn, config, dtls.Client)
}
// Resume imports an already established dtls connection using a specific dtls state
func Resume(state *dtls.State, conn net.Conn, config *dtls.Config) (*Conn, error) {
return hybrid(conn, config, func(c net.Conn, cfg *dtls.Config) (*dtls.Conn, error) {
return dtls.Resume(state, c, cfg)
})
}
func hybrid(conn net.Conn, config *dtls.Config, connFunc func(net.Conn, *dtls.Config) (*dtls.Conn, error)) (*Conn, error) {
cc := &checkConn{
Conn: conn,
hello: make(chan *checkConn),
config: config,
done: make(chan bool),
}
ctx, cancel := context.WithCancel(context.Background())
wc := &Conn{
lock: &sync.RWMutex{},
cancel: cancel,
}
go func() {
hello := cc.hello
var cc *checkConn
for {
select {
case <-ctx.Done():
return
case cc = <-hello:
}
done := cc.done
cc.done = make(chan bool)
dtlsConn, err := dtls.Server(cc, config)
if err != nil {
done <- false
close(cc.done)
continue
}
wc.set(dtlsConn)
done <- true
}
}()
dtlsConn, err := connFunc(cc, config)
if err != nil {
return nil, err
}
wc.set(dtlsConn)
return wc, nil
}
// Conn is a wrapper of a dtls.Conn that implements net.Conn interface
type Conn struct {
dtlsConn *dtls.Conn
lock *sync.RWMutex
cancel context.CancelFunc
}
func (c *Conn) set(conn *dtls.Conn) {
c.lock.Lock()
defer c.lock.Unlock()
c.dtlsConn = conn
}
func (c *Conn) conn() *dtls.Conn {
c.lock.RLock()
defer c.lock.RUnlock()
return c.dtlsConn
}
// Read implements net.Conn.Read
func (c *Conn) Read(b []byte) (int, error) {
n, err := c.conn().Read(b)
if errors.Is(err, errNewSession) {
return c.conn().Read(b)
}
return n, err
}
// Close implements net.Conn.Close
func (c *Conn) Close() error {
defer c.cancel()
return c.conn().Close()
}
// ConnectionState calls inner dtls.Conn.ConnectionState
func (c *Conn) ConnectionState() dtls.State {
return c.conn().ConnectionState()
}
// Write implements net.Conn.Write
func (c *Conn) Write(b []byte) (int, error) { return c.conn().Write(b) }
// LocalAddr implements net.Conn.LocalAddr
func (c *Conn) LocalAddr() net.Addr { return c.conn().LocalAddr() }
// RemoteAddr implements net.Conn.RemoteAddr
func (c *Conn) RemoteAddr() net.Addr { return c.conn().RemoteAddr() }
// SetDeadline implements net.Conn.SetDeadline
func (c *Conn) SetDeadline(t time.Time) error { return c.conn().SetDeadline(t) }
// SetReadDeadline implements net.Conn.SetReadDeadline
func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn().SetReadDeadline(t) }
// SetWriteDeadline implements net.Conn.SetReadDeadline
func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn().SetWriteDeadline(t) }
var errNewSession = errors.New("new session")
type checkConn struct {
first []byte
net.Conn
hello chan *checkConn
config *dtls.Config
done chan bool
}
func (c *checkConn) Write(b []byte) (int, error) {
select {
case <-c.done:
return 0, io.ErrClosedPipe
default:
}
return c.Conn.Write(b)
}
func (c *checkConn) Read(b []byte) (int, error) {
select {
case <-c.done:
return 0, io.EOF
default:
}
n := len(c.first)
if n > 0 {
copy(b, c.first)
c.first = nil
return n, nil
}
n, err := c.Conn.Read(b)
if err != nil {
return n, err
}
if !isClientHello(b[:n], c.config) {
return n, nil
}
c.hello <- &checkConn{
first: b[:n],
Conn: c.Conn,
hello: c.hello,
config: c.config,
done: c.done,
}
if ok := <-c.done; !ok {
return n, nil
}
return 0, errNewSession
}
func (c *checkConn) Close() error {
close(c.done)
return nil
}
func isClientHello(data []byte, config *dtls.Config) bool {
_, err := dtls.Server(isClientHelloConn{Reader: bytes.NewReader(data)}, config)
return errors.Is(err, errIsClientHello)
}
var errIsClientHello = errors.New("client hello")
type isClientHelloConn struct{ io.Reader }
func (isClientHelloConn) Write(b []byte) (int, error) { return 0, errIsClientHello }
func (isClientHelloConn) Close() error { return nil }
type emptyAddr struct{}
func (emptyAddr) String() string { return "" }
func (emptyAddr) Network() string { return "" }
func (isClientHelloConn) LocalAddr() net.Addr { return emptyAddr{} }
func (isClientHelloConn) RemoteAddr() net.Addr { return emptyAddr{} }
func (isClientHelloConn) SetDeadline(time.Time) error { return nil }
func (isClientHelloConn) SetReadDeadline(time.Time) error { return nil }
func (isClientHelloConn) SetWriteDeadline(time.Time) error { return nil }
package hdtls
import (
"bytes"
"crypto/tls"
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
)
func fatal(t *testing.T, errChan chan error, err error) {
close(errChan)
t.Fatal(err)
}
func TestClient(t *testing.T) {
certificate, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Fatal(err)
}
// Generate connections
localConn1, rc1 := net.Pipe()
localConn2, rc2 := net.Pipe()
remoteConn := &backupConn{curr: rc1, next: rc2}
// Launch remote in another goroutine
errChan := make(chan error, 1)
defer func() {
err = <-errChan
if err != nil {
t.Fatal(err)
}
}()
config := &dtls.Config{
Certificates: []tls.Certificate{certificate},
InsecureSkipVerify: true,
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
}
go func() {
var remote net.Conn
var errR error
remote, errR = Client(remoteConn, config)
if errR != nil {
errChan <- errR
return
}
// Loop of read write
for i := 0; i < 2; i++ {
recv := make([]byte, 1024)
var n int
n, errR = remote.Read(recv)
if errR != nil {
errChan <- errR
return
}
if _, errR = remote.Write(recv[:n]); errR != nil {
errChan <- errR
return
}
}
errChan <- nil
}()
var local net.Conn
local, err = dtls.Server(localConn1, config)
if err != nil {
fatal(t, errChan, err)
}
defer func() {
_ = local.Close()
}()
// Test write and read
message := []byte("Hello")
if _, err = local.Write(message); err != nil {
fatal(t, errChan, err)
}
recv := make([]byte, 1024)
var n int
n, err = local.Read(recv)
if err != nil {
fatal(t, errChan, err)
}
if !bytes.Equal(message, recv[:n]) {
fatal(t, errChan, fmt.Errorf("messages missmatch: %s != %s", message, recv[:n]))
}
if err = localConn1.Close(); err != nil {
fatal(t, errChan, err)
}
local, err = dtls.Client(localConn2, config)
if err != nil {
fatal(t, errChan, err)
}
// Test write and read on rehandshaked connection
if _, err = local.Write(message); err != nil {
fatal(t, errChan, err)
}
recv = make([]byte, 1024)
n, err = local.Read(recv)
if err != nil {
fatal(t, errChan, err)
}
if !bytes.Equal(message, recv[:n]) {
fatal(t, errChan, fmt.Errorf("messages missmatch: %s != %s", message, recv[:n]))
}
}
type backupConn struct {
curr net.Conn
next net.Conn
mux sync.Mutex
}
func (b *backupConn) Read(data []byte) (n int, err error) {
n, err = b.curr.Read(data)
if err != nil && b.next != nil {
b.mux.Lock()
b.curr = b.next
b.next = nil
b.mux.Unlock()
return b.Read(data)
}
return n, err
}
func (b *backupConn) Write(data []byte) (n int, err error) {
n, err = b.curr.Write(data)
if err != nil && b.next != nil {
b.mux.Lock()
b.curr = b.next
b.next = nil
b.mux.Unlock()
return b.Write(data)
}
return n, err
}
func (b *backupConn) Close() error {
return nil
}
func (b *backupConn) LocalAddr() net.Addr {
return nil
}
func (b *backupConn) RemoteAddr() net.Addr {
return nil
}
func (b *backupConn) SetDeadline(t time.Time) error {
return nil
}
func (b *backupConn) SetReadDeadline(t time.Time) error {
return nil
}
func (b *backupConn) SetWriteDeadline(t time.Time) error {
return nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment