diff --git a/m2m-oauth-server/service/service.go b/m2m-oauth-server/service/service.go index 4ab56ac49..410a02087 100644 --- a/m2m-oauth-server/service/service.go +++ b/m2m-oauth-server/service/service.go @@ -7,6 +7,7 @@ import ( "time" oauthsigner "github.com/plgd-dev/hub/v2/m2m-oauth-server/oauthSigner" + "github.com/plgd-dev/hub/v2/m2m-oauth-server/pb" grpcService "github.com/plgd-dev/hub/v2/m2m-oauth-server/service/grpc" httpService "github.com/plgd-dev/hub/v2/m2m-oauth-server/service/http" "github.com/plgd-dev/hub/v2/m2m-oauth-server/store" @@ -16,9 +17,11 @@ import ( "github.com/plgd-dev/hub/v2/pkg/fn" "github.com/plgd-dev/hub/v2/pkg/fsnotify" "github.com/plgd-dev/hub/v2/pkg/log" + "github.com/plgd-dev/hub/v2/pkg/net/grpc" "github.com/plgd-dev/hub/v2/pkg/net/listener" otelClient "github.com/plgd-dev/hub/v2/pkg/opentelemetry/collector/client" certManagerServer "github.com/plgd-dev/hub/v2/pkg/security/certManager/server" + "github.com/plgd-dev/hub/v2/pkg/security/jwt" "github.com/plgd-dev/hub/v2/pkg/security/jwt/validator" "github.com/plgd-dev/hub/v2/pkg/security/openid" "github.com/plgd-dev/hub/v2/pkg/service" @@ -62,8 +65,8 @@ func createStore(ctx context.Context, config storeConfig.Config, fileWatcher *fs return s, nil } -func newHttpService(ctx context.Context, config HTTPConfig, validatorConfig validator.Config, getOpenIDConfiguration validator.GetOpenIDConfigurationFunc, tlsConfig certManagerServer.Config, ss *grpcService.M2MOAuthServiceServer, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*httpService.Service, func(), error) { - httpValidator, err := validator.New(ctx, validatorConfig, fileWatcher, logger, tracerProvider, validator.WithGetOpenIDConfiguration(getOpenIDConfiguration)) +func newHttpService(ctx context.Context, config HTTPConfig, validatorConfig validator.Config, getOpenIDConfiguration validator.GetOpenIDConfigurationFunc, trustVerification map[string]jwt.TokenIssuerClient, tlsConfig certManagerServer.Config, ss *grpcService.M2MOAuthServiceServer, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*httpService.Service, func(), error) { + httpValidator, err := validator.New(ctx, validatorConfig, fileWatcher, logger, tracerProvider, validator.WithGetOpenIDConfiguration(getOpenIDConfiguration), validator.WithCustomTokenIssuerClients(trustVerification)) if err != nil { return nil, nil, fmt.Errorf("cannot create http validator: %w", err) } @@ -82,8 +85,8 @@ func newHttpService(ctx context.Context, config HTTPConfig, validatorConfig vali return httpService, httpValidator.Close, nil } -func newGrpcService(ctx context.Context, config grpcService.Config, getOpenIDConfiguration validator.GetOpenIDConfigurationFunc, ss *grpcService.M2MOAuthServiceServer, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*grpcService.Service, func(), error) { - grpcValidator, err := validator.New(ctx, config.Authorization.Config, fileWatcher, logger, tracerProvider, validator.WithGetOpenIDConfiguration(getOpenIDConfiguration)) +func newGrpcService(ctx context.Context, config grpcService.Config, getOpenIDConfiguration validator.GetOpenIDConfigurationFunc, trustVerification map[string]jwt.TokenIssuerClient, ss *grpcService.M2MOAuthServiceServer, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*grpcService.Service, func(), error) { + grpcValidator, err := validator.New(ctx, config.Authorization.Config, fileWatcher, logger, tracerProvider, validator.WithGetOpenIDConfiguration(getOpenIDConfiguration), validator.WithCustomTokenIssuerClients(trustVerification)) if err != nil { return nil, nil, fmt.Errorf("cannot create grpc validator: %w", err) } @@ -95,6 +98,33 @@ func newGrpcService(ctx context.Context, config grpcService.Config, getOpenIDCon return grpcService, grpcValidator.Close, nil } +type tokenIssuerClient struct { + store store.Store + ownerClaim string +} + +func (c *tokenIssuerClient) VerifyTokenByRequest(ctx context.Context, accessToken, tokenID string) (*pb.Token, error) { + owner, err := grpc.ParseOwnerFromJwtToken(c.ownerClaim, accessToken) + if err != nil { + return nil, fmt.Errorf("cannot parse owner from token: %w", err) + } + var token *pb.Token + err = c.store.GetTokens(ctx, owner, &pb.GetTokensRequest{ + IdFilter: []string{tokenID}, + IncludeBlacklisted: true, + }, func(v *pb.Token) error { + token = v + return nil + }) + if err != nil { + return nil, fmt.Errorf("cannot get token(%v): %w", tokenID, err) + } + if token == nil { + return nil, fmt.Errorf("token(%v) not found", tokenID) + } + return token, nil +} + func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logger log.Logger) (*Service, error) { otelClient, err := otelClient.New(ctx, config.Clients.OpenTelemetryCollector.Config, serviceName, fileWatcher, logger) if err != nil { @@ -104,13 +134,6 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg closerFn.AddFunc(otelClient.Close) tracerProvider := otelClient.GetTracerProvider() - getOpenIDCfg := func(ctx context.Context, c *http.Client, authority string) (openid.Config, error) { - if authority == config.OAuthSigner.GetAuthority() { - return httpService.GetOpenIDConfiguration(config.OAuthSigner.GetDomain()), nil - } - return openid.GetConfiguration(ctx, c, authority) - } - db, err := createStore(ctx, config.Clients.Storage, fileWatcher, logger, tracerProvider) if err != nil { closerFn.Execute() @@ -122,6 +145,13 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg } }) + getOpenIDCfg := func(ctx context.Context, c *http.Client, authority string) (openid.Config, error) { + if authority == config.OAuthSigner.GetAuthority() { + return httpService.GetOpenIDConfiguration(config.OAuthSigner.GetDomain()), nil + } + return openid.GetConfiguration(ctx, c, authority) + } + signer, err := oauthsigner.New(ctx, config.OAuthSigner, getOpenIDCfg, fileWatcher, logger, tracerProvider) if err != nil { closerFn.Execute() @@ -131,14 +161,21 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg m2mOAuthService := grpcService.NewM2MOAuthServerServer(db, signer, logger) - grpcService, grpcServiceClose, err := newGrpcService(ctx, config.APIs.GRPC, getOpenIDCfg, m2mOAuthService, fileWatcher, logger, tracerProvider) + customTokenIssuerClients := map[string]jwt.TokenIssuerClient{ + config.OAuthSigner.GetDomain(): &tokenIssuerClient{ + store: db, + ownerClaim: signer.GetOwnerClaim(), + }, + } + + grpcService, grpcServiceClose, err := newGrpcService(ctx, config.APIs.GRPC, getOpenIDCfg, customTokenIssuerClients, m2mOAuthService, fileWatcher, logger, tracerProvider) if err != nil { closerFn.Execute() return nil, err } closerFn.AddFunc(grpcServiceClose) - httpService, httpServiceClose, err := newHttpService(ctx, config.APIs.HTTP, config.APIs.GRPC.Authorization.Config, getOpenIDCfg, config.APIs.GRPC.TLS, + httpService, httpServiceClose, err := newHttpService(ctx, config.APIs.HTTP, config.APIs.GRPC.Authorization.Config, getOpenIDCfg, customTokenIssuerClients, config.APIs.GRPC.TLS, m2mOAuthService, fileWatcher, logger, tracerProvider) if err != nil { grpcService.Close() diff --git a/m2m-oauth-server/test/test.go b/m2m-oauth-server/test/test.go index b57b582a1..32ec47e20 100644 --- a/m2m-oauth-server/test/test.go +++ b/m2m-oauth-server/test/test.go @@ -54,7 +54,7 @@ var JWTPrivateKeyOAuthClient = oauthsigner.Client{ AllowedScopes: nil, JWTPrivateKey: oauthsigner.PrivateKeyJWTConfig{ Enabled: true, - Authorization: MakeValidatorConfig(), + Authorization: config.MakeValidatorConfig(), }, } @@ -63,16 +63,6 @@ var OAuthClients = oauthsigner.OAuthClientsConfig{ &JWTPrivateKeyOAuthClient, } -func MakeValidatorConfig() validator.Config { - c := config.MakeValidatorConfig() - // tokens are verified by the m2m-oauth-server, so we want to disable the verification here to avoid infinite loop - // of token verification - c.TokenVerification = validator.TokenTrustVerificationConfig{ - Enabled: false, - } - return c -} - func MakeConfig(t require.TestingT) service.Config { var cfg service.Config @@ -91,7 +81,7 @@ func MakeConfig(t require.TestingT) service.Config { HTTP: config.MakeHttpClientConfig(), }, ) - cfg.APIs.GRPC.Authorization.Config = MakeValidatorConfig() + cfg.APIs.GRPC.Authorization.Config = config.MakeValidatorConfig() cfg.Clients.Storage = MakeStoreConfig() cfg.OAuthSigner.PrivateKeyFile = urischeme.URIScheme(os.Getenv("M2M_OAUTH_SERVER_PRIVATE_KEY")) diff --git a/pkg/security/jwt/tokenCache.go b/pkg/security/jwt/tokenCache.go index 032061386..e9f77362b 100644 --- a/pkg/security/jwt/tokenCache.go +++ b/pkg/security/jwt/tokenCache.go @@ -17,13 +17,47 @@ import ( "go.uber.org/atomic" ) -type Client struct { +type HTTPClient struct { *http.Client tokenEndpoint string } -func NewClient(client *http.Client, tokenEndpoint string) *Client { - return &Client{ +func (c *HTTPClient) VerifyTokenByRequest(ctx context.Context, token, tokenID string) (*pb.Token, error) { + uri, err := url.Parse(c.tokenEndpoint) + if err != nil { + return nil, fmt.Errorf("cannot parse tokenEndpoint %v: %w", c.tokenEndpoint, err) + } + query := uri.Query() + query.Add("idFilter", tokenID) + query.Add("includeBlacklisted", "true") + uri.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri.String(), nil) + if err != nil { + return nil, fmt.Errorf("cannot create request for GET %v: %w", uri.String(), err) + } + + req.Header.Set("Accept", "application/protojson") + req.Header.Set("Authorization", "bearer "+token) + resp, err := c.Do(req) + if err != nil { + return nil, fmt.Errorf("cannot send request for GET %v: %w", c.tokenEndpoint, err) + } + + defer func() { + _ = resp.Body.Close() + }() + + var gotToken pb.Token + err = pkgHttpPb.Unmarshal(resp.StatusCode, resp.Body, &gotToken) + if err != nil { + return nil, err + } + return &gotToken, nil +} + +func NewHTTPClient(client *http.Client, tokenEndpoint string) *HTTPClient { + return &HTTPClient{ Client: client, tokenEndpoint: tokenEndpoint, } @@ -75,17 +109,19 @@ func (tf *tokenOrFuture) Get(ctx context.Context) (*tokenRecord, error) { } type tokenIssuerCache struct { - client *http.Client - tokenEndpoint string - tokens map[uuid.UUID]tokenOrFuture - mutex sync.Mutex + client TokenIssuerClient + tokens map[uuid.UUID]tokenOrFuture + mutex sync.Mutex } -func newTokenIssuerCache(client *Client) *tokenIssuerCache { +type TokenIssuerClient interface { + VerifyTokenByRequest(ctx context.Context, token, tokenID string) (*pb.Token, error) +} + +func newTokenIssuerCache(client TokenIssuerClient) *tokenIssuerCache { return &tokenIssuerCache{ - client: client.Client, - tokenEndpoint: client.tokenEndpoint, - tokens: make(map[uuid.UUID]tokenOrFuture), + client: client, + tokens: make(map[uuid.UUID]tokenOrFuture), } } @@ -154,37 +190,7 @@ func (tc *tokenIssuerCache) checkExpirations(now time.Time) { } func (tc *tokenIssuerCache) verifyTokenByRequest(ctx context.Context, token, tokenID string) (*pb.Token, error) { - uri, err := url.Parse(tc.tokenEndpoint) - if err != nil { - return nil, fmt.Errorf("cannot parse tokenEndpoint %v: %w", tc.tokenEndpoint, err) - } - query := uri.Query() - query.Add("idFilter", tokenID) - query.Add("includeBlacklisted", "true") - uri.RawQuery = query.Encode() - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri.String(), nil) - if err != nil { - return nil, fmt.Errorf("cannot create request for GET %v: %w", uri.String(), err) - } - - req.Header.Set("Accept", "application/protojson") - req.Header.Set("Authorization", "bearer "+token) - resp, err := tc.client.Do(req) - if err != nil { - return nil, fmt.Errorf("cannot send request for GET %v: %w", tc.tokenEndpoint, err) - } - - defer func() { - _ = resp.Body.Close() - }() - - var gotToken pb.Token - err = pkgHttpPb.Unmarshal(resp.StatusCode, resp.Body, &gotToken) - if err != nil { - return nil, err - } - return &gotToken, nil + return tc.client.VerifyTokenByRequest(ctx, token, tokenID) } type TokenCache struct { @@ -193,7 +199,7 @@ type TokenCache struct { logger log.Logger } -func NewTokenCache(clients map[string]*Client, expiration time.Duration, logger log.Logger) *TokenCache { +func NewTokenCache(clients map[string]TokenIssuerClient, expiration time.Duration, logger log.Logger) *TokenCache { tc := &TokenCache{ expiration: expiration, logger: logger, diff --git a/pkg/security/jwt/tokenCache_internal_test.go b/pkg/security/jwt/tokenCache_internal_test.go index e0e8ec240..3b0e99684 100644 --- a/pkg/security/jwt/tokenCache_internal_test.go +++ b/pkg/security/jwt/tokenCache_internal_test.go @@ -34,7 +34,7 @@ func TestTokenRecord_IsExpired(t *testing.T) { } func TestTokenIssuerCacheSetAndGetToken(t *testing.T) { - cache := newTokenIssuerCache(&Client{Client: &http.Client{}, tokenEndpoint: "http://example.com"}) + cache := newTokenIssuerCache(&HTTPClient{Client: &http.Client{}, tokenEndpoint: "http://example.com"}) ctx, cancel := context.WithTimeout(context.Background(), TEST_TIMEOUT) defer cancel() @@ -110,7 +110,7 @@ func TestTokenIssuerCacheSetAndGetToken(t *testing.T) { func TestTokenIssuerCacheCheckExpirations(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), TEST_TIMEOUT) defer cancel() - cache := newTokenIssuerCache(&Client{Client: &http.Client{}, tokenEndpoint: "http://example.com"}) + cache := newTokenIssuerCache(&HTTPClient{Client: &http.Client{}, tokenEndpoint: "http://example.com"}) now := time.Now() tokenID1 := uuid.New() diff --git a/pkg/security/jwt/validator.go b/pkg/security/jwt/validator.go index e2f0fc63e..b9765568c 100644 --- a/pkg/security/jwt/validator.go +++ b/pkg/security/jwt/validator.go @@ -32,7 +32,7 @@ var ( type config struct { verifyTrust bool - clients map[string]*Client + clients map[string]TokenIssuerClient cacheExpiration time.Duration stop <-chan struct{} } @@ -47,7 +47,7 @@ func (o optionFunc) apply(c *config) { o(c) } -func WithTrustVerification(clients map[string]*Client, cacheExpiration time.Duration, stop <-chan struct{}) Option { +func WithTrustVerification(clients map[string]TokenIssuerClient, cacheExpiration time.Duration, stop <-chan struct{}) Option { return optionFunc(func(c *config) { c.verifyTrust = true c.clients = clients diff --git a/pkg/security/jwt/validator/config.go b/pkg/security/jwt/validator/config.go index 9a971f1c5..fdd3f882d 100644 --- a/pkg/security/jwt/validator/config.go +++ b/pkg/security/jwt/validator/config.go @@ -23,13 +23,12 @@ func (c *AuthorityConfig) Validate() error { } type TokenTrustVerificationConfig struct { - Enabled bool `yaml:"enabled,omitempty" json:"enabled,omitempty"` CacheExpiration time.Duration `yaml:"cacheExpiration,omitempty" json:"cacheExpiration,omitempty"` } func (c *TokenTrustVerificationConfig) Validate() error { - if c.Enabled && c.CacheExpiration == 0 { - return fmt.Errorf("cacheExpiration('%v')", c.CacheExpiration) + if c.CacheExpiration == 0 { + c.CacheExpiration = time.Second * 30 } return nil } diff --git a/pkg/security/jwt/validator/validator.go b/pkg/security/jwt/validator/validator.go index d0064d513..1565f00a8 100644 --- a/pkg/security/jwt/validator/validator.go +++ b/pkg/security/jwt/validator/validator.go @@ -44,7 +44,8 @@ func (v *Validator) GetParser() *jwtValidator.Validator { type GetOpenIDConfigurationFunc func(ctx context.Context, c *http.Client, authority string) (openid.Config, error) type Options struct { - getOpenIDConfiguration GetOpenIDConfigurationFunc + getOpenIDConfiguration GetOpenIDConfigurationFunc + customTokenIssuerClients map[string]jwtValidator.TokenIssuerClient } func WithGetOpenIDConfiguration(f GetOpenIDConfigurationFunc) func(o *Options) { @@ -53,9 +54,16 @@ func WithGetOpenIDConfiguration(f GetOpenIDConfigurationFunc) func(o *Options) { } } +func WithCustomTokenIssuerClients(clients map[string]jwtValidator.TokenIssuerClient) func(o *Options) { + return func(o *Options) { + o.customTokenIssuerClients = clients + } +} + func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider, opts ...func(o *Options)) (*Validator, error) { options := Options{ - getOpenIDConfiguration: openid.GetConfiguration, + getOpenIDConfiguration: openid.GetConfiguration, + customTokenIssuerClients: make(map[string]jwtValidator.TokenIssuerClient), } for _, o := range opts { o(&options) @@ -64,7 +72,7 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg keys := jwtValidator.NewMultiKeyCache() var onClose fn.FuncList openIDConfigurations := make([]openid.Config, 0, len(config.Endpoints)) - clients := make(map[string]*jwtValidator.Client, len(config.Endpoints)) + clients := make(map[string]jwtValidator.TokenIssuerClient, len(config.Endpoints)) for _, authority := range config.Endpoints { httpClient, err := client.New(authority.HTTP, fileWatcher, logger, tracerProvider) if err != nil { @@ -77,6 +85,9 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg if options.getOpenIDConfiguration == nil { return nil, errors.New("GetOpenIDConfiguration is nil") } + if options.customTokenIssuerClients == nil { + return nil, errors.New("customTokenIssuerClients is nil") + } openIDCfg, err := options.getOpenIDConfiguration(ctx2, httpClient.HTTP(), authority.Authority) if err != nil { @@ -88,8 +99,12 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg issuer := pkgHttpUri.CanonicalURI(openIDCfg.Issuer) keys.Add(issuer, openIDCfg.JWKSURL, httpClient.HTTP()) openIDConfigurations = append(openIDConfigurations, openIDCfg) - if config.TokenVerification.Enabled && openIDCfg.PlgdTokensEndpoint != "" { - clients[issuer] = jwtValidator.NewClient(httpClient.HTTP(), openIDCfg.PlgdTokensEndpoint) + if openIDCfg.PlgdTokensEndpoint != "" { + if opts, ok := options.customTokenIssuerClients[issuer]; ok { + clients[issuer] = opts + continue + } + clients[issuer] = jwtValidator.NewHTTPClient(httpClient.HTTP(), openIDCfg.PlgdTokensEndpoint) } } diff --git a/test/config/config.go b/test/config/config.go index 32d8a82ee..5511fcd79 100644 --- a/test/config/config.go +++ b/test/config/config.go @@ -307,7 +307,6 @@ func MakeValidatorConfig() validator.Config { }, }, TokenVerification: validator.TokenTrustVerificationConfig{ - Enabled: true, CacheExpiration: VALIDATOR_CACHE_EXPIRATION, }, }