From 8d82170661ecf13293abfe0524d67e759b2c5d21 Mon Sep 17 00:00:00 2001 From: Daniel Adam Date: Mon, 29 Jul 2024 09:16:00 +0200 Subject: [PATCH] Fix static analysis issues and add tests --- grpc-gateway/service/getResources_test.go | 2 - pkg/net/http/pb/protojson.go | 1 - pkg/net/http/pb/protojson_test.go | 102 ++++++++++++++ pkg/security/jwt/tokenCache.go | 44 ++++-- pkg/security/jwt/tokenCache_internal_test.go | 141 +++++++++++++++++++ 5 files changed, 274 insertions(+), 16 deletions(-) create mode 100644 pkg/net/http/pb/protojson_test.go create mode 100644 pkg/security/jwt/tokenCache_internal_test.go diff --git a/grpc-gateway/service/getResources_test.go b/grpc-gateway/service/getResources_test.go index 8ad37813f..df2950b0b 100644 --- a/grpc-gateway/service/getResources_test.go +++ b/grpc-gateway/service/getResources_test.go @@ -13,7 +13,6 @@ import ( "github.com/plgd-dev/hub/v2/grpc-gateway/pb" grpcgwTest "github.com/plgd-dev/hub/v2/grpc-gateway/test" m2mOauthTest "github.com/plgd-dev/hub/v2/m2m-oauth-server/test" - "github.com/plgd-dev/hub/v2/pkg/log" pkgGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" pkgJwt "github.com/plgd-dev/hub/v2/pkg/security/jwt" "github.com/plgd-dev/hub/v2/resource-aggregate/commands" @@ -223,7 +222,6 @@ func TestRequestHandlerGetResourcesWithM2MTokenVerification(t *testing.T) { grpcCfg := grpcgwTest.MakeConfig(t) grpcCfg.APIs.GRPC.Authorization.TokenVerification.CacheExpiration = time.Second * 2 - grpcCfg.Log.Level = log.DebugLevel tearDown := service.SetUp(ctx, t, service.WithGRPCGWConfig(grpcCfg)) defer tearDown() validTokenStr := oauthTest.GetDefaultAccessToken(t) diff --git a/pkg/net/http/pb/protojson.go b/pkg/net/http/pb/protojson.go index 619a5cbfe..af3399580 100644 --- a/pkg/net/http/pb/protojson.go +++ b/pkg/net/http/pb/protojson.go @@ -43,7 +43,6 @@ func (d *Decoder) Unmarshal(code int, input io.Reader, v protoreflect.ProtoMessa if err != nil { return err } - d.logger.Debugf("data: %s\n", data) if code != http.StatusOK { diff --git a/pkg/net/http/pb/protojson_test.go b/pkg/net/http/pb/protojson_test.go new file mode 100644 index 000000000..2a06c6cd3 --- /dev/null +++ b/pkg/net/http/pb/protojson_test.go @@ -0,0 +1,102 @@ +package pb_test + +import ( + "bytes" + "net/http" + "testing" + + "github.com/plgd-dev/hub/v2/pkg/net/http/pb" + "github.com/plgd-dev/hub/v2/test" + "github.com/stretchr/testify/require" + "google.golang.org/genproto/googleapis/rpc/status" + grpcStatus "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/structpb" +) + +func TestUnmarshalError(t *testing.T) { + s := &status.Status{ + Code: http.StatusInternalServerError, + Message: "test error", + } + data, err := protojson.Marshal(s) + require.NoError(t, err) + + err = pb.UnmarshalError(data) + require.Error(t, err) + + st, ok := grpcStatus.FromError(err) + require.True(t, ok) + require.Equal(t, s.GetCode(), int32(st.Code())) + require.Equal(t, s.GetMessage(), st.Message()) +} + +func TestUnmarshal(t *testing.T) { + tests := []struct { + name string + code int + input []byte + wantGrpcError error + wantErr bool + want protoreflect.ProtoMessage + }{ + { + name: "Unmarshal success", + code: http.StatusOK, + input: func() []byte { + data, err := protojson.Marshal(structpb.NewStringValue("test")) + require.NoError(t, err) + return []byte(`{"result":` + string(data) + `}`) + }(), + want: structpb.NewStringValue("test"), + }, + { + name: "Unmarshal error status", + code: http.StatusInternalServerError, + input: []byte(`{"code": 500, "message": "test error"}`), + wantGrpcError: grpcStatus.ErrorProto(&status.Status{ + Code: http.StatusInternalServerError, + Message: "test error", + }), + }, + { + name: "Unmarshal error status (2)", + code: http.StatusOK, + input: []byte(`{"error": {"code": 500, "message": "test error"}}`), + wantGrpcError: grpcStatus.ErrorProto(&status.Status{ + Code: http.StatusInternalServerError, + Message: "test error", + }), + }, + { + name: "Invalid JSON", + code: http.StatusOK, + input: []byte(`invalid json`), + wantErr: true, + }, + { + name: "Empty result and error fields", + code: http.StatusOK, + input: []byte(`{}`), + want: &structpb.Struct{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var v structpb.Value + err := pb.Unmarshal(tt.code, bytes.NewReader(tt.input), &v) + if tt.wantGrpcError != nil { + require.ErrorIs(t, err, tt.wantGrpcError) + return + } + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + test.CheckProtobufs(t, tt.want, &v, test.RequireToCheckFunc(require.Equal)) + }) + } +} diff --git a/pkg/security/jwt/tokenCache.go b/pkg/security/jwt/tokenCache.go index 97b7d9a92..032061386 100644 --- a/pkg/security/jwt/tokenCache.go +++ b/pkg/security/jwt/tokenCache.go @@ -113,15 +113,27 @@ func (tc *tokenIssuerCache) getValidTokenRecordOrFuture(tokenID uuid.UUID) (toke return tf, nil } -func (tc *tokenIssuerCache) removeToken(tokenID uuid.UUID) { +func (tc *tokenIssuerCache) removeTokenRecord(tokenID uuid.UUID) { + tc.mutex.Lock() + defer tc.mutex.Unlock() delete(tc.tokens, tokenID) } -func (tc *tokenIssuerCache) setTokenRecord(tokenID uuid.UUID, tr *tokenRecord) { +func (tc *tokenIssuerCache) removeTokenRecordAndSetErrorOnFuture(tokenUUID uuid.UUID, setTRFuture future.SetFunc, err error) { + tc.removeTokenRecord(tokenUUID) + setTRFuture(nil, err) +} + +func (tc *tokenIssuerCache) setTokenRecord(tokenUUID uuid.UUID, tr *tokenRecord) { tf := makeTokenOrFuture(tr, nil) tc.mutex.Lock() defer tc.mutex.Unlock() - tc.tokens[tokenID] = tf + tc.tokens[tokenUUID] = tf +} + +func (tc *tokenIssuerCache) setTokenRecordAndWaitingFuture(tokenUUID uuid.UUID, tr *tokenRecord, setTRFuture future.SetFunc) { + tc.setTokenRecord(tokenUUID, tr) + setTRFuture(tr, nil) } func (tc *tokenIssuerCache) checkExpirations(now time.Time) { @@ -156,7 +168,6 @@ func (tc *tokenIssuerCache) verifyTokenByRequest(ctx context.Context, token, tok return nil, fmt.Errorf("cannot create request for GET %v: %w", uri.String(), err) } - // TODO: "Accept" -> pktNetHttp.AcceptHeaderKey: import cycle, must move to another package req.Header.Set("Accept", "application/protojson") req.Header.Set("Authorization", "bearer "+token) resp, err := tc.client.Do(req) @@ -210,17 +221,26 @@ func (t *TokenCache) getValidUntil(token *pb.Token) time.Time { return time.Now().Add(t.expiration) } +func getTokenUUID(tokenClaims jwt.Claims) (string, uuid.UUID, error) { + tokenID, err := getID(tokenClaims) + if err != nil { + return "", uuid.Nil, err + } + tokenUUID, err := uuid.Parse(tokenID) + if err != nil { + return "", uuid.Nil, err + } + return tokenID, tokenUUID, nil +} + func (t *TokenCache) VerifyTrust(ctx context.Context, issuer, token string, tokenClaims jwt.Claims) error { tc, ok := t.cache[issuer] if !ok { t.logger.Debugf("client not set for issuer %v, trust verification skipped", issuer) return nil } - tokenID, err := getID(tokenClaims) - if err != nil { - return err - } - tokenUUID, err := uuid.Parse(tokenID) + + tokenID, tokenUUID, err := getTokenUUID(tokenClaims) if err != nil { return err } @@ -241,8 +261,7 @@ func (t *TokenCache) VerifyTrust(ctx context.Context, issuer, token string, toke t.logger.Debugf("requesting token(id=%s) verification by m2m", tokenID) respToken, err := tc.verifyTokenByRequest(ctx, token, tokenID) if err != nil { - tc.removeToken(tokenUUID) - set(nil, err) + tc.removeTokenRecordAndSetErrorOnFuture(tokenUUID, set, err) return err } @@ -257,8 +276,7 @@ func (t *TokenCache) VerifyTrust(ctx context.Context, issuer, token string, toke validUntil := t.getValidUntil(respToken) tr := newTokenRecord(blacklisted, validUntil, onExpire) t.logger.Debugf("token(id=%s) set (blacklisted=%v, validUntil=%v)", tokenID, blacklisted, validUntil) - tc.setTokenRecord(tokenUUID, tr) - set(tr, nil) + tc.setTokenRecordAndWaitingFuture(tokenUUID, tr, set) if blacklisted { return ErrBlackListedToken diff --git a/pkg/security/jwt/tokenCache_internal_test.go b/pkg/security/jwt/tokenCache_internal_test.go new file mode 100644 index 000000000..e0e8ec240 --- /dev/null +++ b/pkg/security/jwt/tokenCache_internal_test.go @@ -0,0 +1,141 @@ +package jwt + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const TEST_TIMEOUT = time.Second * 10 + +func TestTokenRecord_IsExpired(t *testing.T) { + now := time.Now() + tests := []struct { + name string + recordTime time.Time + want bool + }{ + {"Not expired", now.Add(time.Hour), false}, + {"Expired", now.Add(-time.Hour), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + record := newTokenRecord(false, tt.recordTime, nil) + require.Equal(t, tt.want, record.IsExpired(now)) + }) + } +} + +func TestTokenIssuerCacheSetAndGetToken(t *testing.T) { + cache := newTokenIssuerCache(&Client{Client: &http.Client{}, tokenEndpoint: "http://example.com"}) + + ctx, cancel := context.WithTimeout(context.Background(), TEST_TIMEOUT) + defer cancel() + tokenID := uuid.New() + // token doesn't exist yet, so we should get a future with a set function + _, setTokenOrError := cache.getValidTokenRecordOrFuture(tokenID) + require.NotNil(t, setTokenOrError) + + // we can wait on the future in other goroutine + // -> setting error on the future should unblock the goroutine + waiting := make(chan struct{}) + done := make(chan struct{}) + go func() { + defer close(done) + tf2, setToken2 := cache.getValidTokenRecordOrFuture(tokenID) + assert.Nil(t, setToken2) + close(waiting) + _, err := tf2.Get(ctx) + assert.Error(t, err) + }() + + <-waiting + cache.removeTokenRecordAndSetErrorOnFuture(tokenID, setTokenOrError, errors.New("test")) + select { + case <-done: + case <-ctx.Done(): + require.Fail(t, "timeout") + } + + // get a new future + _, setTokenOrError = cache.getValidTokenRecordOrFuture(tokenID) + require.NotNil(t, setTokenOrError) + + // -> setting an expired token record should result in a future with a set function being returned + expiredIDs := []uuid.UUID{} + tr := newTokenRecord(false, time.Now().Add(-time.Hour), func(u uuid.UUID) { + expiredIDs = append(expiredIDs, u) + }) + cache.setTokenRecord(tokenID, tr) + _, setTokenOrError = cache.getValidTokenRecordOrFuture(tokenID) + require.NotNil(t, setTokenOrError) + require.Len(t, expiredIDs, 1) + require.Equal(t, tokenID, expiredIDs[0]) + + // -> finally, set valid token record + tr = newTokenRecord(false, time.Now().Add(time.Hour), nil) + waiting = make(chan struct{}) + done = make(chan struct{}) + go func() { + defer close(done) + tf2, setToken2 := cache.getValidTokenRecordOrFuture(tokenID) + assert.Nil(t, setToken2) + close(waiting) + result, err := tf2.Get(ctx) + assert.NoError(t, err) + assert.Equal(t, tr, result) + }() + + <-waiting + cache.setTokenRecordAndWaitingFuture(tokenID, tr, setTokenOrError) + select { + case <-done: + case <-ctx.Done(): + require.Fail(t, "timeout") + } + + // cache should return a token record now, not a future + tf, _ := cache.getValidTokenRecordOrFuture(tokenID) + _, ok := tf.tokenOrFuture.(*tokenRecord) + require.True(t, ok) +} + +func TestTokenIssuerCacheCheckExpirations(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TEST_TIMEOUT) + defer cancel() + cache := newTokenIssuerCache(&Client{Client: &http.Client{}, tokenEndpoint: "http://example.com"}) + + now := time.Now() + tokenID1 := uuid.New() + expiredIDs := []uuid.UUID{} + onExpire := func(u uuid.UUID) { + expiredIDs = append(expiredIDs, u) + } + tokenRecord1 := newTokenRecord(false, now.Add(-time.Hour), onExpire) + cache.setTokenRecord(tokenID1, tokenRecord1) + + tokenID2 := uuid.New() + tokenRecord2 := newTokenRecord(false, now.Add(time.Hour), onExpire) + cache.setTokenRecord(tokenID2, tokenRecord2) + + cache.checkExpirations(now) + + // tokenRecord1 should have been removed and we should get a future with a set function + _, setTf1 := cache.getValidTokenRecordOrFuture(tokenID1) + require.NotNil(t, setTf1) + require.Len(t, expiredIDs, 1) + require.Equal(t, tokenID1, expiredIDs[0]) + // tokenRecord2 should still be there + tf2, setTf2 := cache.getValidTokenRecordOrFuture(tokenID2) + require.Nil(t, setTf2) + result, err := tf2.Get(ctx) + require.NoError(t, err) + require.Equal(t, tokenRecord2, result) +}