diff --git a/README.md b/README.md index 5477473..101be30 100644 --- a/README.md +++ b/README.md @@ -245,6 +245,27 @@ Input validation helpers. ### `libCallApi` Utilities for calling external APIs and handling auth/multi-call scenarios. +Remote APIs can authenticate with OAuth2 (`client_credentials`, `refresh_token`, optional `password` grant) or fall back to BasicAuth when `grant-type` is not configured. + +Example `param.yaml`: + +```yaml +remoteApis: + partner-api: + domain: https://api.partner.com + name: partner-api + auth: + grant-type: client_credentials + auth-uri: https://auth.partner.com/oauth/token + client-id: partner-client +``` + +Secure values (existing pattern): + +- `remote-api#partner-api#client-secret` +- `remote-api#partner-api#client-id` +- `remote-api#partner-api#auth-uri` (alias: `auth-url`) + ### `libCrypto` Cryptographic and security primitives. diff --git a/go.mod b/go.mod index 6be1773..a4350fb 100644 --- a/go.mod +++ b/go.mod @@ -106,6 +106,7 @@ require ( golang.org/x/arch v0.20.0 // indirect golang.org/x/crypto v0.41.0 // indirect golang.org/x/net v0.43.0 // indirect + golang.org/x/oauth2 v0.36.0 // indirect golang.org/x/sync v0.16.0 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/text v0.28.0 // indirect diff --git a/go.sum b/go.sum index 255d58d..3e37cf6 100644 --- a/go.sum +++ b/go.sum @@ -246,6 +246,8 @@ golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/libApplication/init.go b/libApplication/init.go index e9206e8..c25b506 100644 --- a/libApplication/init.go +++ b/libApplication/init.go @@ -113,6 +113,13 @@ func InitializeApp[T any](app Application[T]) *App[T] { cache, lock := libCallApi.InitTokenCache() api.TokenCache = cache api.TokenCacheLock = lock + if api.AuthData.GrantType != "" { + auth, err := libCallApi.NewOAuth2AuthFromAuthData(api.AuthData, libCallApi.NewTokenHTTPClient()) + if err != nil { + log.Fatal("InitializeApp: OAuth2 auth for ", id, "=>", err) + } + api.Auth = auth + } wsParams.RemoteApis[id] = api } diff --git a/libCallApi/auth.go b/libCallApi/auth.go index 7b59a4e..7e1e27d 100644 --- a/libCallApi/auth.go +++ b/libCallApi/auth.go @@ -1,6 +1,7 @@ package libCallApi import ( + "context" "encoding/base64" "fmt" "sync" @@ -8,7 +9,6 @@ import ( "github.com/hmmftg/requestCore/libError" "github.com/hmmftg/requestCore/status" - "github.com/hmmftg/requestCore/webFramework" ) type Auth struct { @@ -47,7 +47,8 @@ func InitTokenCache() (*TokenCache, *sync.Mutex) { } type AuthSystem interface { - Login(w webFramework.WebFramework) (*TokenCache, libError.Error) + Login(ctx context.Context) (*TokenCache, libError.Error) + Refresh(ctx context.Context, refreshToken string) (*TokenCache, libError.Error) } func (api RemoteApi) GetBasicAuthHeader() string { @@ -73,19 +74,38 @@ func (api RemoteApi) GetAuthHeader() (string, error) { return fmt.Sprintf("%s %s", api.TokenCache.AccessToken.Type, api.TokenCache.AccessToken.Token), nil } -func (api *RemoteApi) handleToken(w webFramework.WebFramework) libError.Error { +func (api *RemoteApi) handleToken(ctx context.Context) libError.Error { api.TokenCacheLock.Lock() defer api.TokenCacheLock.Unlock() if api.TokenCache.AccessToken != nil && !api.TokenCache.Expired() { - // another thread handles login before us return nil } + if api.Auth == nil { + return libError.NewWithDescription( + status.InternalServerError, + "AUTH_SYSTEM_NOT_CONFIGURED", + "auth system of api %s is not configured", + api.Name, + ) + } + + if api.TokenCache.RefreshToken != nil && api.TokenCache.RefreshToken.Token != "" { + tokens, err := api.Auth.Refresh(ctx, api.TokenCache.RefreshToken.Token) + if err == nil { + api.TokenCache.AccessToken = tokens.AccessToken + if tokens.RefreshToken != nil { + api.TokenCache.RefreshToken = tokens.RefreshToken + } + return nil + } + } + api.TokenCache.AccessToken = nil api.TokenCache.RefreshToken = nil - tokens, err := api.Auth.Login(w) + tokens, err := api.Auth.Login(ctx) if err != nil { return err } @@ -94,21 +114,15 @@ func (api *RemoteApi) handleToken(w webFramework.WebFramework) libError.Error { return nil } -func (api *RemoteApi) Authenticate(w webFramework.WebFramework) libError.Error { +func (api *RemoteApi) Authenticate(ctx context.Context) libError.Error { if api.TokenCacheLock == nil { return libError.NewWithDescription(status.InternalServerError, "TOKEN_CACHE_NOT_INITIALIZED", "token cache lock of api %s is null", api.Name) } - if api.TokenCache.AccessToken == nil { - err := api.handleToken(w) - if err != nil { - return err - } + if api.TokenCache == nil { + return libError.NewWithDescription(status.InternalServerError, "TOKEN_CACHE_NOT_INITIALIZED", "token cache of api %s is null", api.Name) } - if api.TokenCache.Expired() { - err := api.handleToken(w) - if err != nil { - return err - } + if api.TokenCache.AccessToken == nil || api.TokenCache.Expired() { + return api.handleToken(ctx) } return nil } diff --git a/libCallApi/auth_headers.go b/libCallApi/auth_headers.go new file mode 100644 index 0000000..badd2fa --- /dev/null +++ b/libCallApi/auth_headers.go @@ -0,0 +1,43 @@ +package libCallApi + +import ( + "context" + + "github.com/hmmftg/requestCore/libError" + "github.com/hmmftg/requestCore/status" +) + +func (api *RemoteApi) EnsureAuthorization(ctx context.Context, headers map[string]string) libError.Error { + if headers == nil { + return libError.NewWithDescription( + status.InternalServerError, + "AUTH_HEADERS_NIL", + "headers map is nil for api %s", + api.Name, + ) + } + if _, ok := headers["Authorization"]; ok { + return nil + } + if api.Auth != nil { + if err := api.Authenticate(ctx); err != nil { + return err + } + authHeader, err := api.GetAuthHeader() + if err != nil { + return libError.NewWithDescription( + status.InternalServerError, + "AUTH_HEADER_FAILED", + "failed to build auth header for api %s: %v", + api.Name, + err, + ) + } + headers["Authorization"] = authHeader + return nil + } + if api.AuthData.User != "" && api.AuthData.Password != "" { + headers["Authorization"] = api.GetBasicAuthHeader() + } + return nil +} diff --git a/libCallApi/auth_test.go b/libCallApi/auth_test.go new file mode 100644 index 0000000..e43c830 --- /dev/null +++ b/libCallApi/auth_test.go @@ -0,0 +1,191 @@ +package libCallApi_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/hmmftg/requestCore/libCallApi" + "github.com/hmmftg/requestCore/libError" + "gotest.tools/v3/assert" +) + +type countingAuth struct { + logins atomic.Int32 + refreshes atomic.Int32 +} + +func (a *countingAuth) Login(ctx context.Context) (*libCallApi.TokenCache, libError.Error) { + a.logins.Add(1) + return &libCallApi.TokenCache{ + AccessToken: &libCallApi.OAuth2Token{ + Token: "access-token", + Type: "Bearer", + TimeTaken: time.Now(), + ValidUntil: time.Hour, + }, + }, nil +} + +func (a *countingAuth) Refresh(ctx context.Context, refreshToken string) (*libCallApi.TokenCache, libError.Error) { + a.refreshes.Add(1) + return &libCallApi.TokenCache{ + AccessToken: &libCallApi.OAuth2Token{ + Token: "refreshed-access-token", + Type: "Bearer", + TimeTaken: time.Now(), + ValidUntil: time.Hour, + }, + }, nil +} + +func TestAuthenticate_CacheHitAvoidsSecondLogin(t *testing.T) { + auth := &countingAuth{} + cache, lock := libCallApi.InitTokenCache() + api := &libCallApi.RemoteApi{ + Name: "test-api", + Auth: auth, + TokenCache: cache, + TokenCacheLock: lock, + } + + err := api.Authenticate(context.Background()) + assert.NilError(t, err) + err = api.Authenticate(context.Background()) + assert.NilError(t, err) + assert.Equal(t, auth.logins.Load(), int32(1)) +} + +func TestAuthenticate_ExpiredTokenTriggersRefresh(t *testing.T) { + auth := &countingAuth{} + cache, lock := libCallApi.InitTokenCache() + cache.AccessToken = &libCallApi.OAuth2Token{ + Token: "expired-token", + Type: "Bearer", + TimeTaken: time.Now().Add(-2 * time.Hour), + ValidUntil: time.Hour, + } + cache.RefreshToken = &libCallApi.OAuth2Token{ + Token: "refresh-token", + TimeTaken: time.Now(), + ValidUntil: time.Hour, + } + api := &libCallApi.RemoteApi{ + Name: "test-api", + Auth: auth, + TokenCache: cache, + TokenCacheLock: lock, + } + + err := api.Authenticate(context.Background()) + assert.NilError(t, err) + assert.Equal(t, auth.refreshes.Load(), int32(1)) + assert.Equal(t, auth.logins.Load(), int32(0)) + assert.Equal(t, api.TokenCache.AccessToken.Token, "refreshed-access-token") +} + +func TestAuthenticate_ConcurrentLoginOnce(t *testing.T) { + auth := &countingAuth{} + cache, lock := libCallApi.InitTokenCache() + api := &libCallApi.RemoteApi{ + Name: "test-api", + Auth: auth, + TokenCache: cache, + TokenCacheLock: lock, + } + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := api.Authenticate(context.Background()) + assert.NilError(t, err) + }() + } + wg.Wait() + assert.Equal(t, auth.logins.Load(), int32(1)) +} + +func TestEnsureAuthorization_PreservesExplicitHeader(t *testing.T) { + api := &libCallApi.RemoteApi{ + Name: "test-api", + Auth: &countingAuth{}, + } + headers := map[string]string{ + "Authorization": "Bearer explicit-token", + } + err := api.EnsureAuthorization(context.Background(), headers) + assert.NilError(t, err) + assert.Equal(t, headers["Authorization"], "Bearer explicit-token") +} + +func TestEnsureAuthorization_BasicAuthFallback(t *testing.T) { + api := &libCallApi.RemoteApi{ + Name: "test-api", + AuthData: libCallApi.Auth{ + User: "user", + Password: "pass", + }, + } + headers := map[string]string{} + err := api.EnsureAuthorization(context.Background(), headers) + assert.NilError(t, err) + assert.Equal(t, headers["Authorization"], api.GetBasicAuthHeader()) +} + +func TestPrepareCall_OAuthAuthorizationHeader(t *testing.T) { + var tokenCalls atomic.Int32 + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenCalls.Add(1) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "oauth-access-token", + "token_type": "Bearer", + "expires_in": 3600, + }) + })) + defer tokenServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Header.Get("Authorization"), "Bearer oauth-access-token") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer apiServer.Close() + + auth, err := libCallApi.NewOAuth2AuthFromAuthData(libCallApi.Auth{ + GrantType: libCallApi.GrantTypeClientCredentials, + AuthURI: tokenServer.URL, + ClientID: "client-id", + ClientSecret: "client-secret", + }, tokenServer.Client()) + assert.NilError(t, err) + + cache, lock := libCallApi.InitTokenCache() + remoteApi := libCallApi.RemoteApi{ + Name: "partner-api", + Domain: apiServer.URL, + Auth: auth, + TokenCache: cache, + TokenCacheLock: lock, + } + + callData := libCallApi.CallData[map[string]string]{ + Api: remoteApi, + Path: "", + Method: http.MethodGet, + Headers: map[string]string{}, + BodyType: libCallApi.Empty, + Context: context.Background(), + } + + req, err := libCallApi.PrepareCall(callData) + assert.NilError(t, err) + assert.Equal(t, req.Header.Get("Authorization"), "Bearer oauth-access-token") +} diff --git a/libCallApi/callApi.go b/libCallApi/callApi.go index 0b65ebf..6aa1088 100644 --- a/libCallApi/callApi.go +++ b/libCallApi/callApi.go @@ -37,11 +37,17 @@ func (m RemoteApiModel) ConsumeRestBasicAuthApi(requestJson []byte, apiName, pat timeoutSeconds, _ := strconv.Atoi(timeOutString) httpClient.Timeout = time.Duration(timeoutSeconds * int(time.Second)) } - req, err := http.NewRequest(method, m.RemoteApiList[apiName].Domain+"/"+path, bytes.NewBuffer(requestJson)) + api := m.RemoteApiList[apiName] + if headers == nil { + headers = make(map[string]string) + } + if err := (&api).EnsureAuthorization(context.Background(), headers); err != nil { + return nil, "AUTH_FAILED", err + } + req, err := http.NewRequest(method, api.Domain+"/"+path, bytes.NewBuffer(requestJson)) if err != nil { return nil, "Generate Request Failed", err } - req.SetBasicAuth(m.RemoteApiList[apiName].AuthData.User, m.RemoteApiList[apiName].AuthData.Password) req.Header.Add("Content-Type", contentType) for header, value := range headers { req.Header.Add(header, value) @@ -85,13 +91,17 @@ func (m RemoteApiModel) ConsumeRestApi(requestJson []byte, apiName, path, conten timeoutSeconds, _ := strconv.Atoi(timeOutString) httpClient.Timeout = time.Duration(timeoutSeconds * int(time.Second)) } - req, err := http.NewRequest(method, m.RemoteApiList[apiName].Domain+"/"+path, bytes.NewBuffer(requestJson)) + api := m.RemoteApiList[apiName] + if headers == nil { + headers = make(map[string]string) + } + if err := (&api).EnsureAuthorization(context.Background(), headers); err != nil { + return nil, "AUTH_FAILED", http.StatusInternalServerError, err + } + req, err := http.NewRequest(method, api.Domain+"/"+path, bytes.NewBuffer(requestJson)) if err != nil { return nil, "Generate Request Failed", http.StatusInternalServerError, err } - if _, ok := headers["Authorization"]; !ok { - req.SetBasicAuth(m.RemoteApiList[apiName].AuthData.User, m.RemoteApiList[apiName].AuthData.Password) - } req.Header.Add("Content-Type", contentType) for header, value := range headers { req.Header.Add(header, value) @@ -288,8 +298,11 @@ func PrepareCall[Resp any](c CallData[Resp]) (*http.Request, error) { propagator.Inject(ctx, propagation.HeaderCarrier(req.Header)) } - if _, ok := c.Headers["Authorization"]; !ok { - req.SetBasicAuth(c.Api.AuthData.User, c.Api.AuthData.Password) + if c.Headers == nil { + c.Headers = make(map[string]string) + } + if err := c.Api.EnsureAuthorization(ctx, c.Headers); err != nil { + return nil, err } switch c.BodyType { case JSON: diff --git a/libCallApi/oauth2_auth.go b/libCallApi/oauth2_auth.go new file mode 100644 index 0000000..44f99b7 --- /dev/null +++ b/libCallApi/oauth2_auth.go @@ -0,0 +1,133 @@ +package libCallApi + +import ( + "context" + "net/http" + "time" + + "github.com/hmmftg/requestCore/libError" + "github.com/hmmftg/requestCore/status" + "golang.org/x/oauth2" + "golang.org/x/oauth2/clientcredentials" +) + +const tokenExpirySkew = 30 * time.Second + +type OAuth2Auth struct { + grantType string + user string + password string + cfg oauth2.Config + httpClient *http.Client +} + +func (a OAuth2Auth) Login(ctx context.Context) (*TokenCache, libError.Error) { + switch a.grantType { + case GrantTypeClientCredentials: + return a.loginClientCredentials(ctx) + case GrantTypePassword: + return a.loginPassword(ctx) + default: + return nil, libError.NewWithDescription( + status.InternalServerError, + "OAUTH2_UNSUPPORTED_GRANT", + "unsupported grant type %s", + a.grantType, + ) + } +} + +func (a OAuth2Auth) Refresh(ctx context.Context, refreshToken string) (*TokenCache, libError.Error) { + if refreshToken == "" { + return nil, libError.NewWithDescription( + status.InternalServerError, + "OAUTH2_NO_REFRESH_TOKEN", + "empty refresh token", + ) + } + ctx = a.withHTTPClient(ctx) + ts := a.cfg.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken}) + tok, err := ts.Token() + if err != nil { + return nil, libError.NewWithDescription( + status.InternalServerError, + "OAUTH2_REFRESH_FAILED", + "refresh token failed: %v", + err, + ) + } + return oauth2TokenToCache(tok), nil +} + +func (a OAuth2Auth) loginClientCredentials(ctx context.Context) (*TokenCache, libError.Error) { + cc := &clientcredentials.Config{ + ClientID: a.cfg.ClientID, + ClientSecret: a.cfg.ClientSecret, + TokenURL: a.cfg.Endpoint.TokenURL, + } + ctx = a.withHTTPClient(ctx) + tok, err := cc.Token(ctx) + if err != nil { + return nil, libError.NewWithDescription( + status.InternalServerError, + "OAUTH2_LOGIN_FAILED", + "client credentials login failed: %v", + err, + ) + } + return oauth2TokenToCache(tok), nil +} + +func (a OAuth2Auth) loginPassword(ctx context.Context) (*TokenCache, libError.Error) { + ctx = a.withHTTPClient(ctx) + tok, err := a.cfg.PasswordCredentialsToken(ctx, a.user, a.password) + if err != nil { + return nil, libError.NewWithDescription( + status.InternalServerError, + "OAUTH2_LOGIN_FAILED", + "password grant login failed: %v", + err, + ) + } + return oauth2TokenToCache(tok), nil +} + +func (a OAuth2Auth) withHTTPClient(ctx context.Context) context.Context { + if a.httpClient == nil { + return ctx + } + return context.WithValue(ctx, oauth2.HTTPClient, a.httpClient) +} + +func oauth2TokenToCache(tok *oauth2.Token) *TokenCache { + tokenType := tok.TokenType + if tokenType == "" { + tokenType = "Bearer" + } + + validUntil := time.Until(tok.Expiry) - tokenExpirySkew + if validUntil < 0 { + validUntil = 0 + } + + cache := &TokenCache{ + AccessToken: &OAuth2Token{ + Token: tok.AccessToken, + Type: tokenType, + TimeTaken: time.Now(), + ValidUntil: validUntil, + }, + } + if scope, ok := tok.Extra("scope").(string); ok { + cache.AccessToken.Scope = scope + } + if tok.RefreshToken != "" { + cache.RefreshToken = &OAuth2Token{ + Token: tok.RefreshToken, + Type: tokenType, + TimeTaken: time.Now(), + ValidUntil: 365 * 24 * time.Hour, + } + } + return cache +} diff --git a/libCallApi/oauth2_auth_test.go b/libCallApi/oauth2_auth_test.go new file mode 100644 index 0000000..b19cb69 --- /dev/null +++ b/libCallApi/oauth2_auth_test.go @@ -0,0 +1,91 @@ +package libCallApi_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/hmmftg/requestCore/libCallApi" + "gotest.tools/v3/assert" +) + +func TestOAuth2Auth_ClientCredentialsLogin(t *testing.T) { + var tokenCalls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenCalls.Add(1) + assert.Equal(t, r.Method, http.MethodPost) + err := r.ParseForm() + assert.NilError(t, err) + assert.Equal(t, "client_credentials", r.Form.Get("grant_type")) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-access-token", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "api.read", + }) + })) + defer server.Close() + + auth, err := libCallApi.NewOAuth2AuthFromAuthData(libCallApi.Auth{ + GrantType: libCallApi.GrantTypeClientCredentials, + AuthURI: server.URL + "/token", + ClientID: "client-id", + ClientSecret: "client-secret", + }, server.Client()) + assert.NilError(t, err) + + cache, loginErr := auth.Login(context.Background()) + assert.NilError(t, loginErr) + assert.Equal(t, tokenCalls.Load(), int32(1)) + assert.Equal(t, cache.AccessToken.Token, "test-access-token") + assert.Equal(t, cache.AccessToken.Type, "Bearer") + assert.Equal(t, cache.AccessToken.Scope, "api.read") + assert.Assert(t, cache.AccessToken.ValidUntil > 0) + assert.Assert(t, !cache.Expired()) +} + +func TestOAuth2Auth_RefreshToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + assert.NilError(t, err) + assert.Equal(t, "refresh_token", r.Form.Get("grant_type")) + assert.Equal(t, "old-refresh", r.Form.Get("refresh_token")) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "new-access-token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "new-refresh", + }) + })) + defer server.Close() + + auth, err := libCallApi.NewOAuth2AuthFromAuthData(libCallApi.Auth{ + GrantType: libCallApi.GrantTypeClientCredentials, + AuthURI: server.URL + "/token", + ClientID: "client-id", + ClientSecret: "client-secret", + }, server.Client()) + assert.NilError(t, err) + + cache, refreshErr := auth.Refresh(context.Background(), "old-refresh") + assert.NilError(t, refreshErr) + assert.Equal(t, cache.AccessToken.Token, "new-access-token") + assert.Equal(t, cache.RefreshToken.Token, "new-refresh") +} + +func TestNewOAuth2AuthFromAuthData_Validation(t *testing.T) { + _, err := libCallApi.NewOAuth2AuthFromAuthData(libCallApi.Auth{}, nil) + assert.ErrorContains(t, err, "grant-type is required") + + _, err = libCallApi.NewOAuth2AuthFromAuthData(libCallApi.Auth{ + GrantType: libCallApi.GrantTypeClientCredentials, + }, nil) + assert.ErrorContains(t, err, "auth-uri is required") +} diff --git a/libCallApi/oauth2_factory.go b/libCallApi/oauth2_factory.go new file mode 100644 index 0000000..1fb04ea --- /dev/null +++ b/libCallApi/oauth2_factory.go @@ -0,0 +1,60 @@ +package libCallApi + +import ( + "crypto/tls" + "fmt" + "net/http" + + "golang.org/x/oauth2" +) + +const ( + GrantTypeClientCredentials = "client_credentials" + GrantTypePassword = "password" +) + +func NewTokenHTTPClient() *http.Client { + return &http.Client{ + Timeout: defaultTimeOut, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + } +} + +func NewOAuth2AuthFromAuthData(auth Auth, httpClient *http.Client) (AuthSystem, error) { + if auth.GrantType == "" { + return nil, fmt.Errorf("grant-type is required") + } + if auth.AuthURI == "" { + return nil, fmt.Errorf("auth-uri is required") + } + if auth.ClientID == "" { + return nil, fmt.Errorf("client-id is required") + } + if auth.ClientSecret == "" { + return nil, fmt.Errorf("client-secret is required") + } + if auth.GrantType == GrantTypePassword && (auth.User == "" || auth.Password == "") { + return nil, fmt.Errorf("user and password are required for password grant") + } + if httpClient == nil { + httpClient = NewTokenHTTPClient() + } + + return OAuth2Auth{ + grantType: auth.GrantType, + user: auth.User, + password: auth.Password, + cfg: oauth2.Config{ + ClientID: auth.ClientID, + ClientSecret: auth.ClientSecret, + Endpoint: oauth2.Endpoint{ + TokenURL: auth.AuthURI, + }, + }, + httpClient: httpClient, + }, nil +} diff --git a/libParams/security.go b/libParams/security.go index fa9fda9..a3346f0 100644 --- a/libParams/security.go +++ b/libParams/security.go @@ -154,7 +154,7 @@ func DecryptParams[T any](keyByte, ivByte []byte, params *ApplicationParams[T]) api.AuthData.ClientID = current.Value case "client-secret": api.AuthData.ClientSecret = current.Value - case "auth-url": + case "auth-url", "auth-uri": api.AuthData.AuthURI = current.Value } params.RemoteApis[tags[1]] = api