diff --git a/example/integration_test.zig b/example/integration_test.zig index d5fd4f7..492e9bb 100644 --- a/example/integration_test.zig +++ b/example/integration_test.zig @@ -110,6 +110,39 @@ test "server with ec key key pair" { thr.join(); } +test "server with ec key key pair from slices" { + const allocator = testing.allocator; + var threaded: std.Io.Threaded = .init(allocator, .{}); + defer threaded.deinit(); + const io = threaded.io(); + const now = try std.Io.Clock.real.now(io); + + var auth = try tls.config.CertKeyPair.fromSlice( + allocator, + io, + @embedFile("example/cert/localhost_ec/cert.pem"), + @embedFile("example/cert/localhost_ec/key.pem"), + ); + defer auth.deinit(allocator); + + var root_ca = try tls.config.cert.fromSlice(allocator, io, @embedFile("example/cert/minica.pem")); + defer root_ca.deinit(allocator); + + const opt: tls.config.Server = .{ .auth = &auth, .now = now }; + var server = try address.listen(io, .{}); + const thr = try std.Thread.spawn(.{}, acceptSend, .{ io, &server, opt, 3 }); + // client with insecure_skip_verify connects, server sends certificates but client skips verification + try connectReceive(io, server.socket.address, .{ .insecure_skip_verify = true, .host = host, .root_ca = .{}, .now = now }); + // client with root certificates connects; server certificates are validated + try connectReceive(io, server.socket.address, .{ .host = host, .root_ca = root_ca, .now = now }); + // client without insecure_skip_verify but not root ca fails; client can't verify server certificates + try testing.expectError( + error.CertificateIssuerNotFound, + connectReceive(io, server.socket.address, .{ .host = host, .root_ca = .{}, .now = now }), + ); + thr.join(); +} + test "server with rsa key key pair" { const allocator = testing.allocator; var threaded: std.Io.Threaded = .init(allocator, .{}); diff --git a/src/handshake_common.zig b/src/handshake_common.zig index 2b93208..f5d6b02 100644 --- a/src/handshake_common.zig +++ b/src/handshake_common.zig @@ -85,6 +85,18 @@ pub const CertKeyPair = struct { return .{ .bundle = bundle, .key = key, .ecdsa_key_pair = try EcdsaKeyPair.init(key) }; } + pub fn fromSlice( + allocator: mem.Allocator, + io: Io, + cert_slice: []const u8, + key_slice: []const u8, + ) !CertKeyPair { + const key = try PrivateKey.parsePem(key_slice); + const bundle = try cert.fromSlice(allocator, io, cert_slice); + + return .{ .bundle = bundle, .key = key, .ecdsa_key_pair = try EcdsaKeyPair.init(key) }; + } + pub fn deinit(c: *CertKeyPair, allocator: mem.Allocator) void { c.bundle.deinit(allocator); } @@ -140,6 +152,41 @@ pub const cert = struct { try bundle.rescan(allocator, io, try Io.Clock.real.now(io)); return bundle; } + + pub fn fromSlice(allocator: mem.Allocator, io: Io, slice: []const u8) !Bundle { + const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n"); + const size = slice.len; + const ts = try Io.Clock.real.now(io); + + var bundle: Bundle = .{}; + + //Contains modified code from std.crypto.Certificate.Bundle.addCertsFromFile + const decoded_size_upper_bound = size / 4 * 3; + const needed_capacity = std.math.cast(u32, decoded_size_upper_bound + size) orelse + return Certificate.Bundle.AddCertsFromFileError.CertificateAuthorityBundleTooBig; + try bundle.bytes.ensureUnusedCapacity(allocator, needed_capacity); + const end_reserved: u32 = @intCast(bundle.bytes.items.len + decoded_size_upper_bound); + const buffer = bundle.bytes.allocatedSlice()[end_reserved..]; + @memcpy(buffer[0..size], slice); + const encoded_bytes = buffer[0..size]; + + const begin_marker = "-----BEGIN CERTIFICATE-----"; + const end_marker = "-----END CERTIFICATE-----"; + + var start_index: usize = 0; + while (mem.indexOfPos(u8, encoded_bytes, start_index, begin_marker)) |begin_marker_start| { + const cert_start = begin_marker_start + begin_marker.len; + const cert_end = mem.indexOfPos(u8, encoded_bytes, cert_start, end_marker) orelse + return Certificate.Bundle.AddCertsFromFileError.MissingEndCertificateMarker; + start_index = cert_end + end_marker.len; + const encoded_cert = mem.trim(u8, encoded_bytes[cert_start..cert_end], " \t\r\n"); + const decoded_start: u32 = @intCast(bundle.bytes.items.len); + const dest_buf = bundle.bytes.allocatedSlice()[decoded_start..]; + bundle.bytes.items.len += try base64.decode(dest_buf, encoded_cert); + try bundle.parseCert(allocator, decoded_start, ts.toSeconds()); + } + return bundle; + } }; pub const CertificateBuilder = struct {