Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions cmd/tls/handlers/post.go
Original file line number Diff line number Diff line change
Expand Up @@ -1138,37 +1138,42 @@ func (h *HandlersTLS) EnrollPackageHandler(w http.ResponseWriter, r *http.Reques
var fDesc, fName, fPath string
switch packageVar {
case settings.PackageDeb:
if strings.HasPrefix(env.DebPackage, "http") {
if strings.HasPrefix(env.DebPackage, "https://") {
http.Redirect(w, r, env.DebPackage, http.StatusFound)
return
}
fDesc = "Enrolling DEB Package for Linux"
fName = genPackageFilename(env.Name, settings.PackageDeb, version.OsqueryVersion, version.OsctrlVersion)
fPath = fmt.Sprintf("%s/%s/%s", enrollPackagesPath, env.Name, env.DebPackage)
fPath, err = environments.PackageFilePath(enrollPackagesPath, env.Name, env.DebPackage)
case settings.PackageRpm:
if strings.HasPrefix(env.RpmPackage, "http") {
if strings.HasPrefix(env.RpmPackage, "https://") {
http.Redirect(w, r, env.RpmPackage, http.StatusFound)
return
}
fDesc = "Enrolling RPM Package for Linux"
fName = genPackageFilename(env.Name, settings.PackageRpm, version.OsqueryVersion, version.OsctrlVersion)
fPath = fmt.Sprintf("%s/%s/%s", enrollPackagesPath, env.Name, env.RpmPackage)
fPath, err = environments.PackageFilePath(enrollPackagesPath, env.Name, env.RpmPackage)
case settings.PackagePkg:
if strings.HasPrefix(env.PkgPackage, "http") {
if strings.HasPrefix(env.PkgPackage, "https://") {
http.Redirect(w, r, env.PkgPackage, http.StatusFound)
return
}
fDesc = "Enrolling PKG Package for Mac"
fName = genPackageFilename(env.Name, settings.PackagePkg, version.OsqueryVersion, version.OsctrlVersion)
fPath = fmt.Sprintf("%s/%s/%s", enrollPackagesPath, env.Name, env.PkgPackage)
fPath, err = environments.PackageFilePath(enrollPackagesPath, env.Name, env.PkgPackage)
case settings.PackageMsi:
if strings.HasPrefix(env.MsiPackage, "http") {
if strings.HasPrefix(env.MsiPackage, "https://") {
http.Redirect(w, r, env.MsiPackage, http.StatusFound)
return
}
fDesc = "Enrolling MSI Package for Windows"
fName = genPackageFilename(env.Name, settings.PackageMsi, defOsqueryVersion, version.OsctrlVersion)
fPath = fmt.Sprintf("%s/%s/%s", enrollPackagesPath, env.Name, env.MsiPackage)
fPath, err = environments.PackageFilePath(enrollPackagesPath, env.Name, env.MsiPackage)
}
if err != nil {
log.Err(err).Msg("invalid package path")
utils.HTTPResponse(w, "", http.StatusBadRequest, []byte(""))
return
}
// Initiate download
fi, err := os.Stat(fPath)
Expand Down
12 changes: 12 additions & 0 deletions pkg/environments/environments.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,9 @@ func (environment *EnvManager) UpdateCertificate(idEnv, certificate string) erro

// UpdateDebPackage to update DEB package for an environment
func (environment *EnvManager) UpdateDebPackage(idEnv, debpackage string) error {
if err := ValidatePackageReference(debpackage); err != nil {
return fmt.Errorf("UpdateDebPackage %w", err)
}
if err := environment.DB.Model(&TLSEnvironment{}).Where("name = ? OR uuid = ?", idEnv, idEnv).Update("deb_package", debpackage).Error; err != nil {
return fmt.Errorf("UpdateDebPackage %w", err)
}
Expand All @@ -423,6 +426,9 @@ func (environment *EnvManager) UpdateDebPackage(idEnv, debpackage string) error

// UpdateRpmPackage to update RPM package for an environment
func (environment *EnvManager) UpdateRpmPackage(idEnv, rpmpackage string) error {
if err := ValidatePackageReference(rpmpackage); err != nil {
return fmt.Errorf("UpdateRpmPackage %w", err)
}
if err := environment.DB.Model(&TLSEnvironment{}).Where("name = ? OR uuid = ?", idEnv, idEnv).Update("rpm_package", rpmpackage).Error; err != nil {
return fmt.Errorf("UpdateRpmPackage %w", err)
}
Expand All @@ -431,6 +437,9 @@ func (environment *EnvManager) UpdateRpmPackage(idEnv, rpmpackage string) error

// UpdateMsiPackage to update MSI package for an environment
func (environment *EnvManager) UpdateMsiPackage(idEnv, msipackage string) error {
if err := ValidatePackageReference(msipackage); err != nil {
return fmt.Errorf("UpdateMsiPackage %w", err)
}
if err := environment.DB.Model(&TLSEnvironment{}).Where("name = ? OR uuid = ?", idEnv, idEnv).Update("msi_package", msipackage).Error; err != nil {
return fmt.Errorf("UpdateMsiPackage %w", err)
}
Expand All @@ -439,6 +448,9 @@ func (environment *EnvManager) UpdateMsiPackage(idEnv, msipackage string) error

// UpdatePkgPackage to update PKG package for an environment
func (environment *EnvManager) UpdatePkgPackage(idEnv, pkgpackage string) error {
if err := ValidatePackageReference(pkgpackage); err != nil {
return fmt.Errorf("UpdatePkgPackage %w", err)
}
if err := environment.DB.Model(&TLSEnvironment{}).Where("name = ? OR uuid = ?", idEnv, idEnv).Update("pkg_package", pkgpackage).Error; err != nil {
return fmt.Errorf("UpdatePkgPackage %w", err)
}
Expand Down
34 changes: 34 additions & 0 deletions pkg/environments/package_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package environments

import (
"path/filepath"
"testing"
)

func TestPackageReferences(t *testing.T) {
env := TLSEnvironment{Hostname: "host", UUID: "uuid", Secret: "secret"}

if got := PackageDownloadURL(env, "osquery.deb"); got != "https://host/uuid/secret/package/osquery.deb" {
t.Fatalf("local package URL = %q", got)
}
if got := PackageDownloadURL(env, "https://example.com/osquery.deb"); got != "https://example.com/osquery.deb" {
t.Fatalf("HTTPS package URL = %q", got)
}
if got := PackageDownloadURL(env, "../secret"); got != "" {
t.Fatalf("traversal package URL = %q", got)
}

got, err := PackageFilePath("packages", "prod", "osquery.deb")
if err != nil {
t.Fatal(err)
}
if want := filepath.Join("packages", "prod", "osquery.deb"); got != want {
t.Fatalf("package path = %q, want %q", got, want)
}

for _, pkg := range []string{"../secret", "dir/osquery.deb", `dir\osquery.deb`, "/tmp/osquery.deb", ".", ".."} {
if _, err := PackageFilePath("packages", "prod", pkg); err == nil {
t.Fatalf("PackageFilePath accepted %q", pkg)
}
}
}
26 changes: 26 additions & 0 deletions pkg/environments/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package environments
import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
)
Expand Down Expand Up @@ -49,12 +50,37 @@ func PackageDownloadURL(env TLSEnvironment, pkg string) string {
if pkg == "" {
return ""
}
if err := ValidatePackageReference(pkg); err != nil {
return ""
}
if strings.HasPrefix(pkg, "https://") {
return pkg
}
return fmt.Sprintf("https://%s/%s/%s/package/%s", env.Hostname, env.UUID, env.Secret, pkg)
}

// ValidatePackageReference allows HTTPS package URLs or local package basenames.
func ValidatePackageReference(pkg string) error {
if pkg == "" || strings.HasPrefix(pkg, "https://") {
return nil
}
if pkg == "." || pkg == ".." || filepath.IsAbs(pkg) || strings.ContainsAny(pkg, `/\`) {
return fmt.Errorf("invalid package path %q", pkg)
}
return nil
}

// PackageFilePath builds the local package path after validating the stored package value.
func PackageFilePath(packageRoot, envName, pkg string) (string, error) {
if err := ValidatePackageReference(pkg); err != nil {
return "", err
}
if strings.HasPrefix(pkg, "https://") {
return "", fmt.Errorf("package URL has no local file path")
}
return filepath.Join(packageRoot, envName, pkg), nil
}

// EnvironmentFinderID to find the environment and return its name based on the environment ID
func EnvironmentFinderID(envID uint, envs []TLSEnvironment, uuid bool) string {
if envID == 0 {
Expand Down
Loading