|
const std = @import("std"); |
|
const SupportedVersions = @import("./supported_versions.zig").SupportedVersions; |
|
const sig = @import("./signature_scheme.zig"); |
|
const SignatureAlgorithm = sig.SignatureScheme; |
|
const SignatureAlgorithms = sig.SignatureSchemeList; |
|
const KeyShareClientHello = @import("./key_share.zig").KeyShareClientHello; |
|
const KeyShareEntry = @import("./key_share.zig").KeyShareEntry; |
|
const KeyShare = @import("./key_share.zig").KeyShare; |
|
const SupportedGroups = @import("./supported_groups.zig").SupportedGroups; |
|
|
|
/// ref: https://datatracker.ietf.org/doc/html/rfc8446#appendix-B.3.1 |
|
pub const ExtensionType = enum(u16) { |
|
server_name = 0, // RFC 6066 |
|
max_fragment_length = 1, // RFC 6066 |
|
status_request = 5, // RFC 6066 |
|
supported_groups = 10, // RFC 8422, 7919 |
|
signature_algorithms = 13, // RFC 8446 |
|
use_srtp = 14, // RFC 5764 |
|
heartbeat = 15, // RFC 6520 |
|
application_layer_protocol_negotiation = 16, // RFC 7301 |
|
signed_certificate_timestamp = 18, // RFC 6962 |
|
client_certificate_type = 19, // RFC 7250 |
|
server_certificate_type = 20, // RFC 7250 |
|
padding = 21, // RFC 7685 |
|
RESERVED_1 = 40, // Used but never assigned |
|
pre_shared_key = 41, // RFC 8446 |
|
early_data = 42, // RFC 8446 |
|
supported_versions = 43, // RFC 8446 |
|
cookie = 44, // RFC 8446 |
|
psk_key_exchange_modes = 45, // RFC 8446 |
|
RESERVED_2 = 46, // Used but never assigned |
|
certificate_authorities = 47, // RFC 8446 |
|
oid_filters = 48, // RFC 8446 |
|
post_handshake_auth = 49, // RFC 8446 |
|
signature_algorithms_cert = 50, // RFC 8446 |
|
key_share = 51, // RFC 8446 |
|
}; |
|
|
|
// unimplemented yet |
|
pub const ServerName = struct {}; |
|
pub const MaxFragmentLength = struct {}; |
|
pub const StatusRequest = struct {}; |
|
pub const UseStrp = struct {}; |
|
pub const Heartbeat = struct {}; |
|
pub const ApplicationLayerProtocolNegotiation = struct {}; |
|
pub const SignedCertificateTimestamp = struct {}; |
|
pub const ClientCertificateType = struct {}; |
|
pub const ServerCertificateType = struct {}; |
|
pub const Padding = struct {}; |
|
pub const Reserved = struct {}; |
|
pub const PreSharedKey = struct {}; |
|
pub const EarlyData = struct {}; |
|
pub const Cookie = struct {}; |
|
pub const PskKeyExchangeModes = struct {}; |
|
pub const CertificateAurhorities = struct {}; |
|
pub const OidFilters = struct {}; |
|
pub const PostHandshakeAuth = struct {}; |
|
pub const SignatureAlgorithmsCert = struct {}; |
|
|
|
/// ref: https://datatracker.ietf.org/doc/html/rfc8446#appendix-B.3.1 |
|
pub const Extension = union(ExtensionType) { |
|
server_name: ServerName, |
|
max_fragment_length: MaxFragmentLength, |
|
status_request: StatusRequest, |
|
supported_groups: SupportedGroups, |
|
signature_algorithms: SignatureAlgorithms, |
|
use_srtp: UseStrp, |
|
heartbeat: Heartbeat, |
|
application_layer_protocol_negotiation: ApplicationLayerProtocolNegotiation, |
|
signed_certificate_timestamp: SignedCertificateTimestamp, |
|
client_certificate_type: ClientCertificateType, |
|
server_certificate_type: ServerCertificateType, |
|
padding: Padding, |
|
RESERVED_1: Reserved, |
|
pre_shared_key: PreSharedKey, |
|
early_data: EarlyData, |
|
supported_versions: SupportedVersions, |
|
cookie: Cookie, |
|
psk_key_exchange_modes: PskKeyExchangeModes, |
|
RESERVED_2: Reserved, |
|
certificate_authorities: CertificateAurhorities, |
|
oid_filters: OidFilters, |
|
post_handshake_auth: PostHandshakeAuth, |
|
signature_algorithms_cert: SignatureAlgorithmsCert, |
|
key_share: KeyShare, |
|
|
|
const Self = @This(); |
|
|
|
pub fn deinit(self: Self) void { |
|
switch (self) { |
|
.supported_versions => |s| s.deinit(), |
|
.signature_algorithms => |s| s.deinit(), |
|
.key_share => |k| k.deinit(), |
|
.supported_groups => |s| s.deinit(), |
|
else => { |
|
// TODO: unimplemented |
|
}, |
|
} |
|
} |
|
|
|
pub fn encode(self: Self, out_stream: anytype) !void { |
|
// extension_type |
|
try out_stream.writeIntBig(u16, @enumToInt(self)); |
|
|
|
// extension_data |
|
switch (self) { |
|
.supported_versions => |supported_versions| { |
|
try out_stream.writeIntBig(u16, @intCast(u16, supported_versions.encodedSize())); |
|
try supported_versions.encode(out_stream); |
|
}, |
|
.signature_algorithms => |signature_algorithms| { |
|
try out_stream.writeIntBig(u16, @intCast(u16, signature_algorithms.encodedSize())); |
|
try signature_algorithms.encode(out_stream); |
|
}, |
|
.key_share => |key_share| { |
|
try out_stream.writeIntBig(u16, @intCast(u16, key_share.encodedSize())); |
|
try key_share.encode(out_stream); |
|
}, |
|
.supported_groups => |supported_groups| { |
|
try out_stream.writeIntBig(u16, @intCast(u16, supported_groups.encodedSize())); |
|
try supported_groups.encode(out_stream); |
|
}, |
|
else => { |
|
// TODO: unimplemented |
|
}, |
|
} |
|
} |
|
|
|
pub fn encodedSize(self: Self) usize { |
|
const extension_type_size = @sizeOf(@typeInfo(Extension).Union.tag_type.?); |
|
const extension_data_length_size = 2; |
|
|
|
const data_length = switch (self) { |
|
.supported_versions => |supported_versions| supported_versions.encodedSize(), |
|
.signature_algorithms => |signature_algorithms| signature_algorithms.encodedSize(), |
|
.key_share => |key_share| key_share.encodedSize(), |
|
.supported_groups => |supported_groups| supported_groups.encodedSize(), |
|
else => @as(usize, 0), // TODO(magurotuna) unimplemented |
|
}; |
|
|
|
return extension_type_size + extension_data_length_size + data_length; |
|
} |
|
|
|
const DecodeReturnType = struct { |
|
decoded: Self, |
|
bytes_read: usize, |
|
}; |
|
|
|
pub fn decode(allocator: std.mem.Allocator, in_stream: anytype) !DecodeReturnType { |
|
const ext_type = @intToEnum(ExtensionType, try in_stream.readIntBig(u16)); |
|
var bytes_read: usize = 2; |
|
|
|
const decoded = switch (ext_type) { |
|
.supported_versions => blk: { |
|
const res = try SupportedVersions.decode(allocator, in_stream); |
|
bytes_read += res.bytes_read; |
|
break :blk Extension{ .supported_versions = res.decoded }; |
|
}, |
|
.key_share => blk: { |
|
// currently only support decoding for message coming from server |
|
const res = try KeyShare.decode_server(in_stream); |
|
bytes_read += res.bytes_read; |
|
break :blk Extension{ .key_share = res.decoded }; |
|
}, |
|
else => unreachable, // TODO: unimplemented yet |
|
}; |
|
|
|
return DecodeReturnType{ |
|
.decoded = decoded, |
|
.bytes_read = bytes_read, |
|
}; |
|
} |
|
}; |
|
|
|
test "Supported Versions Extension properly encoded" { |
|
const allocator = std.testing.allocator; |
|
|
|
var versions = std.ArrayList(u16).init(allocator); |
|
try versions.append(SupportedVersions.TLS_1_3); |
|
|
|
const supported_versions = SupportedVersions{ .versions = versions }; |
|
|
|
const ext = Extension{ .supported_versions = supported_versions }; |
|
defer ext.deinit(); |
|
|
|
try std.testing.expectEqual(@as(usize, 7), ext.encodedSize()); |
|
|
|
var out_buf: [1024]u8 = undefined; |
|
var slice_stream = std.io.fixedBufferStream(&out_buf); |
|
const out = slice_stream.writer(); |
|
|
|
try ext.encode(out); |
|
|
|
const result = slice_stream.getWritten(); |
|
// extension_type versions_length |
|
// vvvvvvvvvv vvvv |
|
const expected = [_]u8{ 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04 }; |
|
// ^^^^^^^^^^ ^^^^^^^^^^ |
|
// data_length version (TLS v1.3) |
|
|
|
try std.testing.expect(std.mem.eql(u8, &expected, result)); |
|
} |
|
|
|
test "Signature Algorithms Extension properly encoded" { |
|
const allocator = std.testing.allocator; |
|
|
|
var signature_algorithms = SignatureAlgorithms.init(allocator); |
|
try signature_algorithms.add_signature_scheme(SignatureAlgorithm.rsa_pss_pss_sha256); |
|
|
|
const ext = Extension{ .signature_algorithms = signature_algorithms }; |
|
defer ext.deinit(); |
|
|
|
try std.testing.expectEqual(@as(usize, 8), ext.encodedSize()); |
|
|
|
var out_buf: [1024]u8 = undefined; |
|
var slice_stream = std.io.fixedBufferStream(&out_buf); |
|
const out = slice_stream.writer(); |
|
|
|
try ext.encode(out); |
|
|
|
const result = slice_stream.getWritten(); |
|
// extension_type algo_length |
|
// vvvvvvvvvv vvvvvvvvvv |
|
const expected = [_]u8{ 0x00, 0x0d, 0x00, 0x04, 0x00, 0x02, 0x08, 0x09 }; |
|
// ^^^^^^^^^^ ^^^^^^^^^^ |
|
// data_length algo |
|
|
|
try std.testing.expect(std.mem.eql(u8, &expected, result)); |
|
} |
|
|
|
test "Key Share Extension properly encoded" { |
|
const allocator = std.testing.allocator; |
|
|
|
const key_share = KeyShareEntry{ .x25519 = [_]u8{0x00} ** 32 }; |
|
|
|
var key_shares = KeyShareClientHello.init(allocator); |
|
try key_shares.add_key_share_entry(key_share); |
|
|
|
const ext = Extension{ .key_share = .{ .client = key_shares } }; |
|
defer ext.deinit(); |
|
|
|
try std.testing.expectEqual(@as(usize, 42), ext.encodedSize()); |
|
|
|
var out_buf: [1024]u8 = undefined; |
|
var slice_stream = std.io.fixedBufferStream(&out_buf); |
|
const out = slice_stream.writer(); |
|
|
|
try ext.encode(out); |
|
|
|
const result = slice_stream.getWritten(); |
|
const expected = blk: { |
|
const extension_type = [_]u8{ 0x00, 0x33 }; |
|
const data_length = [_]u8{ 0x00, 0x26 }; // 38 in decimal |
|
const key_share_length = [_]u8{ 0x00, 0x24 }; // 36 in decimal |
|
const group = [_]u8{ 0x00, 0x1d }; |
|
const length = [_]u8{ 0x00, 0x20 }; // 32 in decimal |
|
const data = [_]u8{0x00} ** 32; |
|
|
|
break :blk extension_type ++ |
|
data_length ++ |
|
key_share_length ++ |
|
group ++ |
|
length ++ |
|
data; |
|
}; |
|
|
|
try std.testing.expect(std.mem.eql(u8, &expected, result)); |
|
} |
|
|
|
test "Supported Groups Extension properly encoded" { |
|
const NamedGroup = @import("./named_group.zig").NamedGroup; |
|
const allocator = std.testing.allocator; |
|
|
|
var groups = SupportedGroups.init(allocator); |
|
try groups.add_group(NamedGroup.x25519); |
|
|
|
const ext = Extension{ .supported_groups = groups }; |
|
defer ext.deinit(); |
|
|
|
try std.testing.expectEqual(@as(usize, 8), ext.encodedSize()); |
|
|
|
var out_buf: [1024]u8 = undefined; |
|
var slice_stream = std.io.fixedBufferStream(&out_buf); |
|
const out = slice_stream.writer(); |
|
|
|
try ext.encode(out); |
|
|
|
const result = slice_stream.getWritten(); |
|
const expected = [_]u8{ 0x00, 0x0a, 0x00, 0x04, 0x00, 0x02, 0x00, 0x1d }; |
|
|
|
try std.testing.expect(std.mem.eql(u8, &expected, result)); |
|
} |
|
|
|
pub const Extensions = struct { |
|
const Self = @This(); |
|
|
|
extensions: std.ArrayList(Extension), |
|
|
|
pub fn init(allocator: std.mem.Allocator) Self { |
|
return Self{ |
|
.extensions = std.ArrayList(Extension).init(allocator), |
|
}; |
|
} |
|
|
|
/// This type has the responsibility to free all the extensions space that it holds inside. |
|
pub fn deinit(self: Self) void { |
|
for (self.extensions.items) |ext| { |
|
ext.deinit(); |
|
} |
|
|
|
self.extensions.deinit(); |
|
} |
|
|
|
pub fn add_ext(self: *Self, ext: Extension) !void { |
|
try self.extensions.append(ext); |
|
} |
|
|
|
pub fn encode(self: Extensions, out_stream: anytype) !void { |
|
var length: usize = 0; |
|
for (self.extensions.items) |ext| { |
|
length += ext.encodedSize(); |
|
} |
|
try out_stream.writeIntBig(u16, @intCast(u16, length)); |
|
|
|
for (self.extensions.items) |ext| { |
|
try ext.encode(out_stream); |
|
} |
|
} |
|
|
|
pub fn encodedSize(self: Extensions) usize { |
|
const size_length = 2; |
|
var data_length: usize = 0; |
|
for (self.extensions.items) |ext| { |
|
data_length += ext.encodedSize(); |
|
} |
|
return size_length + data_length; |
|
} |
|
|
|
pub fn decode(allocator: std.mem.Allocator, in_stream: anytype) !Self { |
|
const length = try in_stream.readIntBig(u16); |
|
|
|
var exts = std.ArrayList(Extension).init(allocator); |
|
errdefer exts.deinit(); |
|
|
|
var i: usize = 0; |
|
while (i < @intCast(usize, length)) { |
|
const res = try Extension.decode(allocator, in_stream); |
|
i += res.bytes_read; |
|
try exts.append(res.decoded); |
|
} |
|
|
|
return Self{ .extensions = exts }; |
|
} |
|
}; |
|
|
|
test "Extensions properly encoded" { |
|
var extensions_data = std.ArrayList(Extension).init(std.testing.allocator); |
|
|
|
var versions = std.ArrayList(u16).init(std.testing.allocator); |
|
|
|
try versions.append(SupportedVersions.TLS_1_3); |
|
const supported_versions = SupportedVersions{ .versions = versions }; |
|
const ext = Extension{ .supported_versions = supported_versions }; |
|
try extensions_data.append(ext); |
|
|
|
const extensions = Extensions{ .extensions = extensions_data }; |
|
defer extensions.deinit(); |
|
|
|
var out_buf: [1024]u8 = undefined; |
|
var slice_stream = std.io.fixedBufferStream(&out_buf); |
|
const out = slice_stream.writer(); |
|
|
|
try extensions.encode(out); |
|
|
|
const result = slice_stream.getWritten(); |
|
const expected = [_]u8{ 0x00, 0x07, 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04 }; |
|
|
|
try std.testing.expect(std.mem.eql(u8, &expected, result)); |
|
} |