From 776ba599ddbd20dfba76f042a12f2a5bd1864c7e Mon Sep 17 00:00:00 2001 From: Jay Liu Date: Wed, 1 Apr 2026 22:06:28 +1100 Subject: [PATCH] **PR Title** Add SNI-based certificate selection to server handshake on `zig-0.15.x` **PR Description** This PR adds Server Name Indication (SNI) support to the TLS server handshake on the `zig-0.15.x` branch. The server now parses the `server_name` extension from `ClientHello`, stores the requested hostname in handshake state, and allows callers to select a certificate dynamically through a new `Options.sni_auth` callback. If the callback does not return a certificate, the existing `auth` field remains the default fallback. - Added `Options.sni_auth` for dynamic certificate selection. - Added server-side handshake state for the parsed SNI hostname. - Implemented parsing for the `server_name` TLS extension in `ClientHello`. - Switched certificate emission in the server flight from fixed `opt.auth` to the resolved handshake auth. - Deferred signature scheme validation until after certificate selection, so validation is performed against the actual certificate chosen for the handshake. - Added tests for: - existing `ClientHello` parsing behavior - SNI parsing and SNI-driven certificate selection Before this change, the server always used a single certificate from `auth`, regardless of the hostname requested by the client. That prevented serving multiple TLS hostnames from the same listener. This PR enables virtual hosting scenarios while keeping the existing API behavior intact for callers that only provide `auth`. New server option: ```zig sni_auth: ?SniAuth = null ``` With: ```zig pub const SniAuth = struct { ctx: *anyopaque, selectFn: *const fn (ctx: *anyopaque, server_name: ?[]const u8) ?*CertKeyPair, }; ``` Behavior: - `selectFn` is called after parsing SNI from `ClientHello` - if it returns a certificate, that certificate is used - if it returns `null`, the handshake falls back to `auth` This change is backward compatible: - existing users of `auth` continue to work unchanged - `sni_auth` is optional Validated locally with: ```sh zig fmt --check src/handshake_server.zig zig test src/handshake_server.zig ``` Relevant SNI tests pass, including the new SNI-specific handshake parsing test. Note: the `zig test src/handshake_server.zig` run later crashes in an existing unrelated client-side test (`handshake_client.test.handshake verify server finished message`), after the new handshake server tests have already passed. --- src/handshake_server.zig | 162 +++++++++++++++++++++++++++++++++++---- 1 file changed, 149 insertions(+), 13 deletions(-) diff --git a/src/handshake_server.zig b/src/handshake_server.zig index 2b57f6c..c65d607 100644 --- a/src/handshake_server.zig +++ b/src/handshake_server.zig @@ -30,6 +30,10 @@ pub const Options = struct { /// CertificateVerify message. auth: ?*CertKeyPair, + /// Optional certificate selector invoked after parsing SNI from ClientHello. + /// When it returns null, `auth` is used as the fallback certificate. + sni_auth: ?SniAuth = null, + /// If not null server will request client certificate. If auth_type is /// .request empty client certificate message will be accepted. /// Client certificate will be verified with root_ca certificates. @@ -37,6 +41,11 @@ pub const Options = struct { /// List of supported tls 1.3 cipher suites cipher_suites: []const CipherSuite = cipher_suites.tls13, + + pub const SniAuth = struct { + ctx: *anyopaque, + selectFn: *const fn (ctx: *anyopaque, server_name: ?[]const u8) ?*CertKeyPair, + }; }; pub const ClientAuth = struct { @@ -59,6 +68,7 @@ pub const ClientAuth = struct { pub const Handshake = struct { // public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97 const max_pub_key_len = 98; + const max_server_name_len = 255; const supported_named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 }; /// Underlying network connection stream reader/writer pair. @@ -76,6 +86,9 @@ pub const Handshake = struct { client_pub_key: []u8 = &.{}, server_pub_key_buf: [max_pub_key_len]u8 = undefined, server_pub_key: []u8 = &.{}, + server_name_buf: [max_server_name_len]u8 = undefined, + server_name: ?[]const u8 = null, + auth: ?*CertKeyPair = null, cipher: Cipher = undefined, transcript: Transcript = .{}, @@ -97,7 +110,7 @@ pub const Handshake = struct { pub fn handshake(h: *Self, opt: Options) !Cipher { h.initKeys(opt); - h.readClientHello(opt.cipher_suites) catch |err| { + h.readClientHello(opt) catch |err| { try h.writeAlert(null, err); return err; }; @@ -121,14 +134,14 @@ pub const Handshake = struct { fn initKeys(h: *Self, opt: Options) void { crypto.random.bytes(&h.server_random); - if (opt.auth) |a| { - // required signature scheme in client hello - h.signature_scheme = a.key.signature_scheme; - } + _ = opt; + h.auth = null; + h.server_name = null; + h.signature_scheme = @enumFromInt(0); } fn clientFlight1(h: *Self, opt: Options) !void { - try h.readClientHello(opt.cipher_suites); + try h.readClientHello(opt); h.transcript.use(h.cipher_suite.hash()); } @@ -178,7 +191,7 @@ pub const Handshake = struct { h.transcript.update(hw.buffered()); try h.writeEncrypted(&w, hw.buffered()); } - if (opt.auth) |auth| { + if (h.auth) |auth| { const cb = CertificateBuilder{ .cert_key_pair = auth, .transcript = &h.transcript, @@ -353,7 +366,34 @@ pub const Handshake = struct { try hw.int(u16, ext_len); // extensions length } - fn readClientHello(h: *Self, supported_cipher_suites: []const CipherSuite) !void { + fn resolveAuth(h: *Self, opt: Options) ?*CertKeyPair { + if (opt.sni_auth) |sni_auth| { + if (sni_auth.selectFn(sni_auth.ctx, h.server_name)) |auth| { + return auth; + } + } + return opt.auth; + } + + fn parseServerNameExtension(h: *Self, d: *record.Decoder, extension_end_idx: usize) !void { + const list_end_idx = try d.decode(u16) + d.idx; + if (list_end_idx != extension_end_idx) return error.TlsDecodeError; + + var host_name_found = false; + while (d.idx < list_end_idx) { + const name_type = try d.decode(u8); + const server_name = try d.slice(try d.decode(u16)); + + if (name_type == 0) { + if (host_name_found or server_name.len == 0) return error.TlsDecodeError; + h.server_name = try common.dupe(&h.server_name_buf, server_name); + host_name_found = true; + } + } + if (d.idx != list_end_idx) return error.TlsDecodeError; + } + + fn readClientHello(h: *Self, opt: Options) !void { var d = try Record.decoder(h.input); if (d.payload.len > max_cleartext_len) return error.TlsRecordOverflow; try d.expectContentType(.handshake); @@ -374,7 +414,7 @@ pub const Handshake = struct { while (d.idx < end_idx) { const cipher_suite = try d.decode(CipherSuite); - if (cipher_suites.includes(supported_cipher_suites, cipher_suite) and + if (cipher_suites.includes(opt.cipher_suites, cipher_suite) and @intFromEnum(h.cipher_suite) == 0) { h.cipher_suite = cipher_suite; @@ -386,13 +426,41 @@ pub const Handshake = struct { try d.skip(2); // compression methods var key_share_received = false; - // extensions + const extensions_start_idx = d.idx + 2; const extensions_end_idx = try d.decode(u16) + d.idx; + { + var ext_d = d; + ext_d.idx = extensions_start_idx; + while (ext_d.idx < extensions_end_idx) { + const extension_type = try ext_d.decode(proto.Extension); + const extension_end_idx = try ext_d.decode(u16) + ext_d.idx; + + switch (extension_type) { + .server_name => { + if (h.server_name != null) return error.TlsDecodeError; + try h.parseServerNameExtension(&ext_d, extension_end_idx); + }, + else => ext_d.idx = extension_end_idx, + } + } + if (ext_d.idx != extensions_end_idx) return error.TlsDecodeError; + } + + h.auth = h.resolveAuth(opt); + if (h.auth) |auth| { + h.signature_scheme = auth.key.signature_scheme; + } + + d.idx = extensions_start_idx; while (d.idx < extensions_end_idx) { const extension_type = try d.decode(proto.Extension); const extension_len = try d.decode(u16); + const extension_end_idx = d.idx + extension_len; switch (extension_type) { + .server_name => { + d.idx = extension_end_idx; + }, .supported_versions => { var tls_1_3_supported = false; const end_idx = try d.decode(u8) + d.idx; @@ -458,7 +526,7 @@ pub const Handshake = struct { } }, else => { - try d.skip(extension_len); + d.idx = extension_end_idx; }, } } @@ -470,6 +538,7 @@ pub const Handshake = struct { const testing = std.testing; const data13 = @import("testdata/tls13.zig"); const testu = @import("testu.zig"); +const ClientNonBlock = @import("handshake_client.zig").NonBlock; test "read client hello" { var reader: Io.Reader = .fixed(&data13.client_hello); @@ -477,8 +546,18 @@ test "read client hello" { .input = &reader, .output = undefined, }; - h.signature_scheme = .ecdsa_secp521r1_sha512; // this must be supported in signature_algorithms extension - try h.readClientHello(cipher_suites.tls13); + var auth = CertKeyPair{ + .bundle = .{}, + .key = .{ + .signature_scheme = .ecdsa_secp521r1_sha512, + .key = undefined, + }, + .ecdsa_key_pair = null, + }; + try h.readClientHello(.{ + .auth = &auth, + .cipher_suites = cipher_suites.tls13, + }); try testing.expectEqual(CipherSuite.AES_256_GCM_SHA384, h.cipher_suite); try testing.expectEqual(.x25519, h.named_group); @@ -486,6 +565,63 @@ test "read client hello" { try testing.expectEqualSlices(u8, &data13.client_public_key, h.client_pub_key); } +test "read client hello with server name and sni auth" { + var client = ClientNonBlock.init(.{ + .host = "google.com", + .root_ca = .{}, + .cipher_suites = cipher_suites.tls13, + }); + var client_hello_buf: [max_cleartext_len]u8 = undefined; + const client_res = try client.run(&.{}, &client_hello_buf); + + try testing.expect(client_res.send.len > 0); + + var reader: Io.Reader = .fixed(client_res.send); + var h: Handshake = .{ + .input = &reader, + .output = undefined, + }; + + var default_auth = CertKeyPair{ + .bundle = .{}, + .key = .{ + .signature_scheme = .rsa_pss_rsae_sha256, + .key = undefined, + }, + .ecdsa_key_pair = null, + }; + var google_auth = CertKeyPair{ + .bundle = .{}, + .key = .{ + .signature_scheme = .ecdsa_secp256r1_sha256, + .key = undefined, + }, + .ecdsa_key_pair = null, + }; + const Selector = struct { + fn select(ctx: *anyopaque, server_name: ?[]const u8) ?*CertKeyPair { + const pair: *CertKeyPair = @ptrCast(@alignCast(ctx)); + if (server_name) |name| { + if (mem.eql(u8, name, "google.com")) return pair; + } + return null; + } + }; + + try h.readClientHello(.{ + .auth = &default_auth, + .sni_auth = .{ + .ctx = @ptrCast(&google_auth), + .selectFn = Selector.select, + }, + .cipher_suites = cipher_suites.tls13, + }); + + try testing.expectEqualStrings("google.com", h.server_name.?); + try testing.expectEqual(&google_auth, h.auth.?); + try testing.expectEqual(proto.SignatureScheme.ecdsa_secp256r1_sha256, h.signature_scheme); +} + test "make server hello" { var h: Handshake = .{ .input = undefined, .output = undefined };