diff --git a/config/headers_test.go b/config/headers_test.go index c807fbc3..6678f037 100644 --- a/config/headers_test.go +++ b/config/headers_test.go @@ -17,8 +17,15 @@ package config import ( + "fmt" + "io" + "net" "net/http" + "net/http/httptest" + "strings" "testing" + + "github.com/stretchr/testify/require" ) func TestReservedHeaders(t *testing.T) { @@ -29,3 +36,111 @@ func TestReservedHeaders(t *testing.T) { } } } + +func TestHeadersRoundTripperSameHost(t *testing.T) { + // All headers, including sensitive ones, must be forwarded on same-host requests. + for _, header := range []string{"Cookie", "X-Custom-Header"} { + t.Run(header, func(t *testing.T) { + received := "" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + received = r.Header.Get(header) + fmt.Fprint(w, "ok") + })) + t.Cleanup(server.Close) + + headers := &Headers{ + Headers: map[string]Header{ + header: {Values: []string{"testvalue"}}, + }, + } + rt := NewHeadersRoundTripper(headers, http.DefaultTransport) + + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "ok", strings.TrimSpace(string(body))) + require.Equalf(t, "testvalue", received, "header %q must be forwarded on same-host request", header) + }) + } +} + +func TestHeadersRoundTripperCrossHostRedirect(t *testing.T) { + // Cookie must be set on the initial request but stripped on cross-host redirects. + cookieOnRedirect := "" + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookieOnRedirect = r.Header.Get("Cookie") + fmt.Fprint(w, "ok") + })) + t.Cleanup(target.Close) + + // Use "localhost" as the redirect target hostname so that it differs from + // "127.0.0.1" used by the origin server, making it a cross-host redirect. + targetPort := target.Listener.Addr().(*net.TCPAddr).Port + targetURL := fmt.Sprintf("http://localhost:%d", targetPort) + + cookieOnOrigin := "" + origin := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookieOnOrigin = r.Header.Get("Cookie") + http.Redirect(w, r, targetURL, http.StatusFound) + })) + t.Cleanup(origin.Close) + + cfg := HTTPClientConfig{ + FollowRedirects: true, + HTTPHeaders: &Headers{ + Headers: map[string]Header{ + "Cookie": {Values: []string{"session=abc"}}, + }, + }, + } + client, err := NewClientFromConfig(cfg, "test") + require.NoError(t, err) + + resp, err := client.Get(origin.URL) + require.NoError(t, err) + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + require.Equalf(t, "session=abc", cookieOnOrigin, "Cookie must be set on the initial request.") + require.Empty(t, cookieOnRedirect, "Cookie must not be forwarded on a cross-host redirect.") +} + +func TestHeadersRoundTripperSameHostRedirect(t *testing.T) { + // Cookie must be forwarded on same-host redirects. + mux := http.NewServeMux() + cookieOnRedirect := "" + mux.HandleFunc("/start", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/end", http.StatusFound) + }) + mux.HandleFunc("/end", func(w http.ResponseWriter, r *http.Request) { + cookieOnRedirect = r.Header.Get("Cookie") + fmt.Fprint(w, "ok") + }) + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + cfg := HTTPClientConfig{ + FollowRedirects: true, + HTTPHeaders: &Headers{ + Headers: map[string]Header{ + "Cookie": {Values: []string{"session=abc"}}, + }, + }, + } + client, err := NewClientFromConfig(cfg, "test") + require.NoError(t, err) + + resp, err := client.Get(server.URL + "/start") + require.NoError(t, err) + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + require.Equalf(t, "session=abc", cookieOnRedirect, "Cookie must be forwarded on a same-host redirect.") +} diff --git a/config/http_config.go b/config/http_config.go index 55cc5b07..2204cd45 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -608,10 +608,14 @@ func NewClientFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HTTPClie return nil, err } client := newClient(rt) - if !cfg.FollowRedirects { - client.CheckRedirect = func(*http.Request, []*http.Request) error { + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if !cfg.FollowRedirects { return http.ErrUseLastResponse } + if len(via) > 0 && !shouldSendCredentialsOnRedirect(via[0].URL, req.URL) { + *req = *req.WithContext(context.WithValue(req.Context(), crossHostRedirectKey{}, true)) + } + return nil } return client, nil } @@ -721,6 +725,11 @@ func NewRoundTripperFromConfigWithContext(ctx context.Context, cfg HTTPClientCon } if cfg.HTTPHeaders != nil { + // Strip sensitive headers added by headersRoundTripper on cross-host + // redirects before they reach the transport. + if cfg.FollowRedirects { + rt = &sensitiveHeadersStripRT{next: rt} + } rt = NewHeadersRoundTripper(cfg.HTTPHeaders, rt) } @@ -862,7 +871,7 @@ func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials Se } func (rt *authorizationCredentialsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if len(req.Header.Get("Authorization")) != 0 { + if len(req.Header.Get("Authorization")) != 0 || isCrossHostRedirect(req) { return rt.rt.RoundTrip(req) } @@ -900,7 +909,7 @@ func NewBasicAuthRoundTripper(username, password SecretReader, rt http.RoundTrip } func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if len(req.Header.Get("Authorization")) != 0 { + if len(req.Header.Get("Authorization")) != 0 || isCrossHostRedirect(req) { return rt.rt.RoundTrip(req) } var username string @@ -1085,6 +1094,9 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro rt.mtx.RLock() currentRT := rt.lastRT rt.mtx.RUnlock() + if isCrossHostRedirect(req) { + return currentRT.Base.RoundTrip(req) + } return currentRT.RoundTrip(req) } @@ -1106,6 +1118,78 @@ func mapToValues(m map[string]string) url.Values { return v } +// crossHostRedirectKey is the context key used to mark cross-host redirects. +type crossHostRedirectKey struct{} + +// isCrossHostRedirect reports whether req was marked as a cross-host redirect +// by the CheckRedirect handler. +func isCrossHostRedirect(req *http.Request) bool { + return req.Context().Value(crossHostRedirectKey{}) != nil +} + +// sensitiveHeadersOnRedirect lists the headers that must not be forwarded when +// following a redirect to a different host, mirroring the list in +// makeHeadersCopier in net/http/client.go. +var sensitiveHeadersOnRedirect = map[string]struct{}{ + "Authorization": {}, + "Www-Authenticate": {}, + "Cookie": {}, + "Cookie2": {}, + "Proxy-Authorization": {}, + "Proxy-Authenticate": {}, +} + +// sensitiveHeadersStripRT strips sensitive headers from requests marked as +// cross-host redirects before passing them to the underlying transport. +type sensitiveHeadersStripRT struct { + next http.RoundTripper +} + +func (rt *sensitiveHeadersStripRT) RoundTrip(req *http.Request) (*http.Response, error) { + if isCrossHostRedirect(req) { + req = cloneRequest(req) + for h := range sensitiveHeadersOnRedirect { + req.Header.Del(h) + } + } + return rt.next.RoundTrip(req) +} + +func (rt *sensitiveHeadersStripRT) CloseIdleConnections() { + if ci, ok := rt.next.(closeIdler); ok { + ci.CloseIdleConnections() + } +} + +// shouldSendCredentialsOnRedirect reports whether credentials from a request +// to initial should be forwarded when redirecting to dest. It mirrors the +// logic in shouldCopyHeaderOnRedirect from net/http/client.go: credentials +// are forwarded when dest is the same host as, or a subdomain of, initial. +// Port is not considered, matching Go's standard library behaviour. +func shouldSendCredentialsOnRedirect(initial, dest *url.URL) bool { + ihost := strings.ToLower(initial.Hostname()) + dhost := strings.ToLower(dest.Hostname()) + return isDomainOrSubdomain(dhost, ihost) +} + +// isDomainOrSubdomain reports whether sub is a subdomain (or exact match) of +// parent. It mirrors isDomainOrSubdomain from net/http/client.go. +func isDomainOrSubdomain(sub, parent string) bool { + if sub == parent { + return true + } + // A colon means sub is an IPv6 address; a percent sign introduces an IPv6 + // zone ID. Neither can be a hostname, and both could otherwise pass the + // suffix check below (e.g. "::1%.www.example.com" ends with "example.com"). + if strings.ContainsAny(sub, ":%") { + return false + } + if !strings.HasSuffix(sub, parent) { + return false + } + return sub[len(sub)-len(parent)-1] == '.' +} + // cloneRequest returns a clone of the provided *http.Request. // The clone is a shallow copy of the struct and its Header map. func cloneRequest(r *http.Request) *http.Request { diff --git a/config/http_config_test.go b/config/http_config_test.go index 9968d37a..d79d7d2f 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -1380,6 +1380,130 @@ func TestDefaultFollowRedirect(t *testing.T) { } } +func TestCrossHostRedirectDropsCredentials(t *testing.T) { + for _, tc := range []struct { + name string + config HTTPClientConfig + }{ + { + name: "bearer token", + config: HTTPClientConfig{ + FollowRedirects: true, + Authorization: &Authorization{ + Type: "Bearer", + Credentials: "secret-token", + }, + }, + }, + { + name: "basic auth", + config: HTTPClientConfig{ + FollowRedirects: true, + BasicAuth: &BasicAuth{ + Username: "user", + Password: "pass", + }, + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + // target listens on 127.0.0.1 but origin redirects using "localhost" + // as the hostname. "127.0.0.1" and "localhost" are different hostname + // strings, so Go's redirect rules strip credentials on the redirect. + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "" { + http.Error(w, "credentials leaked to cross-host redirect target", http.StatusForbidden) + return + } + fmt.Fprint(w, ExpectedMessage) + })) + t.Cleanup(target.Close) + + // Build a redirect URL that uses "localhost" instead of "127.0.0.1". + targetPort := target.Listener.Addr().(*net.TCPAddr).Port + targetLocalhostURL := fmt.Sprintf("http://localhost:%d", targetPort) + + origin := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, targetLocalhostURL+r.URL.Path, http.StatusFound) + })) + t.Cleanup(origin.Close) + + client, err := NewClientFromConfig(tc.config, "test") + require.NoError(t, err) + + resp, err := client.Get(origin.URL) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, ExpectedMessage, strings.TrimSpace(string(body))) + }) + } +} + +func TestIsDomainOrSubdomain(t *testing.T) { + for _, tc := range []struct { + sub, parent string + want bool + }{ + {"example.com", "example.com", true}, + {"sub.example.com", "example.com", true}, + {"deep.sub.example.com", "example.com", true}, + {"notexample.com", "example.com", false}, + {"example.com", "sub.example.com", false}, + {"bar.com", "foo.com", false}, + {"127.0.0.1", "127.0.0.1", true}, + {"localhost", "127.0.0.1", false}, + {"127.0.0.1", "localhost", false}, + {"::1", "::1", true}, + {"::2", "::1", false}, + {"::1", "example.com", false}, + // Zone ID containing a hostname must not match as a subdomain. + {"::1%.www.example.com", "example.com", false}, + {"fe80::1%eth0", "eth0", false}, + } { + t.Run(tc.sub+"→"+tc.parent, func(t *testing.T) { + require.Equal(t, tc.want, isDomainOrSubdomain(tc.sub, tc.parent)) + }) + } +} + +func TestSameHostRedirectKeepsCredentials(t *testing.T) { + credsSeen := false + mux := http.NewServeMux() + mux.HandleFunc("/start", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/end", http.StatusFound) + }) + mux.HandleFunc("/end", func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "" { + credsSeen = true + } + fmt.Fprint(w, ExpectedMessage) + }) + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + cfg := HTTPClientConfig{ + FollowRedirects: true, + Authorization: &Authorization{ + Type: "Bearer", + Credentials: "secret-token", + }, + } + client, err := NewClientFromConfig(cfg, "test") + require.NoError(t, err) + + resp, err := client.Get(server.URL + "/start") + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, ExpectedMessage, strings.TrimSpace(string(body))) + require.Truef(t, credsSeen, "credentials should be forwarded on same-host redirect") +} + func TestValidateHTTPConfig(t *testing.T) { cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.good.yml") if err != nil {