From f52e258862d67a7946c9f75d901e02eb98b4dd30 Mon Sep 17 00:00:00 2001 From: Jay Liu Date: Wed, 1 Apr 2026 21:31:52 +1100 Subject: [PATCH] **PR Title** Add SNI-based certificate selection to TLS 1.3 server handshake **PR Description** This change adds Server Name Indication (SNI) support to the TLS 1.3 server handshake. The server now parses the `server_name` extension from `ClientHello`, stores the requested hostname in handshake state, and can select a certificate dynamically through a new `Options.sni_auth` callback. If the callback does not return a certificate, the existing `auth` field is still used as the default fallback. - Added `Options.sni_auth` to support dynamic certificate selection based on SNI. - Added `Handshake.server_name` state to retain the hostname parsed from `ClientHello`. - Implemented parsing for the `server_name` TLS extension. - Changed server certificate selection to use the resolved SNI-specific certificate instead of always using `opt.auth`. - Deferred signature scheme validation until after certificate selection, so validation is performed against the actual certificate chosen for the handshake. - Added tests covering: - existing `ClientHello` parsing behavior - SNI parsing and SNI-driven certificate selection Before this change, the server always used a single fixed certificate from `opt.auth`, regardless of the hostname requested by the client. That made virtual hosting impossible. This patch enables: - serving multiple hostnames from the same listener - choosing the correct certificate during handshake - preserving backward compatibility for callers that only use `auth` New server option: ```zig sni_auth: ?SniAuth = null ``` Where `SniAuth` is: ```zig pub const SniAuth = struct { ctx: *anyopaque, selectFn: *const fn (ctx: *anyopaque, server_name: ?[]const u8) ?*CertKeyPair, }; ``` Behavior: - `selectFn` is called after parsing SNI from `ClientHello` - if it returns a certificate, that certificate is used - if it returns `null`, the handshake falls back to `auth` This change is backward compatible: - existing users of `auth` continue to work unchanged - `sni_auth` is optional I could not use full `zig test` output as the validation signal because the current repository does not compile cleanly against the local Zig standard library version for unrelated reasons. The modified file passes formatting checks with: ```sh zig fmt --check src/handshake_server.zig ``` --- src/handshake_server.zig | 175 +++++++++++++++++++++++++++++++++++---- 1 file changed, 159 insertions(+), 16 deletions(-) diff --git a/src/handshake_server.zig b/src/handshake_server.zig index 19c8c40..9374622 100644 --- a/src/handshake_server.zig +++ b/src/handshake_server.zig @@ -31,6 +31,10 @@ pub const Options = struct { /// CertificateVerify message. auth: ?*CertKeyPair, + /// Optional certificate selector invoked after parsing SNI from ClientHello. + /// When it returns null, `auth` is used as the fallback certificate. + sni_auth: ?SniAuth = null, + /// If not null server will request client certificate. If auth_type is /// .request empty client certificate message will be accepted. /// Client certificate will be verified with root_ca certificates. @@ -44,6 +48,11 @@ pub const Options = struct { alpn_protocols: []const []const u8 = &.{}, now: Io.Timestamp, + + pub const SniAuth = struct { + ctx: *anyopaque, + selectFn: *const fn (ctx: *anyopaque, server_name: ?[]const u8) ?*CertKeyPair, + }; }; pub const ClientAuth = struct { @@ -66,6 +75,7 @@ pub const ClientAuth = struct { pub const Handshake = struct { // public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97 const max_pub_key_len = 98; + const max_server_name_len = 255; const supported_named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 }; /// Underlying network connection stream reader/writer pair. @@ -83,6 +93,9 @@ pub const Handshake = struct { client_pub_key: []u8 = &.{}, server_pub_key_buf: [max_pub_key_len]u8 = undefined, server_pub_key: []u8 = &.{}, + server_name_buf: [max_server_name_len]u8 = undefined, + server_name: ?[]const u8 = null, + auth: ?*CertKeyPair = null, cipher: Cipher = undefined, transcript: Transcript = .{}, @@ -106,7 +119,7 @@ pub const Handshake = struct { pub fn handshake(h: *Self, opt: Options) !Cipher { h.initKeys(opt); - h.readClientHello(opt.cipher_suites, opt.alpn_protocols) catch |err| { + h.readClientHello(opt) catch |err| { try h.writeAlert(null, err); return err; }; @@ -130,14 +143,13 @@ pub const Handshake = struct { fn initKeys(h: *Self, opt: Options) void { opt.rng.bytes(&h.server_random); - if (opt.auth) |a| { - // required signature scheme in client hello - h.signature_scheme = a.key.signature_scheme; - } + h.auth = null; + h.server_name = null; + h.signature_scheme = @enumFromInt(0); } fn clientFlight1(h: *Self, opt: Options) !void { - try h.readClientHello(opt.cipher_suites, opt.alpn_protocols); + try h.readClientHello(opt); h.transcript.use(h.cipher_suite.hash()); } @@ -199,7 +211,7 @@ pub const Handshake = struct { h.transcript.update(hw.buffered()); try h.writeEncrypted(&w, hw.buffered()); } - if (opt.auth) |auth| { + if (h.auth) |auth| { const cb = CertificateBuilder{ .rng = opt.rng, .cert_key_pair = auth, @@ -379,7 +391,34 @@ pub const Handshake = struct { try hw.int(u16, ext_len); // extensions length } - fn readClientHello(h: *Self, supported_cipher_suites: []const CipherSuite, server_alpn_protocols: []const []const u8) !void { + fn resolveAuth(h: *Self, opt: Options) ?*CertKeyPair { + if (opt.sni_auth) |sni_auth| { + if (sni_auth.selectFn(sni_auth.ctx, h.server_name)) |auth| { + return auth; + } + } + return opt.auth; + } + + fn parseServerNameExtension(h: *Self, d: *record.Decoder, extension_end_idx: usize) !void { + const list_end_idx = try d.decode(u16) + d.idx; + if (list_end_idx != extension_end_idx) return error.TlsDecodeError; + + var host_name_found = false; + while (d.idx < list_end_idx) { + const name_type = try d.decode(u8); + const server_name = try d.slice(try d.decode(u16)); + + if (name_type == 0) { + if (host_name_found or server_name.len == 0) return error.TlsDecodeError; + h.server_name = try common.dupe(&h.server_name_buf, server_name); + host_name_found = true; + } + } + if (d.idx != list_end_idx) return error.TlsDecodeError; + } + + fn readClientHello(h: *Self, opt: Options) !void { var d = try Record.decoder(h.input); if (d.payload.len > max_cleartext_len) return error.TlsRecordOverflow; try d.expectContentType(.handshake); @@ -400,7 +439,7 @@ pub const Handshake = struct { while (d.idx < end_idx) { const cipher_suite = try d.decode(CipherSuite); - if (cipher_suites.includes(supported_cipher_suites, cipher_suite) and + if (cipher_suites.includes(opt.cipher_suites, cipher_suite) and @intFromEnum(h.cipher_suite) == 0) { h.cipher_suite = cipher_suite; @@ -412,13 +451,41 @@ pub const Handshake = struct { try d.skip(2); // compression methods var key_share_received = false; - // extensions + const extensions_start_idx = d.idx + 2; const extensions_end_idx = try d.decode(u16) + d.idx; + { + var ext_d = d; + ext_d.idx = extensions_start_idx; + while (ext_d.idx < extensions_end_idx) { + const extension_type = try ext_d.decode(proto.Extension); + const extension_end_idx = try ext_d.decode(u16) + ext_d.idx; + + switch (extension_type) { + .server_name => { + if (h.server_name != null) return error.TlsDecodeError; + try h.parseServerNameExtension(&ext_d, extension_end_idx); + }, + else => ext_d.idx = extension_end_idx, + } + } + if (ext_d.idx != extensions_end_idx) return error.TlsDecodeError; + } + + h.auth = h.resolveAuth(opt); + if (h.auth) |auth| { + h.signature_scheme = auth.key.signature_scheme; + } + + d.idx = extensions_start_idx; while (d.idx < extensions_end_idx) { const extension_type = try d.decode(proto.Extension); const extension_len = try d.decode(u16); + const extension_end_idx = d.idx + extension_len; switch (extension_type) { + .server_name => { + d.idx = extension_end_idx; + }, .supported_versions => { var tls_1_3_supported = false; const end_idx = try d.decode(u8) + d.idx; @@ -485,14 +552,14 @@ pub const Handshake = struct { }, .application_layer_protocol_negotiation => { // RFC 7301: parse client ALPN extension and select a protocol - if (server_alpn_protocols.len > 0) { + if (opt.alpn_protocols.len > 0) { const list_end = try d.decode(u16) + d.idx; // Find the first server protocol that the client supports // (server preference order) var best_match: ?[]const u8 = null; - var best_server_idx: usize = server_alpn_protocols.len; + var best_server_idx: usize = opt.alpn_protocols.len; const saved_idx = d.idx; - for (server_alpn_protocols, 0..) |server_proto, si| { + for (opt.alpn_protocols, 0..) |server_proto, si| { d.idx = saved_idx; while (d.idx < list_end) { const proto_len = try d.decode(u8); @@ -513,7 +580,7 @@ pub const Handshake = struct { } }, else => { - try d.skip(extension_len); + d.idx = extension_end_idx; }, } } @@ -525,6 +592,7 @@ pub const Handshake = struct { const testing = std.testing; const data13 = @import("testdata/tls13.zig"); const testu = @import("testu.zig"); +const ClientNonBlock = @import("handshake_client.zig").NonBlock; test "read client hello" { var reader: Io.Reader = .fixed(&data13.client_hello); @@ -532,8 +600,21 @@ test "read client hello" { .input = &reader, .output = undefined, }; - h.signature_scheme = .ecdsa_secp521r1_sha512; // this must be supported in signature_algorithms extension - try h.readClientHello(cipher_suites.tls13, &.{}); + const rsa_key = PrivateKey{ + .signature_scheme = .ecdsa_secp521r1_sha512, + .key = undefined, + }; + var auth = CertKeyPair{ + .bundle = .empty, + .key = rsa_key, + .ecdsa_key_pair = null, + }; + try h.readClientHello(.{ + .rng = testu.random(0), + .auth = &auth, + .cipher_suites = cipher_suites.tls13, + .now = .zero, + }); try testing.expectEqual(CipherSuite.AES_256_GCM_SHA384, h.cipher_suite); try testing.expectEqual(.x25519, h.named_group); @@ -541,6 +622,68 @@ test "read client hello" { try testing.expectEqualSlices(u8, &data13.client_public_key, h.client_pub_key); } +test "read client hello with server name and sni auth" { + const rng = testu.random(0); + var client = ClientNonBlock.init(.{ + .rng = rng, + .host = "google.com", + .root_ca = .empty, + .cipher_suites = cipher_suites.tls13, + .now = .zero, + }); + var client_hello_buf: [max_cleartext_len]u8 = undefined; + const client_res = try client.run(&.{}, &client_hello_buf); + + try testing.expect(client_res.send.len > 0); + + var reader: Io.Reader = .fixed(client_res.send); + var h: Handshake = .{ + .input = &reader, + .output = undefined, + }; + + var default_auth = CertKeyPair{ + .bundle = .empty, + .key = .{ + .signature_scheme = .rsa_pss_rsae_sha256, + .key = undefined, + }, + .ecdsa_key_pair = null, + }; + var google_auth = CertKeyPair{ + .bundle = .empty, + .key = .{ + .signature_scheme = .ecdsa_secp256r1_sha256, + .key = undefined, + }, + .ecdsa_key_pair = null, + }; + const Selector = struct { + fn select(ctx: *anyopaque, server_name: ?[]const u8) ?*CertKeyPair { + const pair: *CertKeyPair = @ptrCast(@alignCast(ctx)); + if (server_name) |name| { + if (mem.eql(u8, name, "google.com")) return pair; + } + return null; + } + }; + + try h.readClientHello(.{ + .rng = rng, + .auth = &default_auth, + .sni_auth = .{ + .ctx = @ptrCast(&google_auth), + .selectFn = Selector.select, + }, + .cipher_suites = cipher_suites.tls13, + .now = .zero, + }); + + try testing.expectEqualStrings("google.com", h.server_name.?); + try testing.expectEqual(&google_auth, h.auth.?); + try testing.expectEqual(proto.SignatureScheme.ecdsa_secp256r1_sha256, h.signature_scheme); +} + test "make server hello" { var h: Handshake = .{ .input = undefined, .output = undefined };