Skip to content

Instantly share code, notes, and snippets.

@cretz
Last active June 14, 2018 19:25
Show Gist options
  • Save cretz/8436539bb688a59fc468a5e934eedf07 to your computer and use it in GitHub Desktop.
Save cretz/8436539bb688a59fc468a5e934eedf07 to your computer and use it in GitHub Desktop.
Chromecast Server in Go (fails unless authCert gen'd w/ Google-owned CA)
package main
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/binary"
"encoding/pem"
"fmt"
"io"
"log"
"math/big"
"net"
"os"
"time"
"github.com/cretz/temp/server/pb/cast_channel"
"github.com/golang/protobuf/proto"
"github.com/grandcat/zeroconf"
)
// Just a random v4 uuid I gen'd
const id = "fb40bc4b-1ef8-4e97-839f-b4a9cf8e5c10"
func main() {
if err := run(); err != nil {
log.Fatal(err)
}
}
func run() error {
debugf("Generating peer and auth certs...")
peerCert, err := generateCert()
if err != nil {
return fmt.Errorf("Unable to generate cert: %v", err)
}
// TODO: this has to come from a specific in-Chrome CA or auth fails
authCert, err := generateCert()
if err != nil {
return fmt.Errorf("Unable to generate cert: %v", err)
}
debugf("Starting listener...")
l, err := tls.Listen("tcp4", ":50000", &tls.Config{
Certificates: []tls.Certificate{peerCert.cert},
})
if err != nil {
return fmt.Errorf("Unable to listen: %v", err)
}
defer l.Close()
debugf("Registering mDNS...")
mdns, err := broadcastMdns("127.0.0.1", 50000)
if err != nil {
return fmt.Errorf("Unable to register mDNS: %v", err)
}
defer mdns.Shutdown()
debugf("Waiting for connection...")
conn, err := l.Accept()
if err != nil {
return fmt.Errorf("Failed accepting connection: %v", err)
}
debugf("Got conn from %v, reading...", conn.(*tls.Conn).RemoteAddr())
defer conn.Close()
debugf("Getting auth message...")
msg, err := readCastMessage(conn)
if err != nil {
return fmt.Errorf("Failed to read cast message: %v", err)
}
if msg.GetProtocolVersion() != cast_channel.CastMessage_CASTV2_1_0 {
return fmt.Errorf("Invalid version: %v", msg.GetProtocolVersion())
} else if msg.GetNamespace() != "urn:x-cast:com.google.cast.tp.deviceauth" {
return fmt.Errorf("Expected auth namespace, got: %v", msg.GetNamespace())
} else if msg.GetPayloadType() != cast_channel.CastMessage_BINARY {
return fmt.Errorf("Expected binary payload, got: %v", msg.GetPayloadType())
}
authReq := &cast_channel.DeviceAuthMessage{}
if err = proto.Unmarshal(msg.PayloadBinary, authReq); err != nil {
return fmt.Errorf("Unable to get auth message: %v", err)
} else if authReq.Challenge == nil {
return fmt.Errorf("Missing challenge")
}
debugf("Got auth message: %v", proto.MarshalTextString(authReq))
debugf("Responding to auth message")
authResp := &cast_channel.DeviceAuthMessage{
Response: &cast_channel.AuthResponse{
Signature: nil, // This is set below
ClientAuthCertificate: authCert.derBytes,
IntermediateCertificate: [][]byte{},
SignatureAlgorithm: authReq.Challenge.SignatureAlgorithm,
SenderNonce: authReq.Challenge.SenderNonce,
HashAlgorithm: authReq.Challenge.HashAlgorithm,
Crl: nil,
},
}
// Hash
var hash crypto.Hash
switch authReq.Challenge.GetHashAlgorithm() {
case cast_channel.HashAlgorithm_SHA1:
hash = crypto.SHA1
case cast_channel.HashAlgorithm_SHA256:
hash = crypto.SHA256
default:
return fmt.Errorf("Unrecognized hash algorithm: %v", authReq.Challenge.GetHashAlgorithm())
}
toSign := make([]byte, 0, len(authReq.Challenge.SenderNonce)+len(peerCert.derBytes))
toSign = append(toSign, authReq.Challenge.SenderNonce...)
toSign = append(toSign, peerCert.derBytes...)
hasher := hash.New()
if _, err = hasher.Write(toSign); err != nil {
return fmt.Errorf("Failed hashing: %v", err)
}
hashed := hasher.Sum(nil)
// Do the signature
switch authReq.Challenge.GetSignatureAlgorithm() {
case cast_channel.SignatureAlgorithm_RSASSA_PKCS1v15:
authResp.Response.Signature, err = rsa.SignPKCS1v15(rand.Reader, authCert.privKey, hash, hashed)
if err != nil {
return fmt.Errorf("Failed signing: %v", err)
}
case cast_channel.SignatureAlgorithm_RSASSA_PSS:
authResp.Response.Signature, err = rsa.SignPSS(rand.Reader, authCert.privKey, hash, hashed, nil)
if err != nil {
return fmt.Errorf("Failed signing: %v", err)
}
default:
return fmt.Errorf("Unknown sig algo: %v", authReq.Challenge.GetSignatureAlgorithm())
}
debugf("Sending auth message: %v", proto.MarshalTextString(authResp))
if err = sendProtoMessage(conn, msg.GetNamespace(), authResp); err != nil {
return fmt.Errorf("Failed sending auth message: %v", err)
}
debugf("Waiting for next message...")
if msg, err = readCastMessage(conn); err != nil {
return fmt.Errorf("Failed to read cast message: %v", err)
}
return nil
}
type certInfo struct {
privKey *rsa.PrivateKey
derBytes []byte
certPEM []byte
keyPEM []byte
cert tls.Certificate
}
func generateCert() (cert *certInfo, err error) {
// RSA 2048
cert = &certInfo{}
if cert.privKey, err = rsa.GenerateKey(rand.Reader, 2048); err != nil {
return nil, fmt.Errorf("Unable to generate key: %v", err)
}
notBefore := time.Now().Add(-10 * time.Minute)
notAfter := notBefore.Add(24 * time.Hour)
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("Unable to create serial number: %v", err)
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{Organization: []string{"Acme Co"}},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
cert.derBytes, err = x509.CreateCertificate(rand.Reader,
&template, &template, &cert.privKey.PublicKey, cert.privKey)
if err != nil {
return nil, fmt.Errorf("Unable to create cert: %v", err)
}
cert.certPEM = pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: cert.derBytes,
})
cert.keyPEM = pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(cert.privKey),
})
if cert.cert, err = tls.X509KeyPair(cert.certPEM, cert.keyPEM); err != nil {
return nil, fmt.Errorf("Unable to make key pair: %v", err)
}
return
}
func broadcastMdns(ip string, port int) (*zeroconf.Server, error) {
host, err := os.Hostname()
if err != nil {
return nil, fmt.Errorf("Failed getting hostname: %v", err)
}
text := []string{
"id=" + id,
"fn=TestCast v1",
}
return zeroconf.RegisterProxy("TestCast", "_googlecast._tcp", "local.", port, host, []string{ip}, text, nil)
}
func readCastMessage(r io.Reader) (*cast_channel.CastMessage, error) {
debugf("Getting message size...")
byts := make([]byte, 4)
if _, err := io.ReadFull(r, byts); err != nil {
return nil, fmt.Errorf("Failed reading size: %v", err)
}
msgSize := binary.BigEndian.Uint32(byts)
debugf("Read size num of %v", msgSize)
debugf("Getting cast message...")
var msg cast_channel.CastMessage
byts = make([]byte, msgSize)
if _, err := io.ReadFull(r, byts); err != nil {
return nil, fmt.Errorf("Unable to read msg: %v", err)
}
if err := proto.Unmarshal(byts, &msg); err != nil {
return nil, fmt.Errorf("Unable to unmarshal msg: %v", err)
}
debugf("Got message: %v", proto.MarshalTextString(&msg))
return &msg, nil
}
func sendProtoMessage(w io.Writer, namespace string, msg proto.Message) error {
byts, err := proto.Marshal(msg)
if err != nil {
return fmt.Errorf("Failed marshalling message: %v", err)
}
return sendBinaryCastMessage(w, namespace, byts)
}
func sendBinaryCastMessage(w io.Writer, namespace string, msg []byte) error {
version := cast_channel.CastMessage_CASTV2_1_0
sourceID := "receiver-0"
destinationID := "sender-0"
payloadType := cast_channel.CastMessage_BINARY
castMsg := &cast_channel.CastMessage{
ProtocolVersion: &version,
SourceId: &sourceID,
DestinationId: &destinationID,
Namespace: &namespace,
PayloadType: &payloadType,
PayloadBinary: msg,
}
byts, err := proto.Marshal(castMsg)
if err != nil {
return fmt.Errorf("Unable to marshal cast message: %v", err)
}
sizeByts := make([]byte, 4)
binary.BigEndian.PutUint32(sizeByts, uint32(len(byts)))
if _, err = w.Write(sizeByts); err != nil {
return fmt.Errorf("Unable to write size: %v", err)
}
if _, err = w.Write(byts); err != nil {
return fmt.Errorf("Unable to write bytes: %v", err)
}
return nil
}
const debug = true
func debugf(format string, v ...interface{}) {
if debug {
log.Printf(format, v...)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment