Skip to content

Commit

Permalink
fixing m2m-oauth-server-blacklist-client
Browse files Browse the repository at this point in the history
  • Loading branch information
jkralik committed Aug 5, 2024
1 parent 830fbf6 commit 2307169
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 81 deletions.
63 changes: 50 additions & 13 deletions m2m-oauth-server/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

oauthsigner "github.com/plgd-dev/hub/v2/m2m-oauth-server/oauthSigner"
"github.com/plgd-dev/hub/v2/m2m-oauth-server/pb"
grpcService "github.com/plgd-dev/hub/v2/m2m-oauth-server/service/grpc"
httpService "github.com/plgd-dev/hub/v2/m2m-oauth-server/service/http"
"github.com/plgd-dev/hub/v2/m2m-oauth-server/store"
Expand All @@ -16,9 +17,11 @@ import (
"github.com/plgd-dev/hub/v2/pkg/fn"
"github.com/plgd-dev/hub/v2/pkg/fsnotify"
"github.com/plgd-dev/hub/v2/pkg/log"
"github.com/plgd-dev/hub/v2/pkg/net/grpc"
"github.com/plgd-dev/hub/v2/pkg/net/listener"
otelClient "github.com/plgd-dev/hub/v2/pkg/opentelemetry/collector/client"
certManagerServer "github.com/plgd-dev/hub/v2/pkg/security/certManager/server"
"github.com/plgd-dev/hub/v2/pkg/security/jwt"
"github.com/plgd-dev/hub/v2/pkg/security/jwt/validator"
"github.com/plgd-dev/hub/v2/pkg/security/openid"
"github.com/plgd-dev/hub/v2/pkg/service"
Expand Down Expand Up @@ -62,8 +65,8 @@ func createStore(ctx context.Context, config storeConfig.Config, fileWatcher *fs
return s, nil
}

func newHttpService(ctx context.Context, config HTTPConfig, validatorConfig validator.Config, getOpenIDConfiguration validator.GetOpenIDConfigurationFunc, tlsConfig certManagerServer.Config, ss *grpcService.M2MOAuthServiceServer, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*httpService.Service, func(), error) {
httpValidator, err := validator.New(ctx, validatorConfig, fileWatcher, logger, tracerProvider, validator.WithGetOpenIDConfiguration(getOpenIDConfiguration))
func newHttpService(ctx context.Context, config HTTPConfig, validatorConfig validator.Config, getOpenIDConfiguration validator.GetOpenIDConfigurationFunc, trustVerification map[string]jwt.TokenIssuerClient, tlsConfig certManagerServer.Config, ss *grpcService.M2MOAuthServiceServer, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*httpService.Service, func(), error) {
httpValidator, err := validator.New(ctx, validatorConfig, fileWatcher, logger, tracerProvider, validator.WithGetOpenIDConfiguration(getOpenIDConfiguration), validator.WithCustomTokenIssuerClients(trustVerification))
if err != nil {
return nil, nil, fmt.Errorf("cannot create http validator: %w", err)
}
Expand All @@ -82,8 +85,8 @@ func newHttpService(ctx context.Context, config HTTPConfig, validatorConfig vali
return httpService, httpValidator.Close, nil
}

func newGrpcService(ctx context.Context, config grpcService.Config, getOpenIDConfiguration validator.GetOpenIDConfigurationFunc, ss *grpcService.M2MOAuthServiceServer, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*grpcService.Service, func(), error) {
grpcValidator, err := validator.New(ctx, config.Authorization.Config, fileWatcher, logger, tracerProvider, validator.WithGetOpenIDConfiguration(getOpenIDConfiguration))
func newGrpcService(ctx context.Context, config grpcService.Config, getOpenIDConfiguration validator.GetOpenIDConfigurationFunc, trustVerification map[string]jwt.TokenIssuerClient, ss *grpcService.M2MOAuthServiceServer, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*grpcService.Service, func(), error) {
grpcValidator, err := validator.New(ctx, config.Authorization.Config, fileWatcher, logger, tracerProvider, validator.WithGetOpenIDConfiguration(getOpenIDConfiguration), validator.WithCustomTokenIssuerClients(trustVerification))
if err != nil {
return nil, nil, fmt.Errorf("cannot create grpc validator: %w", err)
}
Expand All @@ -95,6 +98,33 @@ func newGrpcService(ctx context.Context, config grpcService.Config, getOpenIDCon
return grpcService, grpcValidator.Close, nil
}

type tokenIssuerClient struct {
store store.Store
ownerClaim string
}

func (c *tokenIssuerClient) VerifyTokenByRequest(ctx context.Context, accessToken, tokenID string) (*pb.Token, error) {
owner, err := grpc.ParseOwnerFromJwtToken(c.ownerClaim, accessToken)
if err != nil {
return nil, fmt.Errorf("cannot parse owner from token: %w", err)
}
var token *pb.Token
err = c.store.GetTokens(ctx, owner, &pb.GetTokensRequest{
IdFilter: []string{tokenID},
IncludeBlacklisted: true,
}, func(v *pb.Token) error {
token = v
return nil
})
if err != nil {
return nil, fmt.Errorf("cannot get token(%v): %w", tokenID, err)
}
if token == nil {
return nil, fmt.Errorf("token(%v) not found", tokenID)
}
return token, nil
}

func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logger log.Logger) (*Service, error) {
otelClient, err := otelClient.New(ctx, config.Clients.OpenTelemetryCollector.Config, serviceName, fileWatcher, logger)
if err != nil {
Expand All @@ -104,13 +134,6 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg
closerFn.AddFunc(otelClient.Close)
tracerProvider := otelClient.GetTracerProvider()

getOpenIDCfg := func(ctx context.Context, c *http.Client, authority string) (openid.Config, error) {
if authority == config.OAuthSigner.GetAuthority() {
return httpService.GetOpenIDConfiguration(config.OAuthSigner.GetDomain()), nil
}
return openid.GetConfiguration(ctx, c, authority)
}

db, err := createStore(ctx, config.Clients.Storage, fileWatcher, logger, tracerProvider)
if err != nil {
closerFn.Execute()
Expand All @@ -122,6 +145,13 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg
}
})

getOpenIDCfg := func(ctx context.Context, c *http.Client, authority string) (openid.Config, error) {
if authority == config.OAuthSigner.GetAuthority() {
return httpService.GetOpenIDConfiguration(config.OAuthSigner.GetDomain()), nil
}
return openid.GetConfiguration(ctx, c, authority)
}

signer, err := oauthsigner.New(ctx, config.OAuthSigner, getOpenIDCfg, fileWatcher, logger, tracerProvider)
if err != nil {
closerFn.Execute()
Expand All @@ -131,14 +161,21 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg

m2mOAuthService := grpcService.NewM2MOAuthServerServer(db, signer, logger)

grpcService, grpcServiceClose, err := newGrpcService(ctx, config.APIs.GRPC, getOpenIDCfg, m2mOAuthService, fileWatcher, logger, tracerProvider)
customTokenIssuerClients := map[string]jwt.TokenIssuerClient{
config.OAuthSigner.GetDomain(): &tokenIssuerClient{
store: db,
ownerClaim: signer.GetOwnerClaim(),
},
}

grpcService, grpcServiceClose, err := newGrpcService(ctx, config.APIs.GRPC, getOpenIDCfg, customTokenIssuerClients, m2mOAuthService, fileWatcher, logger, tracerProvider)
if err != nil {
closerFn.Execute()
return nil, err
}
closerFn.AddFunc(grpcServiceClose)

httpService, httpServiceClose, err := newHttpService(ctx, config.APIs.HTTP, config.APIs.GRPC.Authorization.Config, getOpenIDCfg, config.APIs.GRPC.TLS,
httpService, httpServiceClose, err := newHttpService(ctx, config.APIs.HTTP, config.APIs.GRPC.Authorization.Config, getOpenIDCfg, customTokenIssuerClients, config.APIs.GRPC.TLS,
m2mOAuthService, fileWatcher, logger, tracerProvider)
if err != nil {
grpcService.Close()
Expand Down
14 changes: 2 additions & 12 deletions m2m-oauth-server/test/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ var JWTPrivateKeyOAuthClient = oauthsigner.Client{
AllowedScopes: nil,
JWTPrivateKey: oauthsigner.PrivateKeyJWTConfig{
Enabled: true,
Authorization: MakeValidatorConfig(),
Authorization: config.MakeValidatorConfig(),
},
}

Expand All @@ -63,16 +63,6 @@ var OAuthClients = oauthsigner.OAuthClientsConfig{
&JWTPrivateKeyOAuthClient,
}

func MakeValidatorConfig() validator.Config {
c := config.MakeValidatorConfig()
// tokens are verified by the m2m-oauth-server, so we want to disable the verification here to avoid infinite loop
// of token verification
c.TokenVerification = validator.TokenTrustVerificationConfig{
Enabled: false,
}
return c
}

func MakeConfig(t require.TestingT) service.Config {
var cfg service.Config

Expand All @@ -91,7 +81,7 @@ func MakeConfig(t require.TestingT) service.Config {
HTTP: config.MakeHttpClientConfig(),
},
)
cfg.APIs.GRPC.Authorization.Config = MakeValidatorConfig()
cfg.APIs.GRPC.Authorization.Config = config.MakeValidatorConfig()
cfg.Clients.Storage = MakeStoreConfig()

cfg.OAuthSigner.PrivateKeyFile = urischeme.URIScheme(os.Getenv("M2M_OAUTH_SERVER_PRIVATE_KEY"))
Expand Down
92 changes: 49 additions & 43 deletions pkg/security/jwt/tokenCache.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,47 @@ import (
"go.uber.org/atomic"
)

type Client struct {
type HTTPClient struct {
*http.Client
tokenEndpoint string
}

func NewClient(client *http.Client, tokenEndpoint string) *Client {
return &Client{
func (c *HTTPClient) VerifyTokenByRequest(ctx context.Context, token, tokenID string) (*pb.Token, error) {
uri, err := url.Parse(c.tokenEndpoint)
if err != nil {
return nil, fmt.Errorf("cannot parse tokenEndpoint %v: %w", c.tokenEndpoint, err)
}
query := uri.Query()
query.Add("idFilter", tokenID)
query.Add("includeBlacklisted", "true")
uri.RawQuery = query.Encode()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri.String(), nil)
if err != nil {
return nil, fmt.Errorf("cannot create request for GET %v: %w", uri.String(), err)
}

req.Header.Set("Accept", "application/protojson")
req.Header.Set("Authorization", "bearer "+token)
resp, err := c.Do(req)
if err != nil {
return nil, fmt.Errorf("cannot send request for GET %v: %w", c.tokenEndpoint, err)
}

defer func() {
_ = resp.Body.Close()
}()

var gotToken pb.Token
err = pkgHttpPb.Unmarshal(resp.StatusCode, resp.Body, &gotToken)
if err != nil {
return nil, err
}
return &gotToken, nil
}

func NewHTTPClient(client *http.Client, tokenEndpoint string) *HTTPClient {
return &HTTPClient{
Client: client,
tokenEndpoint: tokenEndpoint,
}
Expand Down Expand Up @@ -75,17 +109,19 @@ func (tf *tokenOrFuture) Get(ctx context.Context) (*tokenRecord, error) {
}

type tokenIssuerCache struct {
client *http.Client
tokenEndpoint string
tokens map[uuid.UUID]tokenOrFuture
mutex sync.Mutex
client TokenIssuerClient
tokens map[uuid.UUID]tokenOrFuture
mutex sync.Mutex
}

func newTokenIssuerCache(client *Client) *tokenIssuerCache {
type TokenIssuerClient interface {
VerifyTokenByRequest(ctx context.Context, token, tokenID string) (*pb.Token, error)
}

func newTokenIssuerCache(client TokenIssuerClient) *tokenIssuerCache {
return &tokenIssuerCache{
client: client.Client,
tokenEndpoint: client.tokenEndpoint,
tokens: make(map[uuid.UUID]tokenOrFuture),
client: client,
tokens: make(map[uuid.UUID]tokenOrFuture),
}
}

Expand Down Expand Up @@ -154,37 +190,7 @@ func (tc *tokenIssuerCache) checkExpirations(now time.Time) {
}

func (tc *tokenIssuerCache) verifyTokenByRequest(ctx context.Context, token, tokenID string) (*pb.Token, error) {
uri, err := url.Parse(tc.tokenEndpoint)
if err != nil {
return nil, fmt.Errorf("cannot parse tokenEndpoint %v: %w", tc.tokenEndpoint, err)
}
query := uri.Query()
query.Add("idFilter", tokenID)
query.Add("includeBlacklisted", "true")
uri.RawQuery = query.Encode()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri.String(), nil)
if err != nil {
return nil, fmt.Errorf("cannot create request for GET %v: %w", uri.String(), err)
}

req.Header.Set("Accept", "application/protojson")
req.Header.Set("Authorization", "bearer "+token)
resp, err := tc.client.Do(req)
if err != nil {
return nil, fmt.Errorf("cannot send request for GET %v: %w", tc.tokenEndpoint, err)
}

defer func() {
_ = resp.Body.Close()
}()

var gotToken pb.Token
err = pkgHttpPb.Unmarshal(resp.StatusCode, resp.Body, &gotToken)
if err != nil {
return nil, err
}
return &gotToken, nil
return tc.client.VerifyTokenByRequest(ctx, token, tokenID)
}

type TokenCache struct {
Expand All @@ -193,7 +199,7 @@ type TokenCache struct {
logger log.Logger
}

func NewTokenCache(clients map[string]*Client, expiration time.Duration, logger log.Logger) *TokenCache {
func NewTokenCache(clients map[string]TokenIssuerClient, expiration time.Duration, logger log.Logger) *TokenCache {
tc := &TokenCache{
expiration: expiration,
logger: logger,
Expand Down
4 changes: 2 additions & 2 deletions pkg/security/jwt/tokenCache_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestTokenRecord_IsExpired(t *testing.T) {
}

func TestTokenIssuerCacheSetAndGetToken(t *testing.T) {
cache := newTokenIssuerCache(&Client{Client: &http.Client{}, tokenEndpoint: "http://example.com"})
cache := newTokenIssuerCache(&HTTPClient{Client: &http.Client{}, tokenEndpoint: "http://example.com"})

ctx, cancel := context.WithTimeout(context.Background(), TEST_TIMEOUT)
defer cancel()
Expand Down Expand Up @@ -110,7 +110,7 @@ func TestTokenIssuerCacheSetAndGetToken(t *testing.T) {
func TestTokenIssuerCacheCheckExpirations(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), TEST_TIMEOUT)
defer cancel()
cache := newTokenIssuerCache(&Client{Client: &http.Client{}, tokenEndpoint: "http://example.com"})
cache := newTokenIssuerCache(&HTTPClient{Client: &http.Client{}, tokenEndpoint: "http://example.com"})

now := time.Now()
tokenID1 := uuid.New()
Expand Down
4 changes: 2 additions & 2 deletions pkg/security/jwt/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ var (

type config struct {
verifyTrust bool
clients map[string]*Client
clients map[string]TokenIssuerClient
cacheExpiration time.Duration
stop <-chan struct{}
}
Expand All @@ -47,7 +47,7 @@ func (o optionFunc) apply(c *config) {
o(c)
}

func WithTrustVerification(clients map[string]*Client, cacheExpiration time.Duration, stop <-chan struct{}) Option {
func WithTrustVerification(clients map[string]TokenIssuerClient, cacheExpiration time.Duration, stop <-chan struct{}) Option {
return optionFunc(func(c *config) {
c.verifyTrust = true
c.clients = clients
Expand Down
5 changes: 2 additions & 3 deletions pkg/security/jwt/validator/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@ func (c *AuthorityConfig) Validate() error {
}

type TokenTrustVerificationConfig struct {
Enabled bool `yaml:"enabled,omitempty" json:"enabled,omitempty"`
CacheExpiration time.Duration `yaml:"cacheExpiration,omitempty" json:"cacheExpiration,omitempty"`
}

func (c *TokenTrustVerificationConfig) Validate() error {
if c.Enabled && c.CacheExpiration == 0 {
return fmt.Errorf("cacheExpiration('%v')", c.CacheExpiration)
if c.CacheExpiration == 0 {
c.CacheExpiration = time.Second * 30
}
return nil
}
Expand Down
Loading

0 comments on commit 2307169

Please sign in to comment.