diff --git a/common/httputilz/httputilz.go b/common/httputilz/httputilz.go index 94bafa881..492fba691 100644 --- a/common/httputilz/httputilz.go +++ b/common/httputilz/httputilz.go @@ -24,8 +24,8 @@ func DumpRequest(req *retryablehttp.Request) (string, error) { } // ParseRequest from raw string -func ParseRequest(req string, unsafe bool) (method, path string, headers map[string]string, body string, err error) { - headers = make(map[string]string) +func ParseRequest(req string, unsafe bool) (method, path string, headers map[string][]string, body string, err error) { + headers = make(map[string][]string) reader := bufio.NewReader(strings.NewReader(req)) s, err := reader.ReadString('\n') if err != nil { @@ -68,7 +68,7 @@ func ParseRequest(req string, unsafe bool) (method, path string, headers map[str value = strings.TrimSpace(value) } - headers[key] = value + headers[key] = append(headers[key], value) } // Handle case with the full http url in path. In that case, @@ -81,7 +81,7 @@ func ParseRequest(req string, unsafe bool) (method, path string, headers map[str return } path = parts[1] - headers["Host"] = parsed.Host + headers["Host"] = []string{parsed.Host} } else { path = parts[1] } diff --git a/common/httputilz/httputilz_test.go b/common/httputilz/httputilz_test.go new file mode 100644 index 000000000..ffc4f1645 --- /dev/null +++ b/common/httputilz/httputilz_test.go @@ -0,0 +1,93 @@ +package httputilz + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseRequestPreservesDuplicateHeaders(t *testing.T) { + raw := strings.Join([]string{ + "GET /anything HTTP/1.1", + "Host: example.com", + "X-Test: one", + "X-Test: two", + "", + "", + }, "\r\n") + + method, path, headers, _, err := ParseRequest(raw, false) + require.NoError(t, err) + require.Equal(t, "GET", method) + require.Equal(t, "/anything", path) + require.Equal(t, []string{"one", "two"}, headers["X-Test"]) +} + +func TestParseRequestFullURLPathSetsHost(t *testing.T) { + raw := strings.Join([]string{ + "GET https://example.com/anything HTTP/1.1", + "Host: ignored.example", + "", + "", + }, "\r\n") + + _, path, headers, _, err := ParseRequest(raw, false) + require.NoError(t, err) + require.Equal(t, "https://example.com/anything", path) + require.Equal(t, []string{"example.com"}, headers["Host"]) +} + +func TestParseRequestParsesBody(t *testing.T) { + raw := strings.Join([]string{ + "POST /submit HTTP/1.1", + "Host: example.com", + "Content-Type: application/json", + "", + `{"a":1}`, + }, "\r\n") + + method, path, headers, body, err := ParseRequest(raw, false) + require.NoError(t, err) + require.Equal(t, "POST", method) + require.Equal(t, "/submit", path) + require.Equal(t, []string{"application/json"}, headers["Content-Type"]) + require.Equal(t, `{"a":1}`, body) +} + +func TestParseRequestSafeStripsContentLength(t *testing.T) { + raw := strings.Join([]string{ + "GET /anything HTTP/1.1", + "Host: example.com", + "Content-Length: 0", + "", + "", + }, "\r\n") + + _, _, headers, _, err := ParseRequest(raw, false) + require.NoError(t, err) + require.NotContains(t, headers, "Content-Length") +} + +func TestParseRequestUnsafePreservesRawHeaders(t *testing.T) { + raw := strings.Join([]string{ + "GET /anything HTTP/1.1", + "Host: example.com", + "Content-Length: 0", + "X-Test: one", + "X-Test: two", + "", + "", + }, "\r\n") + + _, _, headers, _, err := ParseRequest(raw, true) + require.NoError(t, err) + // unsafe mode keeps content-length and does not trim values + require.Contains(t, headers, "Content-Length") + require.Equal(t, []string{" one", " two"}, headers["X-Test"]) +} + +func TestParseRequestMalformed(t *testing.T) { + _, _, _, _, err := ParseRequest("GET\r\n\r\n", false) + require.Error(t, err) +} diff --git a/common/httpx/httpx.go b/common/httpx/httpx.go index 15c567f06..27f4d63bc 100644 --- a/common/httpx/httpx.go +++ b/common/httpx/httpx.go @@ -7,6 +7,7 @@ import ( "io" "net" "net/http" + "net/textproto" "net/url" "os" "strconv" @@ -36,7 +37,7 @@ type HTTPX struct { Filters []Filter Options *Options htmlPolicy *bluemonday.Policy - CustomHeaders map[string]string + CustomHeaders map[string][]string cdn *cdncheck.Client Dialer *fastdialer.Dialer NetworkPolicy *networkpolicy.NetworkPolicy @@ -434,19 +435,31 @@ func (h *HTTPX) NewRequestWithContext(ctx context.Context, method, targetURL str } // SetCustomHeaders on the provided request -func (h *HTTPX) SetCustomHeaders(r *retryablehttp.Request, headers map[string]string) { - for name, value := range headers { - switch strings.ToLower(name) { - case "host": - r.Host = value - if h.Options.Unsafe { - r.Header.Set("Host", value) +func (h *HTTPX) SetCustomHeaders(r *retryablehttp.Request, headers map[string][]string) { + // Coalesce values by canonical header key first. net/http canonicalizes keys + // on Del/Add, so case-variant duplicates (e.g. "X-Test" and "x-test") would + // otherwise have the second key's Del wipe the values added for the first. + normalized := make(map[string][]string, len(headers)) + for name, values := range headers { + canonical := textproto.CanonicalMIMEHeaderKey(name) + normalized[canonical] = append(normalized[canonical], values...) + } + + for name, values := range normalized { + r.Header.Del(name) + for _, value := range values { + switch strings.ToLower(name) { + case "host": + r.Host = value + if h.Options.Unsafe { + r.Header.Add("Host", value) + } + case "cookie": + // cookies are set in the default branch, and reset during the follow redirect flow + fallthrough + default: + r.Header.Add(name, value) } - case "cookie": - // cookies are set in the default branch, and reset during the follow redirect flow - fallthrough - default: - r.Header.Set(name, value) } } if h.Options.RandomAgent { diff --git a/common/httpx/httpx_test.go b/common/httpx/httpx_test.go index 9131e3b93..0dd9cbbca 100644 --- a/common/httpx/httpx_test.go +++ b/common/httpx/httpx_test.go @@ -29,6 +29,79 @@ func TestDo(t *testing.T) { }) } +func TestSetCustomHeaders(t *testing.T) { + h := &HTTPX{Options: &Options{}} + + t.Run("duplicate values preserved in order", func(t *testing.T) { + req, err := retryablehttp.NewRequest(http.MethodGet, "https://example.com", nil) + require.NoError(t, err) + h.SetCustomHeaders(req, map[string][]string{"X-Test": {"one", "two"}}) + require.Equal(t, []string{"one", "two"}, req.Header.Values("X-Test")) + }) + + t.Run("case-variant duplicates are coalesced", func(t *testing.T) { + req, err := retryablehttp.NewRequest(http.MethodGet, "https://example.com", nil) + require.NoError(t, err) + h.SetCustomHeaders(req, map[string][]string{"X-Test": {"one"}, "x-test": {"two"}}) + require.ElementsMatch(t, []string{"one", "two"}, req.Header.Values("X-Test")) + }) + + t.Run("custom header replaces existing value", func(t *testing.T) { + req, err := retryablehttp.NewRequest(http.MethodGet, "https://example.com", nil) + require.NoError(t, err) + req.Header.Set("User-Agent", "default-agent") + h.SetCustomHeaders(req, map[string][]string{"User-Agent": {"custom-agent"}}) + require.Equal(t, []string{"custom-agent"}, req.Header.Values("User-Agent")) + }) + + t.Run("host header sets request host", func(t *testing.T) { + req, err := retryablehttp.NewRequest(http.MethodGet, "https://example.com", nil) + require.NoError(t, err) + h.SetCustomHeaders(req, map[string][]string{"Host": {"custom.host"}}) + require.Equal(t, "custom.host", req.Host) + require.Empty(t, req.Header.Values("Host")) + }) + + t.Run("multiple distinct headers preserved", func(t *testing.T) { + req, err := retryablehttp.NewRequest(http.MethodGet, "https://example.com", nil) + require.NoError(t, err) + h.SetCustomHeaders(req, map[string][]string{"X-One": {"1"}, "X-Two": {"2"}}) + require.Equal(t, []string{"1"}, req.Header.Values("X-One")) + require.Equal(t, []string{"2"}, req.Header.Values("X-Two")) + }) + + t.Run("multiple cookie values preserved", func(t *testing.T) { + req, err := retryablehttp.NewRequest(http.MethodGet, "https://example.com", nil) + require.NoError(t, err) + h.SetCustomHeaders(req, map[string][]string{"Cookie": {"a=1", "b=2"}}) + require.Equal(t, []string{"a=1", "b=2"}, req.Header.Values("Cookie")) + }) + + t.Run("empty value applied as-is", func(t *testing.T) { + req, err := retryablehttp.NewRequest(http.MethodGet, "https://example.com", nil) + require.NoError(t, err) + h.SetCustomHeaders(req, map[string][]string{"X-Empty": {""}}) + require.Equal(t, []string{""}, req.Header.Values("X-Empty")) + }) + + t.Run("unsafe raw header line stored verbatim as key", func(t *testing.T) { + hu := &HTTPX{Options: &Options{Unsafe: true}} + req, err := retryablehttp.NewRequest(http.MethodGet, "https://example.com", nil) + require.NoError(t, err) + // in unsafe mode the runner stores the whole raw header line as the key + // with an empty value; it must survive canonicalization untouched + hu.SetCustomHeaders(req, map[string][]string{"X-Test: one": {""}}) + require.Equal(t, []string{""}, req.Header.Values("X-Test: one")) + }) +} + +func TestParseCustomCookies(t *testing.T) { + options := &Options{CustomHeaders: map[string][]string{"Cookie": {"a=1", "b=2"}}} + options.parseCustomCookies() + require.True(t, options.hasCustomCookies()) + require.Len(t, options.customCookies, 2) +} + func TestHTTP11DisablesRetryableHTTP2FallbackClient(t *testing.T) { options := DefaultOptions options.Protocol = HTTP11 diff --git a/common/httpx/option.go b/common/httpx/option.go index fb1087296..56d11cbc0 100644 --- a/common/httpx/option.go +++ b/common/httpx/option.go @@ -36,7 +36,7 @@ type Options struct { Timeout time.Duration // RetryMax is the maximum number of retries RetryMax int - CustomHeaders map[string]string + CustomHeaders map[string][]string // VHostSimilarityRatio 1 - 100 VHostSimilarityRatio int FollowRedirects bool @@ -90,7 +90,7 @@ func (options *Options) parseCustomCookies() { // parse and fill the custom field for k, v := range options.CustomHeaders { if strings.EqualFold(k, "cookie") { - req := http.Request{Header: http.Header{"Cookie": []string{v}}} + req := http.Request{Header: http.Header{"Cookie": v}} options.customCookies = req.Cookies() } } diff --git a/runner/runner.go b/runner/runner.go index 352e4c74c..6d6110ec6 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -30,14 +30,14 @@ import ( "github.com/PuerkitoBio/goquery" "github.com/corona10/goimagehash" "github.com/gocarina/gocsv" + "github.com/happyhackingspace/dit" "github.com/mfonda/simhash" asnmap "github.com/projectdiscovery/asnmap/libs" "github.com/projectdiscovery/fastdialer/fastdialer" + "github.com/projectdiscovery/httpx/common/authprovider" "github.com/projectdiscovery/httpx/common/customextract" "github.com/projectdiscovery/httpx/common/hashes/jarm" "github.com/projectdiscovery/httpx/common/inputformats" - "github.com/happyhackingspace/dit" - "github.com/projectdiscovery/httpx/common/authprovider" "github.com/projectdiscovery/httpx/static" "github.com/projectdiscovery/mapcidr/asn" "github.com/projectdiscovery/networkpolicy" @@ -238,12 +238,12 @@ func New(options *Options) (*Runner, error) { httpxOptions.Protocol = httpx.Proto(options.Protocol) var key, value string - httpxOptions.CustomHeaders = make(map[string]string) + httpxOptions.CustomHeaders = make(map[string][]string) for _, customHeader := range options.CustomHeaders { tokens := strings.SplitN(customHeader, ":", two) // rawhttp skips all checks if options.Unsafe { - httpxOptions.CustomHeaders[customHeader] = "" + httpxOptions.CustomHeaders[customHeader] = []string{""} continue } @@ -253,7 +253,7 @@ func New(options *Options) (*Runner, error) { } key = strings.TrimSpace(tokens[0]) value = strings.TrimSpace(tokens[1]) - httpxOptions.CustomHeaders[key] = value + httpxOptions.CustomHeaders[key] = append(httpxOptions.CustomHeaders[key], value) } httpxOptions.SniName = options.SniName @@ -278,7 +278,7 @@ func New(options *Options) (*Runner, error) { scanopts.Methods = append(scanopts.Methods, rrMethod) scanopts.RequestURI = rrPath for name, value := range rrHeaders { - httpxOptions.CustomHeaders[name] = value + httpxOptions.CustomHeaders[name] = append(httpxOptions.CustomHeaders[name], value...) } scanopts.RequestBody = rrBody options.rawRequest = string(rawRequest)