Skip to content

Commit

Permalink
Periodically clean-up token cache
Browse files Browse the repository at this point in the history
  • Loading branch information
Danielius1922 authored and Daniel Adam committed Jul 29, 2024
1 parent 433b803 commit 90f952a
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 56 deletions.
37 changes: 29 additions & 8 deletions grpc-gateway/service/getResources_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ func TestRequestHandlerGetResources(t *testing.T) {

func TestRequestHandlerGetResourcesWithM2MTokenVerification(t *testing.T) {
deviceID := test.MustFindDeviceByName(test.TestDeviceName)

ctx, cancel := context.WithTimeout(context.Background(), config.TEST_TIMEOUT*100)
ctx, cancel := context.WithTimeout(context.Background(), config.TEST_TIMEOUT)
defer cancel()

grpcCfg := grpcgwTest.MakeConfig(t)
Expand Down Expand Up @@ -272,6 +271,7 @@ func TestRequestHandlerGetResourcesWithM2MTokenVerification(t *testing.T) {
tokenID, err := token.GetID()
require.NoError(t, err)
m2mOauthTest.BlacklistTokens(ctx, t, []string{tokenID}, validTokenStr)
// request should fail
_, err = getResources(pkgGrpc.CtxWithToken(ctx, tokenStr), c, req)
require.ErrorContains(t, err, pkgJwt.ErrBlackListedToken.Error())

Expand All @@ -285,17 +285,38 @@ func TestRequestHandlerGetResourcesWithM2MTokenVerification(t *testing.T) {
require.NotEmpty(t, values)
pbTest.CmpResourceValues(t, []*pb.Resource{exp.Clone()}, values)

// parallel requests -> cache should be used, only a single request should be made
// parallel whitelisted requests -> cache should be used, only a single request should be made
var wg sync.WaitGroup
newValidTokenStr := m2mOauthTest.GetDefaultAccessToken(t)
tokenStr2 := m2mOauthTest.GetDefaultAccessToken(t)
for range 5 {
wg.Add(1)
go func() {
defer wg.Done()
values2, err2 := getResources(pkgGrpc.CtxWithToken(ctx, tokenStr2), c, req)
assert.NoError(t, err2)
assert.NotEmpty(t, values2)
pbTest.CmpResourceValues(t, []*pb.Resource{exp.Clone()}, values2)
}()
}
wg.Wait()

// wait for expiration
time.Sleep(grpcCfg.APIs.GRPC.Authorization.TokenVerification.CacheExpiration)

// blacklist the token
tokenStr3 := m2mOauthTest.GetDefaultAccessToken(t)
token, err = pkgJwt.ParseToken(tokenStr3)
require.NoError(t, err)
tokenID, err = token.GetID()
require.NoError(t, err)
m2mOauthTest.BlacklistTokens(ctx, t, []string{tokenID}, validTokenStr)
// parallel blacklisted requests -> cache should be used, only a single request should be made
for range 5 {
wg.Add(1)
go func() {
defer wg.Done()
values, err := getResources(pkgGrpc.CtxWithToken(ctx, newValidTokenStr), c, req)
assert.NoError(t, err)
assert.NotEmpty(t, values)
pbTest.CmpResourceValues(t, []*pb.Resource{exp.Clone()}, values)
_, err2 := getResources(pkgGrpc.CtxWithToken(ctx, tokenStr3), c, req)
assert.ErrorContains(t, err2, pkgJwt.ErrBlackListedToken.Error())
}()
}
wg.Wait()
Expand Down
113 changes: 68 additions & 45 deletions pkg/security/jwt/tokenCache.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,58 @@ func (tc *tokenIssuerCache) setTokenRecord(tokenID uuid.UUID, tr *tokenRecord) {
tc.tokens[tokenID] = tf
}

func (tc *tokenIssuerCache) checkExpirations(now time.Time) {
expired := make(map[uuid.UUID]*tokenRecord, 8)
tc.mutex.Lock()
for tokenID, tf := range tc.tokens {
if tr, ok := tf.tokenOrFuture.(*tokenRecord); ok && tr.IsExpired(now) {
if tr.onExpire != nil {
expired[tokenID] = tr
}
delete(tc.tokens, tokenID)
}
}
tc.mutex.Unlock()
for tokenID, tr := range expired {
tr.onExpire(tokenID)
}
}

func (tc *tokenIssuerCache) verifyTokenByRequest(ctx context.Context, token, tokenID string) (*pb.Token, error) {
uri, err := url.Parse(tc.tokenEndpoint)
if err != nil {
return nil, fmt.Errorf("cannot parse tokenEndpoint %v: %w", tc.tokenEndpoint, err)
}
query := uri.Query()
query.Add("idFilter", tokenID)
query.Add("includeBlacklisted", "true")
uri.RawQuery = query.Encode()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri.String(), nil)
if err != nil {
return nil, fmt.Errorf("cannot create request for GET %v: %w", uri.String(), err)
}

// TODO: "Accept" -> pktNetHttp.AcceptHeaderKey: import cycle, must move to another package
req.Header.Set("Accept", "application/protojson")
req.Header.Set("Authorization", "bearer "+token)
resp, err := tc.client.Do(req)
if err != nil {
return nil, fmt.Errorf("cannot send request for GET %v: %w", tc.tokenEndpoint, err)
}

defer func() {
_ = resp.Body.Close()
}()

var gotToken pb.Token
err = pkgHttpPb.Unmarshal(resp.StatusCode, resp.Body, &gotToken)
if err != nil {
return nil, err
}
return &gotToken, nil
}

type TokenCache struct {
expiration time.Duration
cache map[string]*tokenIssuerCache
Expand All @@ -135,11 +187,9 @@ func NewTokenCache(clients map[string]*Client, expiration time.Duration, logger
expiration: expiration,
logger: logger,
}
if len(clients) > 0 {
tc.cache = make(map[string]*tokenIssuerCache)
for issuer, client := range clients {
tc.cache[issuer] = newTokenIssuerCache(client)
}
tc.cache = make(map[string]*tokenIssuerCache)
for issuer, client := range clients {
tc.cache[issuer] = newTokenIssuerCache(client)
}
return tc
}
Expand All @@ -161,7 +211,7 @@ func (t *TokenCache) getValidUntil(token *pb.Token) time.Time {
}

func (t *TokenCache) VerifyTrust(ctx context.Context, issuer, token string, tokenClaims jwt.Claims) error {
ic, ok := t.cache[issuer]
tc, ok := t.cache[issuer]
if !ok {
t.logger.Debugf("client not set for issuer %v, trust verification skipped", issuer)
return nil
Expand All @@ -175,7 +225,7 @@ func (t *TokenCache) VerifyTrust(ctx context.Context, issuer, token string, toke
return err
}
t.logger.Debugf("checking trust for issuer(%v) for token(id=%s)", issuer, tokenID)
tf, set := ic.getValidTokenRecordOrFuture(tokenUUID)
tf, set := tc.getValidTokenRecordOrFuture(tokenUUID)
if set == nil {
tv, errG := tf.Get(ctx)
if errG != nil {
Expand All @@ -188,43 +238,10 @@ func (t *TokenCache) VerifyTrust(ctx context.Context, issuer, token string, toke
return nil
}

uri, err := url.Parse(ic.tokenEndpoint)
if err != nil {
ic.removeToken(tokenUUID)
set(nil, err)
return fmt.Errorf("cannot parse tokenEndpoint %v: %w", ic.tokenEndpoint, err)
}
query := uri.Query()
query.Add("idFilter", tokenID)
query.Add("includeBlacklisted", "true")
uri.RawQuery = query.Encode()

t.logger.Debugf("requesting token(id=%s) verification by m2m", tokenID)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri.String(), nil)
if err != nil {
ic.removeToken(tokenUUID)
set(nil, err)
return fmt.Errorf("cannot create request for GET %v: %w", uri.String(), err)
}

req.Header.Set("Accept", "application/protojson")
req.Header.Set("Authorization", "bearer "+token)
resp, err := ic.client.Do(req)
respToken, err := tc.verifyTokenByRequest(ctx, token, tokenID)
if err != nil {
ic.removeToken(tokenUUID)
set(nil, err)
return fmt.Errorf("cannot send request for GET %v: %w", ic.tokenEndpoint, err)
}
defer func() {
if errC := resp.Body.Close(); errC != nil {
t.logger.Errorf("cannot close response body: %w", errC)
}
}()

var gotToken pb.Token
err = pkgHttpPb.Unmarshal(resp.StatusCode, resp.Body, &gotToken)
if err != nil {
ic.removeToken(tokenUUID)
tc.removeToken(tokenUUID)
set(nil, err)
return err
}
Expand All @@ -236,15 +253,21 @@ func (t *TokenCache) VerifyTrust(ctx context.Context, issuer, token string, toke
}
}

blacklisted := gotToken.GetBlacklisted().GetFlag()
validUntil := t.getValidUntil(&gotToken)
blacklisted := respToken.GetBlacklisted().GetFlag()
validUntil := t.getValidUntil(respToken)
tr := newTokenRecord(blacklisted, validUntil, onExpire)
t.logger.Debugf("token(id=%s) set (blacklisted=%v, validUntil=%v)", tokenID, blacklisted, validUntil)
ic.setTokenRecord(tokenUUID, tr)
tc.setTokenRecord(tokenUUID, tr)
set(tr, nil)

if blacklisted {
return ErrBlackListedToken
}
return nil
}

func (t *TokenCache) CheckExpirations(now time.Time) {
for _, ic := range t.cache {
ic.checkExpirations(now)
}
}
12 changes: 10 additions & 2 deletions pkg/security/jwt/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/plgd-dev/go-coap/v3/pkg/runner/periodic"
"github.com/plgd-dev/hub/v2/pkg/log"
pkgHttpUri "github.com/plgd-dev/hub/v2/pkg/net/http/uri"
)
Expand All @@ -33,6 +34,7 @@ type config struct {
verifyTrust bool
clients map[string]*Client
cacheExpiration time.Duration
stop <-chan struct{}
}

type Option interface {
Expand All @@ -45,11 +47,12 @@ func (o optionFunc) apply(c *config) {
o(c)
}

func WithTrustVerification(clients map[string]*Client, cacheExpiration time.Duration) Option {
func WithTrustVerification(clients map[string]*Client, cacheExpiration time.Duration, stop <-chan struct{}) Option {
return optionFunc(func(c *config) {
c.verifyTrust = true
c.clients = clients
c.cacheExpiration = cacheExpiration
c.stop = stop
})
}

Expand All @@ -62,8 +65,13 @@ func NewValidator(keyCache KeyCacheI, logger log.Logger, opts ...Option) *Valida
keys: keyCache,
verifyTrust: c.verifyTrust,
}
if c.verifyTrust {
if c.verifyTrust && len(c.clients) > 0 {
v.tokenCache = NewTokenCache(c.clients, c.cacheExpiration, logger)
add := periodic.New(c.stop, c.cacheExpiration/2)
add(func(now time.Time) bool {
v.tokenCache.CheckExpirations(now)
return true
})
}
return v
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/security/jwt/validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg

var vopts []jwtValidator.Option
if len(clients) > 0 {
vopts = append(vopts, jwtValidator.WithTrustVerification(clients, config.TokenVerification.CacheExpiration))
vopts = append(vopts, jwtValidator.WithTrustVerification(clients, config.TokenVerification.CacheExpiration, ctx.Done()))
}

return &Validator{
Expand Down

0 comments on commit 90f952a

Please sign in to comment.