diff --git a/cmd/tls/handlers/post.go b/cmd/tls/handlers/post.go index 2c82cf48..915300ad 100644 --- a/cmd/tls/handlers/post.go +++ b/cmd/tls/handlers/post.go @@ -1138,37 +1138,42 @@ func (h *HandlersTLS) EnrollPackageHandler(w http.ResponseWriter, r *http.Reques var fDesc, fName, fPath string switch packageVar { case settings.PackageDeb: - if strings.HasPrefix(env.DebPackage, "http") { + if strings.HasPrefix(env.DebPackage, "https://") { http.Redirect(w, r, env.DebPackage, http.StatusFound) return } fDesc = "Enrolling DEB Package for Linux" fName = genPackageFilename(env.Name, settings.PackageDeb, version.OsqueryVersion, version.OsctrlVersion) - fPath = fmt.Sprintf("%s/%s/%s", enrollPackagesPath, env.Name, env.DebPackage) + fPath, err = environments.PackageFilePath(enrollPackagesPath, env.Name, env.DebPackage) case settings.PackageRpm: - if strings.HasPrefix(env.RpmPackage, "http") { + if strings.HasPrefix(env.RpmPackage, "https://") { http.Redirect(w, r, env.RpmPackage, http.StatusFound) return } fDesc = "Enrolling RPM Package for Linux" fName = genPackageFilename(env.Name, settings.PackageRpm, version.OsqueryVersion, version.OsctrlVersion) - fPath = fmt.Sprintf("%s/%s/%s", enrollPackagesPath, env.Name, env.RpmPackage) + fPath, err = environments.PackageFilePath(enrollPackagesPath, env.Name, env.RpmPackage) case settings.PackagePkg: - if strings.HasPrefix(env.PkgPackage, "http") { + if strings.HasPrefix(env.PkgPackage, "https://") { http.Redirect(w, r, env.PkgPackage, http.StatusFound) return } fDesc = "Enrolling PKG Package for Mac" fName = genPackageFilename(env.Name, settings.PackagePkg, version.OsqueryVersion, version.OsctrlVersion) - fPath = fmt.Sprintf("%s/%s/%s", enrollPackagesPath, env.Name, env.PkgPackage) + fPath, err = environments.PackageFilePath(enrollPackagesPath, env.Name, env.PkgPackage) case settings.PackageMsi: - if strings.HasPrefix(env.MsiPackage, "http") { + if strings.HasPrefix(env.MsiPackage, "https://") { http.Redirect(w, r, env.MsiPackage, http.StatusFound) return } fDesc = "Enrolling MSI Package for Windows" fName = genPackageFilename(env.Name, settings.PackageMsi, defOsqueryVersion, version.OsctrlVersion) - fPath = fmt.Sprintf("%s/%s/%s", enrollPackagesPath, env.Name, env.MsiPackage) + fPath, err = environments.PackageFilePath(enrollPackagesPath, env.Name, env.MsiPackage) + } + if err != nil { + log.Err(err).Msg("invalid package path") + utils.HTTPResponse(w, "", http.StatusBadRequest, []byte("")) + return } // Initiate download fi, err := os.Stat(fPath) diff --git a/pkg/environments/environments.go b/pkg/environments/environments.go index 3c3d8ae1..7c55e042 100644 --- a/pkg/environments/environments.go +++ b/pkg/environments/environments.go @@ -415,6 +415,9 @@ func (environment *EnvManager) UpdateCertificate(idEnv, certificate string) erro // UpdateDebPackage to update DEB package for an environment func (environment *EnvManager) UpdateDebPackage(idEnv, debpackage string) error { + if err := ValidatePackageReference(debpackage); err != nil { + return fmt.Errorf("UpdateDebPackage %w", err) + } if err := environment.DB.Model(&TLSEnvironment{}).Where("name = ? OR uuid = ?", idEnv, idEnv).Update("deb_package", debpackage).Error; err != nil { return fmt.Errorf("UpdateDebPackage %w", err) } @@ -423,6 +426,9 @@ func (environment *EnvManager) UpdateDebPackage(idEnv, debpackage string) error // UpdateRpmPackage to update RPM package for an environment func (environment *EnvManager) UpdateRpmPackage(idEnv, rpmpackage string) error { + if err := ValidatePackageReference(rpmpackage); err != nil { + return fmt.Errorf("UpdateRpmPackage %w", err) + } if err := environment.DB.Model(&TLSEnvironment{}).Where("name = ? OR uuid = ?", idEnv, idEnv).Update("rpm_package", rpmpackage).Error; err != nil { return fmt.Errorf("UpdateRpmPackage %w", err) } @@ -431,6 +437,9 @@ func (environment *EnvManager) UpdateRpmPackage(idEnv, rpmpackage string) error // UpdateMsiPackage to update MSI package for an environment func (environment *EnvManager) UpdateMsiPackage(idEnv, msipackage string) error { + if err := ValidatePackageReference(msipackage); err != nil { + return fmt.Errorf("UpdateMsiPackage %w", err) + } if err := environment.DB.Model(&TLSEnvironment{}).Where("name = ? OR uuid = ?", idEnv, idEnv).Update("msi_package", msipackage).Error; err != nil { return fmt.Errorf("UpdateMsiPackage %w", err) } @@ -439,6 +448,9 @@ func (environment *EnvManager) UpdateMsiPackage(idEnv, msipackage string) error // UpdatePkgPackage to update PKG package for an environment func (environment *EnvManager) UpdatePkgPackage(idEnv, pkgpackage string) error { + if err := ValidatePackageReference(pkgpackage); err != nil { + return fmt.Errorf("UpdatePkgPackage %w", err) + } if err := environment.DB.Model(&TLSEnvironment{}).Where("name = ? OR uuid = ?", idEnv, idEnv).Update("pkg_package", pkgpackage).Error; err != nil { return fmt.Errorf("UpdatePkgPackage %w", err) } diff --git a/pkg/environments/package_test.go b/pkg/environments/package_test.go new file mode 100644 index 00000000..187e00eb --- /dev/null +++ b/pkg/environments/package_test.go @@ -0,0 +1,34 @@ +package environments + +import ( + "path/filepath" + "testing" +) + +func TestPackageReferences(t *testing.T) { + env := TLSEnvironment{Hostname: "host", UUID: "uuid", Secret: "secret"} + + if got := PackageDownloadURL(env, "osquery.deb"); got != "https://host/uuid/secret/package/osquery.deb" { + t.Fatalf("local package URL = %q", got) + } + if got := PackageDownloadURL(env, "https://example.com/osquery.deb"); got != "https://example.com/osquery.deb" { + t.Fatalf("HTTPS package URL = %q", got) + } + if got := PackageDownloadURL(env, "../secret"); got != "" { + t.Fatalf("traversal package URL = %q", got) + } + + got, err := PackageFilePath("packages", "prod", "osquery.deb") + if err != nil { + t.Fatal(err) + } + if want := filepath.Join("packages", "prod", "osquery.deb"); got != want { + t.Fatalf("package path = %q, want %q", got, want) + } + + for _, pkg := range []string{"../secret", "dir/osquery.deb", `dir\osquery.deb`, "/tmp/osquery.deb", ".", ".."} { + if _, err := PackageFilePath("packages", "prod", pkg); err == nil { + t.Fatalf("PackageFilePath accepted %q", pkg) + } + } +} diff --git a/pkg/environments/util.go b/pkg/environments/util.go index b01a49a2..6eed3798 100644 --- a/pkg/environments/util.go +++ b/pkg/environments/util.go @@ -3,6 +3,7 @@ package environments import ( "fmt" "os" + "path/filepath" "strings" "time" ) @@ -49,12 +50,37 @@ func PackageDownloadURL(env TLSEnvironment, pkg string) string { if pkg == "" { return "" } + if err := ValidatePackageReference(pkg); err != nil { + return "" + } if strings.HasPrefix(pkg, "https://") { return pkg } return fmt.Sprintf("https://%s/%s/%s/package/%s", env.Hostname, env.UUID, env.Secret, pkg) } +// ValidatePackageReference allows HTTPS package URLs or local package basenames. +func ValidatePackageReference(pkg string) error { + if pkg == "" || strings.HasPrefix(pkg, "https://") { + return nil + } + if pkg == "." || pkg == ".." || filepath.IsAbs(pkg) || strings.ContainsAny(pkg, `/\`) { + return fmt.Errorf("invalid package path %q", pkg) + } + return nil +} + +// PackageFilePath builds the local package path after validating the stored package value. +func PackageFilePath(packageRoot, envName, pkg string) (string, error) { + if err := ValidatePackageReference(pkg); err != nil { + return "", err + } + if strings.HasPrefix(pkg, "https://") { + return "", fmt.Errorf("package URL has no local file path") + } + return filepath.Join(packageRoot, envName, pkg), nil +} + // EnvironmentFinderID to find the environment and return its name based on the environment ID func EnvironmentFinderID(envID uint, envs []TLSEnvironment, uuid bool) string { if envID == 0 {