Skip to content

Commit

Permalink
Fix static analysis issues and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Danielius1922 authored and Daniel Adam committed Jul 29, 2024
1 parent 90f952a commit 8d82170
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 16 deletions.
2 changes: 0 additions & 2 deletions grpc-gateway/service/getResources_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"github.com/plgd-dev/hub/v2/grpc-gateway/pb"
grpcgwTest "github.com/plgd-dev/hub/v2/grpc-gateway/test"
m2mOauthTest "github.com/plgd-dev/hub/v2/m2m-oauth-server/test"
"github.com/plgd-dev/hub/v2/pkg/log"
pkgGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc"
pkgJwt "github.com/plgd-dev/hub/v2/pkg/security/jwt"
"github.com/plgd-dev/hub/v2/resource-aggregate/commands"
Expand Down Expand Up @@ -223,7 +222,6 @@ func TestRequestHandlerGetResourcesWithM2MTokenVerification(t *testing.T) {

grpcCfg := grpcgwTest.MakeConfig(t)
grpcCfg.APIs.GRPC.Authorization.TokenVerification.CacheExpiration = time.Second * 2
grpcCfg.Log.Level = log.DebugLevel
tearDown := service.SetUp(ctx, t, service.WithGRPCGWConfig(grpcCfg))
defer tearDown()
validTokenStr := oauthTest.GetDefaultAccessToken(t)
Expand Down
1 change: 0 additions & 1 deletion pkg/net/http/pb/protojson.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ func (d *Decoder) Unmarshal(code int, input io.Reader, v protoreflect.ProtoMessa
if err != nil {
return err
}

d.logger.Debugf("data: %s\n", data)

if code != http.StatusOK {
Expand Down
102 changes: 102 additions & 0 deletions pkg/net/http/pb/protojson_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package pb_test

import (
"bytes"
"net/http"
"testing"

"github.com/plgd-dev/hub/v2/pkg/net/http/pb"
"github.com/plgd-dev/hub/v2/test"
"github.com/stretchr/testify/require"
"google.golang.org/genproto/googleapis/rpc/status"
grpcStatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/structpb"
)

func TestUnmarshalError(t *testing.T) {
s := &status.Status{
Code: http.StatusInternalServerError,
Message: "test error",
}
data, err := protojson.Marshal(s)
require.NoError(t, err)

err = pb.UnmarshalError(data)
require.Error(t, err)

st, ok := grpcStatus.FromError(err)
require.True(t, ok)
require.Equal(t, s.GetCode(), int32(st.Code()))
require.Equal(t, s.GetMessage(), st.Message())
}

func TestUnmarshal(t *testing.T) {
tests := []struct {
name string
code int
input []byte
wantGrpcError error
wantErr bool
want protoreflect.ProtoMessage
}{
{
name: "Unmarshal success",
code: http.StatusOK,
input: func() []byte {
data, err := protojson.Marshal(structpb.NewStringValue("test"))
require.NoError(t, err)
return []byte(`{"result":` + string(data) + `}`)
}(),
want: structpb.NewStringValue("test"),
},
{
name: "Unmarshal error status",
code: http.StatusInternalServerError,
input: []byte(`{"code": 500, "message": "test error"}`),
wantGrpcError: grpcStatus.ErrorProto(&status.Status{
Code: http.StatusInternalServerError,
Message: "test error",
}),
},
{
name: "Unmarshal error status (2)",
code: http.StatusOK,
input: []byte(`{"error": {"code": 500, "message": "test error"}}`),
wantGrpcError: grpcStatus.ErrorProto(&status.Status{
Code: http.StatusInternalServerError,
Message: "test error",
}),
},
{
name: "Invalid JSON",
code: http.StatusOK,
input: []byte(`invalid json`),
wantErr: true,
},
{
name: "Empty result and error fields",
code: http.StatusOK,
input: []byte(`{}`),
want: &structpb.Struct{},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var v structpb.Value
err := pb.Unmarshal(tt.code, bytes.NewReader(tt.input), &v)
if tt.wantGrpcError != nil {
require.ErrorIs(t, err, tt.wantGrpcError)
return
}
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
test.CheckProtobufs(t, tt.want, &v, test.RequireToCheckFunc(require.Equal))
})
}
}
44 changes: 31 additions & 13 deletions pkg/security/jwt/tokenCache.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,27 @@ func (tc *tokenIssuerCache) getValidTokenRecordOrFuture(tokenID uuid.UUID) (toke
return tf, nil
}

func (tc *tokenIssuerCache) removeToken(tokenID uuid.UUID) {
func (tc *tokenIssuerCache) removeTokenRecord(tokenID uuid.UUID) {
tc.mutex.Lock()
defer tc.mutex.Unlock()
delete(tc.tokens, tokenID)
}

func (tc *tokenIssuerCache) setTokenRecord(tokenID uuid.UUID, tr *tokenRecord) {
func (tc *tokenIssuerCache) removeTokenRecordAndSetErrorOnFuture(tokenUUID uuid.UUID, setTRFuture future.SetFunc, err error) {
tc.removeTokenRecord(tokenUUID)
setTRFuture(nil, err)
}

func (tc *tokenIssuerCache) setTokenRecord(tokenUUID uuid.UUID, tr *tokenRecord) {
tf := makeTokenOrFuture(tr, nil)
tc.mutex.Lock()
defer tc.mutex.Unlock()
tc.tokens[tokenID] = tf
tc.tokens[tokenUUID] = tf
}

func (tc *tokenIssuerCache) setTokenRecordAndWaitingFuture(tokenUUID uuid.UUID, tr *tokenRecord, setTRFuture future.SetFunc) {
tc.setTokenRecord(tokenUUID, tr)
setTRFuture(tr, nil)
}

func (tc *tokenIssuerCache) checkExpirations(now time.Time) {
Expand Down Expand Up @@ -156,7 +168,6 @@ func (tc *tokenIssuerCache) verifyTokenByRequest(ctx context.Context, token, tok
return nil, fmt.Errorf("cannot create request for GET %v: %w", uri.String(), err)
}

// TODO: "Accept" -> pktNetHttp.AcceptHeaderKey: import cycle, must move to another package
req.Header.Set("Accept", "application/protojson")
req.Header.Set("Authorization", "bearer "+token)
resp, err := tc.client.Do(req)
Expand Down Expand Up @@ -210,17 +221,26 @@ func (t *TokenCache) getValidUntil(token *pb.Token) time.Time {
return time.Now().Add(t.expiration)
}

func getTokenUUID(tokenClaims jwt.Claims) (string, uuid.UUID, error) {
tokenID, err := getID(tokenClaims)
if err != nil {
return "", uuid.Nil, err
}
tokenUUID, err := uuid.Parse(tokenID)
if err != nil {
return "", uuid.Nil, err
}
return tokenID, tokenUUID, nil
}

func (t *TokenCache) VerifyTrust(ctx context.Context, issuer, token string, tokenClaims jwt.Claims) error {
tc, ok := t.cache[issuer]
if !ok {
t.logger.Debugf("client not set for issuer %v, trust verification skipped", issuer)
return nil
}
tokenID, err := getID(tokenClaims)
if err != nil {
return err
}
tokenUUID, err := uuid.Parse(tokenID)

tokenID, tokenUUID, err := getTokenUUID(tokenClaims)
if err != nil {
return err
}
Expand All @@ -241,8 +261,7 @@ func (t *TokenCache) VerifyTrust(ctx context.Context, issuer, token string, toke
t.logger.Debugf("requesting token(id=%s) verification by m2m", tokenID)
respToken, err := tc.verifyTokenByRequest(ctx, token, tokenID)
if err != nil {
tc.removeToken(tokenUUID)
set(nil, err)
tc.removeTokenRecordAndSetErrorOnFuture(tokenUUID, set, err)
return err
}

Expand All @@ -257,8 +276,7 @@ func (t *TokenCache) VerifyTrust(ctx context.Context, issuer, token string, toke
validUntil := t.getValidUntil(respToken)
tr := newTokenRecord(blacklisted, validUntil, onExpire)
t.logger.Debugf("token(id=%s) set (blacklisted=%v, validUntil=%v)", tokenID, blacklisted, validUntil)
tc.setTokenRecord(tokenUUID, tr)
set(tr, nil)
tc.setTokenRecordAndWaitingFuture(tokenUUID, tr, set)

if blacklisted {
return ErrBlackListedToken
Expand Down
141 changes: 141 additions & 0 deletions pkg/security/jwt/tokenCache_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package jwt

import (
"context"
"errors"
"net/http"
"testing"
"time"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const TEST_TIMEOUT = time.Second * 10

func TestTokenRecord_IsExpired(t *testing.T) {
now := time.Now()
tests := []struct {
name string
recordTime time.Time
want bool
}{
{"Not expired", now.Add(time.Hour), false},
{"Expired", now.Add(-time.Hour), true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
record := newTokenRecord(false, tt.recordTime, nil)
require.Equal(t, tt.want, record.IsExpired(now))
})
}
}

func TestTokenIssuerCacheSetAndGetToken(t *testing.T) {
cache := newTokenIssuerCache(&Client{Client: &http.Client{}, tokenEndpoint: "http://example.com"})

ctx, cancel := context.WithTimeout(context.Background(), TEST_TIMEOUT)
defer cancel()
tokenID := uuid.New()
// token doesn't exist yet, so we should get a future with a set function
_, setTokenOrError := cache.getValidTokenRecordOrFuture(tokenID)
require.NotNil(t, setTokenOrError)

// we can wait on the future in other goroutine
// -> setting error on the future should unblock the goroutine
waiting := make(chan struct{})
done := make(chan struct{})
go func() {
defer close(done)
tf2, setToken2 := cache.getValidTokenRecordOrFuture(tokenID)
assert.Nil(t, setToken2)
close(waiting)
_, err := tf2.Get(ctx)
assert.Error(t, err)
}()

<-waiting
cache.removeTokenRecordAndSetErrorOnFuture(tokenID, setTokenOrError, errors.New("test"))
select {
case <-done:
case <-ctx.Done():
require.Fail(t, "timeout")
}

// get a new future
_, setTokenOrError = cache.getValidTokenRecordOrFuture(tokenID)
require.NotNil(t, setTokenOrError)

// -> setting an expired token record should result in a future with a set function being returned
expiredIDs := []uuid.UUID{}
tr := newTokenRecord(false, time.Now().Add(-time.Hour), func(u uuid.UUID) {
expiredIDs = append(expiredIDs, u)
})
cache.setTokenRecord(tokenID, tr)
_, setTokenOrError = cache.getValidTokenRecordOrFuture(tokenID)
require.NotNil(t, setTokenOrError)
require.Len(t, expiredIDs, 1)
require.Equal(t, tokenID, expiredIDs[0])

// -> finally, set valid token record
tr = newTokenRecord(false, time.Now().Add(time.Hour), nil)
waiting = make(chan struct{})
done = make(chan struct{})
go func() {
defer close(done)
tf2, setToken2 := cache.getValidTokenRecordOrFuture(tokenID)
assert.Nil(t, setToken2)
close(waiting)
result, err := tf2.Get(ctx)
assert.NoError(t, err)
assert.Equal(t, tr, result)
}()

<-waiting
cache.setTokenRecordAndWaitingFuture(tokenID, tr, setTokenOrError)
select {
case <-done:
case <-ctx.Done():
require.Fail(t, "timeout")
}

// cache should return a token record now, not a future
tf, _ := cache.getValidTokenRecordOrFuture(tokenID)
_, ok := tf.tokenOrFuture.(*tokenRecord)
require.True(t, ok)
}

func TestTokenIssuerCacheCheckExpirations(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), TEST_TIMEOUT)
defer cancel()
cache := newTokenIssuerCache(&Client{Client: &http.Client{}, tokenEndpoint: "http://example.com"})

now := time.Now()
tokenID1 := uuid.New()
expiredIDs := []uuid.UUID{}
onExpire := func(u uuid.UUID) {
expiredIDs = append(expiredIDs, u)
}
tokenRecord1 := newTokenRecord(false, now.Add(-time.Hour), onExpire)
cache.setTokenRecord(tokenID1, tokenRecord1)

tokenID2 := uuid.New()
tokenRecord2 := newTokenRecord(false, now.Add(time.Hour), onExpire)
cache.setTokenRecord(tokenID2, tokenRecord2)

cache.checkExpirations(now)

// tokenRecord1 should have been removed and we should get a future with a set function
_, setTf1 := cache.getValidTokenRecordOrFuture(tokenID1)
require.NotNil(t, setTf1)
require.Len(t, expiredIDs, 1)
require.Equal(t, tokenID1, expiredIDs[0])
// tokenRecord2 should still be there
tf2, setTf2 := cache.getValidTokenRecordOrFuture(tokenID2)
require.Nil(t, setTf2)
result, err := tf2.Get(ctx)
require.NoError(t, err)
require.Equal(t, tokenRecord2, result)
}

0 comments on commit 8d82170

Please sign in to comment.