Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 159 additions & 16 deletions src/handshake_server.zig
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ pub const Options = struct {
/// CertificateVerify message.
auth: ?*CertKeyPair,

/// Optional certificate selector invoked after parsing SNI from ClientHello.
/// When it returns null, `auth` is used as the fallback certificate.
sni_auth: ?SniAuth = null,

/// If not null server will request client certificate. If auth_type is
/// .request empty client certificate message will be accepted.
/// Client certificate will be verified with root_ca certificates.
Expand All @@ -44,6 +48,11 @@ pub const Options = struct {
alpn_protocols: []const []const u8 = &.{},

now: Io.Timestamp,

pub const SniAuth = struct {
ctx: *anyopaque,
selectFn: *const fn (ctx: *anyopaque, server_name: ?[]const u8) ?*CertKeyPair,
};
};

pub const ClientAuth = struct {
Expand All @@ -66,6 +75,7 @@ pub const ClientAuth = struct {
pub const Handshake = struct {
// public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97
const max_pub_key_len = 98;
const max_server_name_len = 255;
const supported_named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 };

/// Underlying network connection stream reader/writer pair.
Expand All @@ -83,6 +93,9 @@ pub const Handshake = struct {
client_pub_key: []u8 = &.{},
server_pub_key_buf: [max_pub_key_len]u8 = undefined,
server_pub_key: []u8 = &.{},
server_name_buf: [max_server_name_len]u8 = undefined,
server_name: ?[]const u8 = null,
auth: ?*CertKeyPair = null,

cipher: Cipher = undefined,
transcript: Transcript = .{},
Expand All @@ -106,7 +119,7 @@ pub const Handshake = struct {
pub fn handshake(h: *Self, opt: Options) !Cipher {
h.initKeys(opt);

h.readClientHello(opt.cipher_suites, opt.alpn_protocols) catch |err| {
h.readClientHello(opt) catch |err| {
try h.writeAlert(null, err);
return err;
};
Expand All @@ -130,14 +143,13 @@ pub const Handshake = struct {

fn initKeys(h: *Self, opt: Options) void {
opt.rng.bytes(&h.server_random);
if (opt.auth) |a| {
// required signature scheme in client hello
h.signature_scheme = a.key.signature_scheme;
}
h.auth = null;
h.server_name = null;
h.signature_scheme = @enumFromInt(0);
}

fn clientFlight1(h: *Self, opt: Options) !void {
try h.readClientHello(opt.cipher_suites, opt.alpn_protocols);
try h.readClientHello(opt);
h.transcript.use(h.cipher_suite.hash());
}

Expand Down Expand Up @@ -199,7 +211,7 @@ pub const Handshake = struct {
h.transcript.update(hw.buffered());
try h.writeEncrypted(&w, hw.buffered());
}
if (opt.auth) |auth| {
if (h.auth) |auth| {
const cb = CertificateBuilder{
.rng = opt.rng,
.cert_key_pair = auth,
Expand Down Expand Up @@ -379,7 +391,34 @@ pub const Handshake = struct {
try hw.int(u16, ext_len); // extensions length
}

fn readClientHello(h: *Self, supported_cipher_suites: []const CipherSuite, server_alpn_protocols: []const []const u8) !void {
fn resolveAuth(h: *Self, opt: Options) ?*CertKeyPair {
if (opt.sni_auth) |sni_auth| {
if (sni_auth.selectFn(sni_auth.ctx, h.server_name)) |auth| {
return auth;
}
}
return opt.auth;
}

fn parseServerNameExtension(h: *Self, d: *record.Decoder, extension_end_idx: usize) !void {
const list_end_idx = try d.decode(u16) + d.idx;
if (list_end_idx != extension_end_idx) return error.TlsDecodeError;

var host_name_found = false;
while (d.idx < list_end_idx) {
const name_type = try d.decode(u8);
const server_name = try d.slice(try d.decode(u16));

if (name_type == 0) {
if (host_name_found or server_name.len == 0) return error.TlsDecodeError;
h.server_name = try common.dupe(&h.server_name_buf, server_name);
host_name_found = true;
}
}
if (d.idx != list_end_idx) return error.TlsDecodeError;
}

fn readClientHello(h: *Self, opt: Options) !void {
var d = try Record.decoder(h.input);
if (d.payload.len > max_cleartext_len) return error.TlsRecordOverflow;
try d.expectContentType(.handshake);
Expand All @@ -400,7 +439,7 @@ pub const Handshake = struct {

while (d.idx < end_idx) {
const cipher_suite = try d.decode(CipherSuite);
if (cipher_suites.includes(supported_cipher_suites, cipher_suite) and
if (cipher_suites.includes(opt.cipher_suites, cipher_suite) and
@intFromEnum(h.cipher_suite) == 0)
{
h.cipher_suite = cipher_suite;
Expand All @@ -412,13 +451,41 @@ pub const Handshake = struct {
try d.skip(2); // compression methods

var key_share_received = false;
// extensions
const extensions_start_idx = d.idx + 2;
const extensions_end_idx = try d.decode(u16) + d.idx;
{
var ext_d = d;
ext_d.idx = extensions_start_idx;
while (ext_d.idx < extensions_end_idx) {
const extension_type = try ext_d.decode(proto.Extension);
const extension_end_idx = try ext_d.decode(u16) + ext_d.idx;

switch (extension_type) {
.server_name => {
if (h.server_name != null) return error.TlsDecodeError;
try h.parseServerNameExtension(&ext_d, extension_end_idx);
},
else => ext_d.idx = extension_end_idx,
}
}
if (ext_d.idx != extensions_end_idx) return error.TlsDecodeError;
}

h.auth = h.resolveAuth(opt);
if (h.auth) |auth| {
h.signature_scheme = auth.key.signature_scheme;
}

d.idx = extensions_start_idx;
while (d.idx < extensions_end_idx) {
const extension_type = try d.decode(proto.Extension);
const extension_len = try d.decode(u16);
const extension_end_idx = d.idx + extension_len;

switch (extension_type) {
.server_name => {
d.idx = extension_end_idx;
},
.supported_versions => {
var tls_1_3_supported = false;
const end_idx = try d.decode(u8) + d.idx;
Expand Down Expand Up @@ -485,14 +552,14 @@ pub const Handshake = struct {
},
.application_layer_protocol_negotiation => {
// RFC 7301: parse client ALPN extension and select a protocol
if (server_alpn_protocols.len > 0) {
if (opt.alpn_protocols.len > 0) {
const list_end = try d.decode(u16) + d.idx;
// Find the first server protocol that the client supports
// (server preference order)
var best_match: ?[]const u8 = null;
var best_server_idx: usize = server_alpn_protocols.len;
var best_server_idx: usize = opt.alpn_protocols.len;
const saved_idx = d.idx;
for (server_alpn_protocols, 0..) |server_proto, si| {
for (opt.alpn_protocols, 0..) |server_proto, si| {
d.idx = saved_idx;
while (d.idx < list_end) {
const proto_len = try d.decode(u8);
Expand All @@ -513,7 +580,7 @@ pub const Handshake = struct {
}
},
else => {
try d.skip(extension_len);
d.idx = extension_end_idx;
},
}
}
Expand All @@ -525,22 +592,98 @@ pub const Handshake = struct {
const testing = std.testing;
const data13 = @import("testdata/tls13.zig");
const testu = @import("testu.zig");
const ClientNonBlock = @import("handshake_client.zig").NonBlock;

test "read client hello" {
var reader: Io.Reader = .fixed(&data13.client_hello);
var h: Handshake = .{
.input = &reader,
.output = undefined,
};
h.signature_scheme = .ecdsa_secp521r1_sha512; // this must be supported in signature_algorithms extension
try h.readClientHello(cipher_suites.tls13, &.{});
const rsa_key = PrivateKey{
.signature_scheme = .ecdsa_secp521r1_sha512,
.key = undefined,
};
var auth = CertKeyPair{
.bundle = .empty,
.key = rsa_key,
.ecdsa_key_pair = null,
};
try h.readClientHello(.{
.rng = testu.random(0),
.auth = &auth,
.cipher_suites = cipher_suites.tls13,
.now = .zero,
});

try testing.expectEqual(CipherSuite.AES_256_GCM_SHA384, h.cipher_suite);
try testing.expectEqual(.x25519, h.named_group);
try testing.expectEqualSlices(u8, &data13.client_random, &h.client_random);
try testing.expectEqualSlices(u8, &data13.client_public_key, h.client_pub_key);
}

test "read client hello with server name and sni auth" {
const rng = testu.random(0);
var client = ClientNonBlock.init(.{
.rng = rng,
.host = "google.com",
.root_ca = .empty,
.cipher_suites = cipher_suites.tls13,
.now = .zero,
});
var client_hello_buf: [max_cleartext_len]u8 = undefined;
const client_res = try client.run(&.{}, &client_hello_buf);

try testing.expect(client_res.send.len > 0);

var reader: Io.Reader = .fixed(client_res.send);
var h: Handshake = .{
.input = &reader,
.output = undefined,
};

var default_auth = CertKeyPair{
.bundle = .empty,
.key = .{
.signature_scheme = .rsa_pss_rsae_sha256,
.key = undefined,
},
.ecdsa_key_pair = null,
};
var google_auth = CertKeyPair{
.bundle = .empty,
.key = .{
.signature_scheme = .ecdsa_secp256r1_sha256,
.key = undefined,
},
.ecdsa_key_pair = null,
};
const Selector = struct {
fn select(ctx: *anyopaque, server_name: ?[]const u8) ?*CertKeyPair {
const pair: *CertKeyPair = @ptrCast(@alignCast(ctx));
if (server_name) |name| {
if (mem.eql(u8, name, "google.com")) return pair;
}
return null;
}
};

try h.readClientHello(.{
.rng = rng,
.auth = &default_auth,
.sni_auth = .{
.ctx = @ptrCast(&google_auth),
.selectFn = Selector.select,
},
.cipher_suites = cipher_suites.tls13,
.now = .zero,
});

try testing.expectEqualStrings("google.com", h.server_name.?);
try testing.expectEqual(&google_auth, h.auth.?);
try testing.expectEqual(proto.SignatureScheme.ecdsa_secp256r1_sha256, h.signature_scheme);
}

test "make server hello" {
var h: Handshake = .{ .input = undefined, .output = undefined };

Expand Down
Loading