diff --git a/grpc-gateway/service/getResources_test.go b/grpc-gateway/service/getResources_test.go index 65dc6d074..8ad37813f 100644 --- a/grpc-gateway/service/getResources_test.go +++ b/grpc-gateway/service/getResources_test.go @@ -218,8 +218,7 @@ func TestRequestHandlerGetResources(t *testing.T) { func TestRequestHandlerGetResourcesWithM2MTokenVerification(t *testing.T) { deviceID := test.MustFindDeviceByName(test.TestDeviceName) - - ctx, cancel := context.WithTimeout(context.Background(), config.TEST_TIMEOUT*100) + ctx, cancel := context.WithTimeout(context.Background(), config.TEST_TIMEOUT) defer cancel() grpcCfg := grpcgwTest.MakeConfig(t) @@ -272,6 +271,7 @@ func TestRequestHandlerGetResourcesWithM2MTokenVerification(t *testing.T) { tokenID, err := token.GetID() require.NoError(t, err) m2mOauthTest.BlacklistTokens(ctx, t, []string{tokenID}, validTokenStr) + // request should fail _, err = getResources(pkgGrpc.CtxWithToken(ctx, tokenStr), c, req) require.ErrorContains(t, err, pkgJwt.ErrBlackListedToken.Error()) @@ -285,17 +285,38 @@ func TestRequestHandlerGetResourcesWithM2MTokenVerification(t *testing.T) { require.NotEmpty(t, values) pbTest.CmpResourceValues(t, []*pb.Resource{exp.Clone()}, values) - // parallel requests -> cache should be used, only a single request should be made + // parallel whitelisted requests -> cache should be used, only a single request should be made var wg sync.WaitGroup - newValidTokenStr := m2mOauthTest.GetDefaultAccessToken(t) + tokenStr2 := m2mOauthTest.GetDefaultAccessToken(t) + for range 5 { + wg.Add(1) + go func() { + defer wg.Done() + values2, err2 := getResources(pkgGrpc.CtxWithToken(ctx, tokenStr2), c, req) + assert.NoError(t, err2) + assert.NotEmpty(t, values2) + pbTest.CmpResourceValues(t, []*pb.Resource{exp.Clone()}, values2) + }() + } + wg.Wait() + + // wait for expiration + time.Sleep(grpcCfg.APIs.GRPC.Authorization.TokenVerification.CacheExpiration) + + // blacklist the token + tokenStr3 := m2mOauthTest.GetDefaultAccessToken(t) + token, err = pkgJwt.ParseToken(tokenStr3) + require.NoError(t, err) + tokenID, err = token.GetID() + require.NoError(t, err) + m2mOauthTest.BlacklistTokens(ctx, t, []string{tokenID}, validTokenStr) + // parallel blacklisted requests -> cache should be used, only a single request should be made for range 5 { wg.Add(1) go func() { defer wg.Done() - values, err := getResources(pkgGrpc.CtxWithToken(ctx, newValidTokenStr), c, req) - assert.NoError(t, err) - assert.NotEmpty(t, values) - pbTest.CmpResourceValues(t, []*pb.Resource{exp.Clone()}, values) + _, err2 := getResources(pkgGrpc.CtxWithToken(ctx, tokenStr3), c, req) + assert.ErrorContains(t, err2, pkgJwt.ErrBlackListedToken.Error()) }() } wg.Wait() diff --git a/pkg/security/jwt/tokenCache.go b/pkg/security/jwt/tokenCache.go index e8b274583..97b7d9a92 100644 --- a/pkg/security/jwt/tokenCache.go +++ b/pkg/security/jwt/tokenCache.go @@ -124,6 +124,58 @@ func (tc *tokenIssuerCache) setTokenRecord(tokenID uuid.UUID, tr *tokenRecord) { tc.tokens[tokenID] = tf } +func (tc *tokenIssuerCache) checkExpirations(now time.Time) { + expired := make(map[uuid.UUID]*tokenRecord, 8) + tc.mutex.Lock() + for tokenID, tf := range tc.tokens { + if tr, ok := tf.tokenOrFuture.(*tokenRecord); ok && tr.IsExpired(now) { + if tr.onExpire != nil { + expired[tokenID] = tr + } + delete(tc.tokens, tokenID) + } + } + tc.mutex.Unlock() + for tokenID, tr := range expired { + tr.onExpire(tokenID) + } +} + +func (tc *tokenIssuerCache) verifyTokenByRequest(ctx context.Context, token, tokenID string) (*pb.Token, error) { + uri, err := url.Parse(tc.tokenEndpoint) + if err != nil { + return nil, fmt.Errorf("cannot parse tokenEndpoint %v: %w", tc.tokenEndpoint, err) + } + query := uri.Query() + query.Add("idFilter", tokenID) + query.Add("includeBlacklisted", "true") + uri.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri.String(), nil) + if err != nil { + return nil, fmt.Errorf("cannot create request for GET %v: %w", uri.String(), err) + } + + // TODO: "Accept" -> pktNetHttp.AcceptHeaderKey: import cycle, must move to another package + req.Header.Set("Accept", "application/protojson") + req.Header.Set("Authorization", "bearer "+token) + resp, err := tc.client.Do(req) + if err != nil { + return nil, fmt.Errorf("cannot send request for GET %v: %w", tc.tokenEndpoint, err) + } + + defer func() { + _ = resp.Body.Close() + }() + + var gotToken pb.Token + err = pkgHttpPb.Unmarshal(resp.StatusCode, resp.Body, &gotToken) + if err != nil { + return nil, err + } + return &gotToken, nil +} + type TokenCache struct { expiration time.Duration cache map[string]*tokenIssuerCache @@ -135,11 +187,9 @@ func NewTokenCache(clients map[string]*Client, expiration time.Duration, logger expiration: expiration, logger: logger, } - if len(clients) > 0 { - tc.cache = make(map[string]*tokenIssuerCache) - for issuer, client := range clients { - tc.cache[issuer] = newTokenIssuerCache(client) - } + tc.cache = make(map[string]*tokenIssuerCache) + for issuer, client := range clients { + tc.cache[issuer] = newTokenIssuerCache(client) } return tc } @@ -161,7 +211,7 @@ func (t *TokenCache) getValidUntil(token *pb.Token) time.Time { } func (t *TokenCache) VerifyTrust(ctx context.Context, issuer, token string, tokenClaims jwt.Claims) error { - ic, ok := t.cache[issuer] + tc, ok := t.cache[issuer] if !ok { t.logger.Debugf("client not set for issuer %v, trust verification skipped", issuer) return nil @@ -175,7 +225,7 @@ func (t *TokenCache) VerifyTrust(ctx context.Context, issuer, token string, toke return err } t.logger.Debugf("checking trust for issuer(%v) for token(id=%s)", issuer, tokenID) - tf, set := ic.getValidTokenRecordOrFuture(tokenUUID) + tf, set := tc.getValidTokenRecordOrFuture(tokenUUID) if set == nil { tv, errG := tf.Get(ctx) if errG != nil { @@ -188,43 +238,10 @@ func (t *TokenCache) VerifyTrust(ctx context.Context, issuer, token string, toke return nil } - uri, err := url.Parse(ic.tokenEndpoint) - if err != nil { - ic.removeToken(tokenUUID) - set(nil, err) - return fmt.Errorf("cannot parse tokenEndpoint %v: %w", ic.tokenEndpoint, err) - } - query := uri.Query() - query.Add("idFilter", tokenID) - query.Add("includeBlacklisted", "true") - uri.RawQuery = query.Encode() - t.logger.Debugf("requesting token(id=%s) verification by m2m", tokenID) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri.String(), nil) - if err != nil { - ic.removeToken(tokenUUID) - set(nil, err) - return fmt.Errorf("cannot create request for GET %v: %w", uri.String(), err) - } - - req.Header.Set("Accept", "application/protojson") - req.Header.Set("Authorization", "bearer "+token) - resp, err := ic.client.Do(req) + respToken, err := tc.verifyTokenByRequest(ctx, token, tokenID) if err != nil { - ic.removeToken(tokenUUID) - set(nil, err) - return fmt.Errorf("cannot send request for GET %v: %w", ic.tokenEndpoint, err) - } - defer func() { - if errC := resp.Body.Close(); errC != nil { - t.logger.Errorf("cannot close response body: %w", errC) - } - }() - - var gotToken pb.Token - err = pkgHttpPb.Unmarshal(resp.StatusCode, resp.Body, &gotToken) - if err != nil { - ic.removeToken(tokenUUID) + tc.removeToken(tokenUUID) set(nil, err) return err } @@ -236,11 +253,11 @@ func (t *TokenCache) VerifyTrust(ctx context.Context, issuer, token string, toke } } - blacklisted := gotToken.GetBlacklisted().GetFlag() - validUntil := t.getValidUntil(&gotToken) + blacklisted := respToken.GetBlacklisted().GetFlag() + validUntil := t.getValidUntil(respToken) tr := newTokenRecord(blacklisted, validUntil, onExpire) t.logger.Debugf("token(id=%s) set (blacklisted=%v, validUntil=%v)", tokenID, blacklisted, validUntil) - ic.setTokenRecord(tokenUUID, tr) + tc.setTokenRecord(tokenUUID, tr) set(tr, nil) if blacklisted { @@ -248,3 +265,9 @@ func (t *TokenCache) VerifyTrust(ctx context.Context, issuer, token string, toke } return nil } + +func (t *TokenCache) CheckExpirations(now time.Time) { + for _, ic := range t.cache { + ic.checkExpirations(now) + } +} diff --git a/pkg/security/jwt/validator.go b/pkg/security/jwt/validator.go index 77cfb6498..e2f0fc63e 100644 --- a/pkg/security/jwt/validator.go +++ b/pkg/security/jwt/validator.go @@ -7,6 +7,7 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + "github.com/plgd-dev/go-coap/v3/pkg/runner/periodic" "github.com/plgd-dev/hub/v2/pkg/log" pkgHttpUri "github.com/plgd-dev/hub/v2/pkg/net/http/uri" ) @@ -33,6 +34,7 @@ type config struct { verifyTrust bool clients map[string]*Client cacheExpiration time.Duration + stop <-chan struct{} } type Option interface { @@ -45,11 +47,12 @@ func (o optionFunc) apply(c *config) { o(c) } -func WithTrustVerification(clients map[string]*Client, cacheExpiration time.Duration) Option { +func WithTrustVerification(clients map[string]*Client, cacheExpiration time.Duration, stop <-chan struct{}) Option { return optionFunc(func(c *config) { c.verifyTrust = true c.clients = clients c.cacheExpiration = cacheExpiration + c.stop = stop }) } @@ -62,8 +65,13 @@ func NewValidator(keyCache KeyCacheI, logger log.Logger, opts ...Option) *Valida keys: keyCache, verifyTrust: c.verifyTrust, } - if c.verifyTrust { + if c.verifyTrust && len(c.clients) > 0 { v.tokenCache = NewTokenCache(c.clients, c.cacheExpiration, logger) + add := periodic.New(c.stop, c.cacheExpiration/2) + add(func(now time.Time) bool { + v.tokenCache.CheckExpirations(now) + return true + }) } return v } diff --git a/pkg/security/jwt/validator/validator.go b/pkg/security/jwt/validator/validator.go index a9c40709a..d0064d513 100644 --- a/pkg/security/jwt/validator/validator.go +++ b/pkg/security/jwt/validator/validator.go @@ -95,7 +95,7 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg var vopts []jwtValidator.Option if len(clients) > 0 { - vopts = append(vopts, jwtValidator.WithTrustVerification(clients, config.TokenVerification.CacheExpiration)) + vopts = append(vopts, jwtValidator.WithTrustVerification(clients, config.TokenVerification.CacheExpiration, ctx.Done())) } return &Validator{