diff --git a/cmd/scan_direct.go b/cmd/scan_direct.go index f1996f7..673cd9c 100644 --- a/cmd/scan_direct.go +++ b/cmd/scan_direct.go @@ -3,8 +3,11 @@ package cmd import ( "context" "crypto/tls" + "errors" "fmt" + "io" "net" + "net/url" "strconv" "strings" "time" @@ -29,8 +32,12 @@ var ( directFlagTimeoutConnect int directFlagTimeoutRequest int directFlagTimeoutDNS int + directFlagVerbose bool + directFlagFollowRedirect bool ) +const directMaxRedirects = 10 + func init() { rootCmd.AddCommand(directCmd) @@ -42,6 +49,8 @@ func init() { directCmd.Flags().IntVar(&directFlagTimeoutConnect, "timeout-connect", 5, "TCP connect timeout in seconds") directCmd.Flags().IntVar(&directFlagTimeoutRequest, "timeout-request", 10, "Overall request timeout in seconds") directCmd.Flags().IntVar(&directFlagTimeoutDNS, "timeout-dns", 5, "DNS lookup timeout in seconds") + directCmd.Flags().BoolVarP(&directFlagVerbose, "verbose", "v", false, "log skipped hosts and the reason (refused, timeout, no response, etc.)") + directCmd.Flags().BoolVar(&directFlagFollowRedirect, "follow-redirects", false, "follow 3xx redirects and report status/server of each hop") } func parsePorts(portSpec string) ([]string, error) { @@ -95,88 +104,207 @@ func extractHTTPHeaders(response string) (statusCode int, server string, locatio return statusCode, server, location } -func scanDirect(ctx *queuescanner.Ctx, host string) { - ports, err := parsePorts(directFlagPort) - if err != nil { - return +func describeError(err error) string { + if err == nil { + return "unknown error" } + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) { + return "dns lookup failed" + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return "timeout" + } + + switch { + case errors.Is(err, io.EOF): + return "no response (connection closed)" + case strings.Contains(err.Error(), "connection refused"): + return "connection refused" + case strings.Contains(err.Error(), "connection reset"): + return "connection reset" + case strings.Contains(err.Error(), "no route to host"): + return "no route to host" + case strings.Contains(err.Error(), "network is unreachable"): + return "network unreachable" + default: + return err.Error() + } +} + +func logSkip(ctx *queuescanner.Ctx, target, reason string) { + if directFlagVerbose { + ctx.Log(fmt.Sprintf("[skip] %-32s %s", target, reason)) + } +} + +func isHTTPSPort(port string) bool { + for _, httpsPort := range []string{"443", "8443", "9443", "10443"} { + if port == httpsPort { + return true + } + } + return false +} + +func isRedirect(statusCode int) bool { + return statusCode >= 300 && statusCode < 400 +} + +func directRequest(host, port, path string, useTLS bool, method string) (ipStr string, statusCode int, server, location string, err error) { lookupCtx, cancel := context.WithTimeout(context.Background(), time.Duration(directFlagTimeoutDNS)*time.Second) defer cancel() ips, err := net.DefaultResolver.LookupIP(lookupCtx, "ip4", host) - if err != nil || len(ips) == 0 { - return + if err != nil { + return "", 0, "", "", err + } + if len(ips) == 0 { + return "", 0, "", "", fmt.Errorf("no DNS records for %s", host) } + ipStr = ips[0].String() - ip := ips[0] - ipStr := ip.String() + address := net.JoinHostPort(ipStr, port) + network := "tcp4" - for _, port := range ports { - useTLS := false - commonHTTPSPorts := []string{"443", "8443", "9443", "10443"} - for _, httpsPort := range commonHTTPSPorts { - if port == httpsPort { - useTLS = true - break - } - } + dialer := &net.Dialer{ + Timeout: time.Duration(directFlagTimeoutConnect) * time.Second, + } - address := fmt.Sprintf("%s:%s", ipStr, port) - network := "tcp4" + var conn net.Conn + if useTLS { + conn, err = tls.DialWithDialer(dialer, network, address, &tls.Config{ + InsecureSkipVerify: true, + ServerName: host, + }) + } else { + conn, err = dialer.Dial(network, address) + } + if err != nil { + return ipStr, 0, "", "", err + } + defer conn.Close() - dialer := &net.Dialer{ - Timeout: time.Duration(directFlagTimeoutConnect) * time.Second, + conn.SetDeadline(time.Now().Add(time.Duration(directFlagTimeoutRequest) * time.Second)) + + if path == "" { + path = "/" + } + httpRequest := fmt.Sprintf("%s %s HTTP/1.1\r\nHost: %s\r\nUser-Agent: bugscanx-go/1.0\r\nAccept: */*\r\n\r\n", method, path, host) + + if _, err = conn.Write([]byte(httpRequest)); err != nil { + return ipStr, 0, "", "", err + } + + buffer := make([]byte, 4096) + n, err := conn.Read(buffer) + if err != nil { + return ipStr, 0, "", "", err + } + + statusCode, server, location = extractHTTPHeaders(string(buffer[:n])) + return ipStr, statusCode, server, location, nil +} + +func formatDirectResult(ipStr string, statusCode int, server, hostWithPort string) string { + return fmt.Sprintf("%-15s %-3d %-16s %s", ipStr, statusCode, server, hostWithPort) +} + +func followRedirects(base *url.URL, location, method string) []string { + var records []string + visited := make(map[string]bool) + recorded := map[string]bool{base.Host: true} + current := base + + for i := 0; i < directMaxRedirects; i++ { + ref, err := url.Parse(strings.TrimSpace(location)) + if err != nil { + break } - var conn net.Conn - if useTLS { - conn, err = tls.DialWithDialer(dialer, network, address, &tls.Config{ - InsecureSkipVerify: true, - ServerName: host, - }) - } else { - conn, err = dialer.Dial(network, address) + next := current.ResolveReference(ref) + absLocation := next.String() + if visited[absLocation] { + break + } + visited[absLocation] = true + + useTLS := next.Scheme == "https" + port := next.Port() + if port == "" { + if useTLS { + port = "443" + } else { + port = "80" + } } + + ipStr, statusCode, server, nextLocation, err := directRequest(next.Hostname(), port, next.RequestURI(), useTLS, method) if err != nil { - continue + break } - conn.SetDeadline(time.Now().Add(time.Duration(directFlagTimeoutRequest) * time.Second)) + hostWithPort := net.JoinHostPort(next.Hostname(), port) + if !recorded[hostWithPort] { + recorded[hostWithPort] = true + records = append(records, formatDirectResult(ipStr, statusCode, server, hostWithPort)) + } - method := directFlagMethod - if method == "" { - method = "HEAD" + if !isRedirect(statusCode) || nextLocation == "" { + break } - httpRequest := fmt.Sprintf("%s / HTTP/1.1\r\nHost: %s\r\nUser-Agent: bugscanx-go/1.0\r\nConnection: close\r\n\r\n", method, host) + current = next + location = nextLocation + } + + return records +} - _, err = conn.Write([]byte(httpRequest)) - if err != nil { - conn.Close() - continue - } +func scanDirect(ctx *queuescanner.Ctx, host string) { + ports, err := parsePorts(directFlagPort) + if err != nil { + return + } - buffer := make([]byte, 4096) - n, err := conn.Read(buffer) - conn.Close() + method := directFlagMethod + if method == "" { + method = "HEAD" + } + for _, port := range ports { + useTLS := isHTTPSPort(port) + hostWithPort := net.JoinHostPort(host, port) + + ipStr, statusCode, server, location, err := directRequest(host, port, "/", useTLS, method) if err != nil { + logSkip(ctx, hostWithPort, describeError(err)) continue } - response := string(buffer[:n]) - statusCode, server, location := extractHTTPHeaders(response) - if directFlagHideLocation != "" && location == directFlagHideLocation { continue } - hostWithPort := fmt.Sprintf("%s:%s", host, port) - formatted := fmt.Sprintf("%-15s %-3d %-16s %s", ipStr, statusCode, server, hostWithPort) + formatted := formatDirectResult(ipStr, statusCode, server, hostWithPort) ctx.ScanSuccess(formatted) ctx.Log(formatted) + + if directFlagFollowRedirect && isRedirect(statusCode) && location != "" { + scheme := "http" + if useTLS { + scheme = "https" + } + base := &url.URL{Scheme: scheme, Host: hostWithPort} + for _, record := range followRedirects(base, location, method) { + ctx.ScanSuccess(record) + ctx.Log(record) + } + } } }