Skip to content

Commit

Permalink
Adding middleware to inject Auth token for internal requests to front…
Browse files Browse the repository at this point in the history
  • Loading branch information
iamrodrigo authored Aug 30, 2021
1 parent fb8e782 commit 5191468
Show file tree
Hide file tree
Showing 17 changed files with 164 additions and 103 deletions.
66 changes: 54 additions & 12 deletions client/clientBean.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"sync/atomic"
"time"

clientworker "go.uber.org/cadence/worker"
"go.uber.org/yarpc"
"go.uber.org/yarpc/api/peer"
"go.uber.org/yarpc/api/transport"
Expand All @@ -42,6 +43,8 @@ import (
"github.com/uber/cadence/client/frontend"
"github.com/uber/cadence/client/history"
"github.com/uber/cadence/client/matching"
"github.com/uber/cadence/common"
"github.com/uber/cadence/common/authorization"
"github.com/uber/cadence/common/cluster"
"github.com/uber/cadence/common/log"
"github.com/uber/cadence/common/log/tag"
Expand All @@ -66,10 +69,14 @@ type (
SetRemoteFrontendClient(cluster string, client frontend.Client)
}

DispatcherOptions struct {
AuthProvider clientworker.AuthorizationProvider
}

// DispatcherProvider provides a dispatcher to a given address
DispatcherProvider interface {
GetTChannel(name string, address string) (*yarpc.Dispatcher, error)
GetGRPC(name string, address string) (*yarpc.Dispatcher, error)
GetTChannel(name string, address string, options *DispatcherOptions) (*yarpc.Dispatcher, error)
GetGRPC(name string, address string, options *DispatcherOptions) (*yarpc.Dispatcher, error)
}

clientBeanImpl struct {
Expand Down Expand Up @@ -119,14 +126,26 @@ func NewClientBean(factory Factory, dispatcherProvider DispatcherProvider, clust
continue
}

var dispatcherOptions *DispatcherOptions
if info.AuthorizationProvider.Enable {
authProvider, err := authorization.GetAuthProviderClient(info.AuthorizationProvider.PrivateKey)
if err != nil {
return nil, err
}
dispatcherOptions = &DispatcherOptions{
AuthProvider: authProvider,
}
}

var dispatcher *yarpc.Dispatcher
var err error
switch info.RPCTransport {
case tchannel.TransportName:
dispatcher, err = dispatcherProvider.GetTChannel(info.RPCName, info.RPCAddress)
dispatcher, err = dispatcherProvider.GetTChannel(info.RPCName, info.RPCAddress, dispatcherOptions)
case grpc.TransportName:
dispatcher, err = dispatcherProvider.GetGRPC(info.RPCName, info.RPCAddress)
dispatcher, err = dispatcherProvider.GetGRPC(info.RPCName, info.RPCAddress, dispatcherOptions)
}

if err != nil {
return nil, err
}
Expand Down Expand Up @@ -265,7 +284,7 @@ func NewDNSYarpcDispatcherProvider(logger log.Logger, interval time.Duration) Di
}
}

func (p *dnsDispatcherProvider) GetTChannel(serviceName string, address string) (*yarpc.Dispatcher, error) {
func (p *dnsDispatcherProvider) GetTChannel(serviceName string, address string, options *DispatcherOptions) (*yarpc.Dispatcher, error) {
tchanTransport, err := tchannel.NewTransport(
tchannel.ServiceName(serviceName),
// this aim to get rid of the annoying popup about accepting incoming network connections
Expand All @@ -284,10 +303,10 @@ func (p *dnsDispatcherProvider) GetTChannel(serviceName string, address string)
outbound := tchanTransport.NewOutbound(peerList)

p.logger.Info("Creating TChannel dispatcher outbound", tag.Address(address))
return p.createOutboundDispatcher(serviceName, outbound)
return p.createOutboundDispatcher(serviceName, outbound, options)
}

func (p *dnsDispatcherProvider) GetGRPC(serviceName string, address string) (*yarpc.Dispatcher, error) {
func (p *dnsDispatcherProvider) GetGRPC(serviceName string, address string, options *DispatcherOptions) (*yarpc.Dispatcher, error) {
grpcTransport := grpc.NewTransport()

peerList := roundrobin.New(grpcTransport)
Expand All @@ -299,27 +318,50 @@ func (p *dnsDispatcherProvider) GetGRPC(serviceName string, address string) (*ya
outbound := grpcTransport.NewOutbound(peerList)

p.logger.Info("Creating GRPC dispatcher outbound", tag.Address(address))
return p.createOutboundDispatcher(serviceName, outbound)
return p.createOutboundDispatcher(serviceName, outbound, options)
}

func (p *dnsDispatcherProvider) createOutboundDispatcher(serviceName string, outbound transport.UnaryOutbound) (*yarpc.Dispatcher, error) {
// Attach the outbound to the dispatcher (this will add middleware/logging/etc)
dispatcher := yarpc.NewDispatcher(yarpc.Config{
func (p *dnsDispatcherProvider) createOutboundDispatcher(serviceName string, outbound transport.UnaryOutbound, options *DispatcherOptions) (*yarpc.Dispatcher, error) {
cfg := yarpc.Config{
Name: crossDCCaller,
Outbounds: yarpc.Outbounds{
serviceName: transport.Outbounds{
Unary: outbound,
ServiceName: serviceName,
},
},
})
}
if options != nil && options.AuthProvider != nil {
cfg.OutboundMiddleware = yarpc.OutboundMiddleware{
Unary: &outboundMiddleware{authProvider: options.AuthProvider},
}
}

// Attach the outbound to the dispatcher (this will add middleware/logging/etc)
dispatcher := yarpc.NewDispatcher(cfg)

if err := dispatcher.Start(); err != nil {
return nil, err
}
return dispatcher, nil
}

type outboundMiddleware struct {
authProvider clientworker.AuthorizationProvider
}

func (om *outboundMiddleware) Call(ctx context.Context, request *transport.Request, out transport.UnaryOutbound) (*transport.Response, error) {
if om.authProvider != nil {
token, err := om.authProvider.GetAuthToken()
if err != nil {
return nil, err
}
request.Headers = request.Headers.
With(common.AuthorizationTokenHeaderName, string(token))
}
return out.Call(ctx, request)
}

func newDNSUpdater(list peer.List, dnsPort string, interval time.Duration, logger log.Logger) (*dnsUpdater, error) {
ss := strings.Split(dnsPort, ":")
if len(ss) != 2 {
Expand Down
27 changes: 21 additions & 6 deletions client/clientBean_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 13 additions & 1 deletion cmd/server/cadence/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/uber/cadence/common"
"github.com/uber/cadence/common/archiver"
"github.com/uber/cadence/common/archiver/provider"
"github.com/uber/cadence/common/authorization"
"github.com/uber/cadence/common/blobstore/filestore"
"github.com/uber/cadence/common/cluster"
"github.com/uber/cadence/common/config"
Expand Down Expand Up @@ -213,7 +214,18 @@ func (s *server) startService() common.Daemon {
}
}

dispatcher, err := params.DispatcherProvider.GetTChannel(common.FrontendServiceName, s.cfg.PublicClient.HostPort)
var options *client.DispatcherOptions
if s.cfg.Authorization.OAuthAuthorizer.Enable {
clusterName := s.cfg.ClusterGroupMetadata.CurrentClusterName
authProvider, err := authorization.GetAuthProviderClient(s.cfg.ClusterGroupMetadata.ClusterGroup[clusterName].AuthorizationProvider.PrivateKey)
if err != nil {
log.Fatalf("failed to create AuthProvider: %v", err.Error())
}
options = &client.DispatcherOptions{
AuthProvider: authProvider,
}
}
dispatcher, err := params.DispatcherProvider.GetTChannel(common.FrontendServiceName, s.cfg.PublicClient.HostPort, options)
if err != nil {
log.Fatalf("failed to construct dispatcher: %v", err)
}
Expand Down
12 changes: 12 additions & 0 deletions common/authorization/authorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ package authorization

import (
"context"
"fmt"
"io/ioutil"

clientworker "go.uber.org/cadence/worker"

"github.com/uber/cadence/common/types"
)
Expand Down Expand Up @@ -84,3 +88,11 @@ func NewPermission(permission string) Permission {
type Authorizer interface {
Authorize(ctx context.Context, attributes *Attributes) (Result, error)
}

func GetAuthProviderClient(privateKey string) (clientworker.AuthorizationProvider, error) {
pk, err := ioutil.ReadFile(privateKey)
if err != nil {
return nil, fmt.Errorf("invalid private key path %s", privateKey)
}
return clientworker.NewAdminJwtAuthorizationProvider(pk), nil
}
5 changes: 2 additions & 3 deletions common/authorization/factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ func cfgOAuth() config.Authorization {
OAuthAuthorizer: config.OAuthAuthorizer{
Enable: true,
JwtCredentials: config.JwtCredentials{
Algorithm: jwt.RS256.String(),
PublicKey: "public",
PrivateKey: "private",
Algorithm: jwt.RS256.String(),
PublicKey: "public",
},
MaxJwtTTL: 12345,
},
Expand Down
4 changes: 4 additions & 0 deletions common/authorization/oauthAuthorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ func (a *oauthAuthority) Authorize(
return Result{Decision: DecisionDeny}, err
}
token := call.Header(common.AuthorizationTokenHeaderName)
if token == "" {
a.log.Debug("request is not authorized", tag.Error(fmt.Errorf("token is not set in header")))
return Result{Decision: DecisionDeny}, nil
}
claims, err := a.parseToken(token, verifier)
if err != nil {
a.log.Debug("request is not authorized", tag.Error(err))
Expand Down
20 changes: 17 additions & 3 deletions common/authorization/oauthAutorizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,8 @@ func (s *oauthSuite) SetupTest() {
s.cfg = config.OAuthAuthorizer{
Enable: true,
JwtCredentials: config.JwtCredentials{
Algorithm: jwt.RS256.String(),
PublicKey: "../../config/credentials/keytest.pub",
PrivateKey: "../../config/credentials/keytest",
Algorithm: jwt.RS256.String(),
PublicKey: "../../config/credentials/keytest.pub",
},
MaxJwtTTL: 300000001,
}
Expand Down Expand Up @@ -139,6 +138,21 @@ func (s *oauthSuite) TestItIsAdmin() {
s.Equal(result.Decision, DecisionAllow)
}

func (s *oauthSuite) TestEmptyToken() {
ctx := context.Background()
ctx, call := encoding.NewInboundCall(ctx)
err := call.ReadFromRequest(&transport.Request{
Headers: transport.NewHeaders().With(common.AuthorizationTokenHeaderName, ""),
})
s.NoError(err)
authorizer := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
s.logger.On("Debug", "request is not authorized", mock.MatchedBy(func(t []tag.Tag) bool {
return fmt.Sprintf("%v", t[0].Field().Interface) == "token is not set in header"
}))
result, _ := authorizer.Authorize(ctx, &s.att)
s.Equal(result.Decision, DecisionDeny)
}

func (s *oauthSuite) TestGetDomainError() {
s.domainCache.EXPECT().GetDomain(s.att.DomainName).Return(nil, fmt.Errorf("error")).Times(1)
authorizer := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
Expand Down
3 changes: 0 additions & 3 deletions common/config/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ func (a *Authorization) validateOAuth() error {
if oauthConfig.MaxJwtTTL <= 0 {
return fmt.Errorf("[OAuthConfig] MaxTTL must be greater than 0")
}
if oauthConfig.JwtCredentials.PrivateKey == "" {
return fmt.Errorf("[OAuthConfig] PrivateKey can't be empty")
}
if oauthConfig.JwtCredentials.PublicKey == "" {
return fmt.Errorf("[OAuthConfig] PublicKey can't be empty")
}
Expand Down
35 changes: 6 additions & 29 deletions common/config/authorization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,34 +56,13 @@ func TestTTLIsZero(t *testing.T) {
assert.EqualError(t, err, "[OAuthConfig] MaxTTL must be greater than 0")
}

func TestPrivateKeyIsEmpty(t *testing.T) {
cfg := Authorization{
OAuthAuthorizer: OAuthAuthorizer{
Enable: true,
JwtCredentials: JwtCredentials{
Algorithm: "",
PublicKey: "",
PrivateKey: "",
},
MaxJwtTTL: 1000000,
},
NoopAuthorizer: NoopAuthorizer{
Enable: false,
},
}

err := cfg.Validate()
assert.EqualError(t, err, "[OAuthConfig] PrivateKey can't be empty")
}

func TestPublicKeyIsEmpty(t *testing.T) {
cfg := Authorization{
OAuthAuthorizer: OAuthAuthorizer{
Enable: true,
JwtCredentials: JwtCredentials{
Algorithm: "",
PublicKey: "",
PrivateKey: "private",
Algorithm: "",
PublicKey: "",
},
MaxJwtTTL: 1000000,
},
Expand All @@ -101,9 +80,8 @@ func TestAlgorithmIsInvalid(t *testing.T) {
OAuthAuthorizer: OAuthAuthorizer{
Enable: true,
JwtCredentials: JwtCredentials{
Algorithm: "SHA256",
PublicKey: "public",
PrivateKey: "private",
Algorithm: "SHA256",
PublicKey: "public",
},
MaxJwtTTL: 1000000,
},
Expand All @@ -121,9 +99,8 @@ func TestCorrectValidation(t *testing.T) {
OAuthAuthorizer: OAuthAuthorizer{
Enable: true,
JwtCredentials: JwtCredentials{
Algorithm: "RS256",
PublicKey: "public",
PrivateKey: "private",
Algorithm: "RS256",
PublicKey: "public",
},
MaxJwtTTL: 1000000,
},
Expand Down
11 changes: 11 additions & 0 deletions common/config/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ type (
// Allowed values: tchannel|grpc
// Default: tchannel
RPCTransport string `yaml:"rpcTransport"`
// AuthorizationProvider contains the information to authorize the cluster
AuthorizationProvider AuthorizationProvider `yaml:"authorizationProvider"`
}

AuthorizationProvider struct {
// Enable indicates if the auth provider is enabled
Enable bool `yaml:"enable"`
// Type auth provider type
Type string `yaml:"type"` // only supports OAuthAuthorization
// PrivateKey is the private key path
PrivateKey string `yaml:"privateKey"`
}
)

Expand Down
2 changes: 0 additions & 2 deletions common/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ type (
Algorithm string `yaml:"algorithm"`
// Public Key Path for verifying JWT token passed in from external clients
PublicKey string `yaml:"publicKey"`
// Private Key Path for creating JWT token
PrivateKey string `yaml:"privateKey"`
}

// Service contains the service specific config items
Expand Down
Loading

0 comments on commit 5191468

Please sign in to comment.