diff --git a/protocol/triple/triple_protocol/handler.go b/protocol/triple/triple_protocol/handler.go index ff12d9bb18..5a6efca574 100644 --- a/protocol/triple/triple_protocol/handler.go +++ b/protocol/triple/triple_protocol/handler.go @@ -106,6 +106,7 @@ func generateUnaryHandlerFunc( request.header = conn.RequestHeader() // embed header in context so that user logic could process them via FromIncomingContext ctx = newIncomingContext(ctx, conn.RequestHeader()) + ctx = context.WithValue(ctx, handlerOutgoingKey{}, conn) response, err := untyped(ctx, request) diff --git a/protocol/triple/triple_protocol/triple_ext_test.go b/protocol/triple/triple_protocol/triple_ext_test.go index 639db4edb4..81fb82710a 100644 --- a/protocol/triple/triple_protocol/triple_ext_test.go +++ b/protocol/triple/triple_protocol/triple_ext_test.go @@ -522,6 +522,39 @@ func TestServer(t *testing.T) { }) } +func TestSetHeaderAndSetTrailerInUnaryHandler(t *testing.T) { + t.Parallel() + + handler := triple.NewUnaryHandler( + "/connect.ping.v1.PingService/Ping", + func() any { return new(pingv1.PingRequest) }, + func(ctx context.Context, req *triple.Request) (*triple.Response, error) { + if err := triple.SetHeader(ctx, http.Header{handlerHeader: []string{headerValue}}); err != nil { + return nil, err + } + if err := triple.SetTrailer(ctx, http.Header{handlerTrailer: []string{trailerValue}}); err != nil { + return nil, err + } + + msg := req.Msg.(*pingv1.PingRequest) + return triple.NewResponse(&pingv1.PingResponse{ + Number: msg.Number, + Text: msg.Text, + }), nil + }, + ) + server := httptest.NewServer(handler) + t.Cleanup(server.Close) + + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + request := triple.NewRequest(&pingv1.PingRequest{Number: 42}) + response := triple.NewResponse(&pingv1.PingResponse{}) + err := client.Ping(context.Background(), request, response) + assert.Nil(t, err) + assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) +} + func TestConcurrentStreams(t *testing.T) { if testing.Short() { t.Skipf("skipping %s test in short mode", t.Name())