diff --git a/src/handshake_server.zig b/src/handshake_server.zig index 2b57f6c..c65d607 100644 --- a/src/handshake_server.zig +++ b/src/handshake_server.zig @@ -30,6 +30,10 @@ pub const Options = struct { /// CertificateVerify message. auth: ?*CertKeyPair, + /// Optional certificate selector invoked after parsing SNI from ClientHello. + /// When it returns null, `auth` is used as the fallback certificate. + sni_auth: ?SniAuth = null, + /// If not null server will request client certificate. If auth_type is /// .request empty client certificate message will be accepted. /// Client certificate will be verified with root_ca certificates. @@ -37,6 +41,11 @@ pub const Options = struct { /// List of supported tls 1.3 cipher suites cipher_suites: []const CipherSuite = cipher_suites.tls13, + + pub const SniAuth = struct { + ctx: *anyopaque, + selectFn: *const fn (ctx: *anyopaque, server_name: ?[]const u8) ?*CertKeyPair, + }; }; pub const ClientAuth = struct { @@ -59,6 +68,7 @@ pub const ClientAuth = struct { pub const Handshake = struct { // public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97 const max_pub_key_len = 98; + const max_server_name_len = 255; const supported_named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 }; /// Underlying network connection stream reader/writer pair. @@ -76,6 +86,9 @@ pub const Handshake = struct { client_pub_key: []u8 = &.{}, server_pub_key_buf: [max_pub_key_len]u8 = undefined, server_pub_key: []u8 = &.{}, + server_name_buf: [max_server_name_len]u8 = undefined, + server_name: ?[]const u8 = null, + auth: ?*CertKeyPair = null, cipher: Cipher = undefined, transcript: Transcript = .{}, @@ -97,7 +110,7 @@ pub const Handshake = struct { pub fn handshake(h: *Self, opt: Options) !Cipher { h.initKeys(opt); - h.readClientHello(opt.cipher_suites) catch |err| { + h.readClientHello(opt) catch |err| { try h.writeAlert(null, err); return err; }; @@ -121,14 +134,14 @@ pub const Handshake = struct { fn initKeys(h: *Self, opt: Options) void { crypto.random.bytes(&h.server_random); - if (opt.auth) |a| { - // required signature scheme in client hello - h.signature_scheme = a.key.signature_scheme; - } + _ = opt; + h.auth = null; + h.server_name = null; + h.signature_scheme = @enumFromInt(0); } fn clientFlight1(h: *Self, opt: Options) !void { - try h.readClientHello(opt.cipher_suites); + try h.readClientHello(opt); h.transcript.use(h.cipher_suite.hash()); } @@ -178,7 +191,7 @@ pub const Handshake = struct { h.transcript.update(hw.buffered()); try h.writeEncrypted(&w, hw.buffered()); } - if (opt.auth) |auth| { + if (h.auth) |auth| { const cb = CertificateBuilder{ .cert_key_pair = auth, .transcript = &h.transcript, @@ -353,7 +366,34 @@ pub const Handshake = struct { try hw.int(u16, ext_len); // extensions length } - fn readClientHello(h: *Self, supported_cipher_suites: []const CipherSuite) !void { + fn resolveAuth(h: *Self, opt: Options) ?*CertKeyPair { + if (opt.sni_auth) |sni_auth| { + if (sni_auth.selectFn(sni_auth.ctx, h.server_name)) |auth| { + return auth; + } + } + return opt.auth; + } + + fn parseServerNameExtension(h: *Self, d: *record.Decoder, extension_end_idx: usize) !void { + const list_end_idx = try d.decode(u16) + d.idx; + if (list_end_idx != extension_end_idx) return error.TlsDecodeError; + + var host_name_found = false; + while (d.idx < list_end_idx) { + const name_type = try d.decode(u8); + const server_name = try d.slice(try d.decode(u16)); + + if (name_type == 0) { + if (host_name_found or server_name.len == 0) return error.TlsDecodeError; + h.server_name = try common.dupe(&h.server_name_buf, server_name); + host_name_found = true; + } + } + if (d.idx != list_end_idx) return error.TlsDecodeError; + } + + fn readClientHello(h: *Self, opt: Options) !void { var d = try Record.decoder(h.input); if (d.payload.len > max_cleartext_len) return error.TlsRecordOverflow; try d.expectContentType(.handshake); @@ -374,7 +414,7 @@ pub const Handshake = struct { while (d.idx < end_idx) { const cipher_suite = try d.decode(CipherSuite); - if (cipher_suites.includes(supported_cipher_suites, cipher_suite) and + if (cipher_suites.includes(opt.cipher_suites, cipher_suite) and @intFromEnum(h.cipher_suite) == 0) { h.cipher_suite = cipher_suite; @@ -386,13 +426,41 @@ pub const Handshake = struct { try d.skip(2); // compression methods var key_share_received = false; - // extensions + const extensions_start_idx = d.idx + 2; const extensions_end_idx = try d.decode(u16) + d.idx; + { + var ext_d = d; + ext_d.idx = extensions_start_idx; + while (ext_d.idx < extensions_end_idx) { + const extension_type = try ext_d.decode(proto.Extension); + const extension_end_idx = try ext_d.decode(u16) + ext_d.idx; + + switch (extension_type) { + .server_name => { + if (h.server_name != null) return error.TlsDecodeError; + try h.parseServerNameExtension(&ext_d, extension_end_idx); + }, + else => ext_d.idx = extension_end_idx, + } + } + if (ext_d.idx != extensions_end_idx) return error.TlsDecodeError; + } + + h.auth = h.resolveAuth(opt); + if (h.auth) |auth| { + h.signature_scheme = auth.key.signature_scheme; + } + + d.idx = extensions_start_idx; while (d.idx < extensions_end_idx) { const extension_type = try d.decode(proto.Extension); const extension_len = try d.decode(u16); + const extension_end_idx = d.idx + extension_len; switch (extension_type) { + .server_name => { + d.idx = extension_end_idx; + }, .supported_versions => { var tls_1_3_supported = false; const end_idx = try d.decode(u8) + d.idx; @@ -458,7 +526,7 @@ pub const Handshake = struct { } }, else => { - try d.skip(extension_len); + d.idx = extension_end_idx; }, } } @@ -470,6 +538,7 @@ pub const Handshake = struct { const testing = std.testing; const data13 = @import("testdata/tls13.zig"); const testu = @import("testu.zig"); +const ClientNonBlock = @import("handshake_client.zig").NonBlock; test "read client hello" { var reader: Io.Reader = .fixed(&data13.client_hello); @@ -477,8 +546,18 @@ test "read client hello" { .input = &reader, .output = undefined, }; - h.signature_scheme = .ecdsa_secp521r1_sha512; // this must be supported in signature_algorithms extension - try h.readClientHello(cipher_suites.tls13); + var auth = CertKeyPair{ + .bundle = .{}, + .key = .{ + .signature_scheme = .ecdsa_secp521r1_sha512, + .key = undefined, + }, + .ecdsa_key_pair = null, + }; + try h.readClientHello(.{ + .auth = &auth, + .cipher_suites = cipher_suites.tls13, + }); try testing.expectEqual(CipherSuite.AES_256_GCM_SHA384, h.cipher_suite); try testing.expectEqual(.x25519, h.named_group); @@ -486,6 +565,63 @@ test "read client hello" { try testing.expectEqualSlices(u8, &data13.client_public_key, h.client_pub_key); } +test "read client hello with server name and sni auth" { + var client = ClientNonBlock.init(.{ + .host = "google.com", + .root_ca = .{}, + .cipher_suites = cipher_suites.tls13, + }); + var client_hello_buf: [max_cleartext_len]u8 = undefined; + const client_res = try client.run(&.{}, &client_hello_buf); + + try testing.expect(client_res.send.len > 0); + + var reader: Io.Reader = .fixed(client_res.send); + var h: Handshake = .{ + .input = &reader, + .output = undefined, + }; + + var default_auth = CertKeyPair{ + .bundle = .{}, + .key = .{ + .signature_scheme = .rsa_pss_rsae_sha256, + .key = undefined, + }, + .ecdsa_key_pair = null, + }; + var google_auth = CertKeyPair{ + .bundle = .{}, + .key = .{ + .signature_scheme = .ecdsa_secp256r1_sha256, + .key = undefined, + }, + .ecdsa_key_pair = null, + }; + const Selector = struct { + fn select(ctx: *anyopaque, server_name: ?[]const u8) ?*CertKeyPair { + const pair: *CertKeyPair = @ptrCast(@alignCast(ctx)); + if (server_name) |name| { + if (mem.eql(u8, name, "google.com")) return pair; + } + return null; + } + }; + + try h.readClientHello(.{ + .auth = &default_auth, + .sni_auth = .{ + .ctx = @ptrCast(&google_auth), + .selectFn = Selector.select, + }, + .cipher_suites = cipher_suites.tls13, + }); + + try testing.expectEqualStrings("google.com", h.server_name.?); + try testing.expectEqual(&google_auth, h.auth.?); + try testing.expectEqual(proto.SignatureScheme.ecdsa_secp256r1_sha256, h.signature_scheme); +} + test "make server hello" { var h: Handshake = .{ .input = undefined, .output = undefined };