Skip to content

Commit

Permalink
added OAuth2 displayName and pkce options
Browse files Browse the repository at this point in the history
  • Loading branch information
ganigeorgiev committed Nov 29, 2023
1 parent 9957330 commit b283ee2
Show file tree
Hide file tree
Showing 65 changed files with 421 additions and 226 deletions.
9 changes: 8 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Dao.SaveRequest(...) -> Dao.SaveLog(...)
```

- removed app.IsDebug() and the `--debug` flag
- removed `app.IsDebug()` and the `--debug` flag

- (@todo docs) Implemented `slog.Logger` via `app.Logger()`.
Logs db writes are debounced and batched. DB write happens on
Expand All @@ -36,6 +36,13 @@
- Added new `filesystem.Copy(src, dest)` method to copy existing files from one location to another.
_This is usually useful when duplicating records with file fields programmatically._

- Added new `displayName` field for each `listAuthMethods()` OAuth2 provider item.
_The value of the `displayName` property is currently configurable from the UI only for the OIDC providers._

- Added new `PKCE()` and `SetPKCE(enable)` OAuth2 methods to indicate whether the PKCE flow is supported or not.
_The PKCE value is currently configurable from the UI only for the OIDC providers._
_This was added to accommodate OIDC providers that may throw an error if unsupported PKCE params are submitted with the auth request (eg. LinkedIn; see [#3799](https://github.com/pocketbase/pocketbase/discussions/3799#discussioncomment-7640312))._


## v0.20.0-rc3

Expand Down
58 changes: 33 additions & 25 deletions apis/record_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,15 @@ func (api *recordAuthApi) authRefresh(c echo.Context) error {
}

type providerInfo struct {
Name string `json:"name"`
State string `json:"state"`
Name string `json:"name"`
DisplayName string `json:"displayName"`
State string `json:"state"`
AuthUrl string `json:"authUrl"`
// technically could be omitted if the provider doesn't support PKCE,
// but to avoid breaking existing typed clients we'll return them as empty string
CodeVerifier string `json:"codeVerifier"`
CodeChallenge string `json:"codeChallenge"`
CodeChallengeMethod string `json:"codeChallengeMethod"`
AuthUrl string `json:"authUrl"`
}

func (api *recordAuthApi) authMethods(c echo.Context) error {
Expand All @@ -91,9 +94,9 @@ func (api *recordAuthApi) authMethods(c echo.Context) error {
authOptions := collection.AuthOptions()

result := struct {
AuthProviders []providerInfo `json:"authProviders"`
UsernamePassword bool `json:"usernamePassword"`
EmailPassword bool `json:"emailPassword"`
AuthProviders []providerInfo `json:"authProviders"`
}{
UsernamePassword: authOptions.AllowUsernameAuth,
EmailPassword: authOptions.AllowEmailAuth,
Expand Down Expand Up @@ -125,36 +128,41 @@ func (api *recordAuthApi) authMethods(c echo.Context) error {
continue // skip provider
}

state := security.RandomString(30)
codeVerifier := security.RandomString(43)
codeChallenge := security.S256Challenge(codeVerifier)
codeChallengeMethod := "S256"
urlOpts := []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", codeChallengeMethod),
info := providerInfo{
Name: name,
DisplayName: provider.DisplayName(),
State: security.RandomString(30),
}

if info.DisplayName == "" {
info.DisplayName = name
}

urlOpts := []oauth2.AuthCodeOption{}

// custom providers url options
switch name {
case auth.NameApple:
// see https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/incorporating_sign_in_with_apple_into_other_platforms#3332113
urlOpts = append(urlOpts, oauth2.SetAuthURLParam("response_mode", "query"))
case auth.NameVK:
// vk currently doesn't support PKCE for server-side authorization
urlOpts = []oauth2.AuthCodeOption{}
}

result.AuthProviders = append(result.AuthProviders, providerInfo{
Name: name,
State: state,
CodeVerifier: codeVerifier,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
AuthUrl: provider.BuildAuthUrl(
state,
urlOpts...,
) + "&redirect_uri=", // empty redirect_uri so that users can append their redirect url
})
if provider.PKCE() {
info.CodeVerifier = security.RandomString(43)
info.CodeChallenge = security.S256Challenge(info.CodeVerifier)
info.CodeChallengeMethod = "S256"
urlOpts = append(urlOpts,
oauth2.SetAuthURLParam("code_challenge", info.CodeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", info.CodeChallengeMethod),
)
}

info.AuthUrl = provider.BuildAuthUrl(
info.State,
urlOpts...,
) + "&redirect_uri=" // empty redirect_uri so that users can append their redirect url

result.AuthProviders = append(result.AuthProviders, info)
}

// sort providers
Expand Down
14 changes: 8 additions & 6 deletions forms/record_oauth2_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ type RecordOAuth2Login struct {
// The authorization code returned from the initial request.
Code string `form:"code" json:"code"`

// The code verifier sent with the initial request as part of the code_challenge.
// The optional PKCE code verifier as part of the code_challenge sent with the initial request.
CodeVerifier string `form:"codeVerifier" json:"codeVerifier"`

// The redirect url sent with the initial request.
Expand Down Expand Up @@ -87,7 +87,6 @@ func (form *RecordOAuth2Login) Validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Provider, validation.Required, validation.By(form.checkProviderName)),
validation.Field(&form.Code, validation.Required),
validation.Field(&form.CodeVerifier, validation.Required),
validation.Field(&form.RedirectUrl, validation.Required),
)
}
Expand Down Expand Up @@ -142,11 +141,14 @@ func (form *RecordOAuth2Login) Submit(

provider.SetRedirectUrl(form.RedirectUrl)

var opts []oauth2.AuthCodeOption

if provider.PKCE() {
opts = append(opts, oauth2.SetAuthURLParam("code_verifier", form.CodeVerifier))
}

// fetch token
token, err := provider.FetchToken(
form.Code,
oauth2.SetAuthURLParam("code_verifier", form.CodeVerifier),
)
token, err := provider.FetchToken(form.Code, opts...)
if err != nil {
return nil, nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions forms/record_oauth2_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ func TestUserOauth2LoginValidate(t *testing.T) {
"empty payload",
"users",
"{}",
[]string{"provider", "code", "codeVerifier", "redirectUrl"},
[]string{"provider", "code", "redirectUrl"},
},
{
"empty data",
"users",
`{"provider":"","code":"","codeVerifier":"","redirectUrl":""}`,
[]string{"provider", "code", "codeVerifier", "redirectUrl"},
[]string{"provider", "code", "redirectUrl"},
},
{
"missing provider",
Expand Down
2 changes: 1 addition & 1 deletion models/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ var _ Model = (*Request)(nil)
const (
RequestAuthGuest = "guest"
RequestAuthAdmin = "admin"
RequestAuthRecord = "auth_record"
RequestAuthRecord = "authRecord"
)

type Request struct {
Expand Down
10 changes: 10 additions & 0 deletions models/settings/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,8 @@ type AuthProviderConfig struct {
AuthUrl string `form:"authUrl" json:"authUrl"`
TokenUrl string `form:"tokenUrl" json:"tokenUrl"`
UserApiUrl string `form:"userApiUrl" json:"userApiUrl"`
DisplayName string `form:"displayName" json:"displayName"`
PKCE *bool `form:"pkce" json:"pkce"`
}

// Validate makes `ProviderConfig` validatable by implementing [validation.Validatable] interface.
Expand Down Expand Up @@ -659,6 +661,14 @@ func (c AuthProviderConfig) SetupProvider(provider auth.Provider) error {
provider.SetTokenUrl(c.TokenUrl)
}

if c.DisplayName != "" {
provider.SetDisplayName(c.DisplayName)
}

if c.PKCE != nil {
provider.SetPKCE(*c.PKCE)
}

return nil
}

Expand Down
13 changes: 13 additions & 0 deletions models/settings/settings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/pocketbase/pocketbase/models/settings"
"github.com/pocketbase/pocketbase/tools/auth"
"github.com/pocketbase/pocketbase/tools/mailer"
"github.com/pocketbase/pocketbase/tools/types"
)

func TestSettingsValidate(t *testing.T) {
Expand Down Expand Up @@ -941,6 +942,8 @@ func TestAuthProviderConfigValidate(t *testing.T) {
Enabled: true,
ClientId: "test",
ClientSecret: "test",
DisplayName: "test",
PKCE: types.Pointer(true),
AuthUrl: "https://example.com",
TokenUrl: "https://example.com",
UserApiUrl: "https://example.com",
Expand Down Expand Up @@ -978,6 +981,8 @@ func TestAuthProviderConfigSetupProvider(t *testing.T) {
AuthUrl: "test_AuthUrl",
UserApiUrl: "test_UserApiUrl",
TokenUrl: "test_TokenUrl",
DisplayName: "test_DisplayName",
PKCE: types.Pointer(true),
}
if err := c2.SetupProvider(provider); err != nil {
t.Error(err)
Expand All @@ -1002,4 +1007,12 @@ func TestAuthProviderConfigSetupProvider(t *testing.T) {
if provider.TokenUrl() != c2.TokenUrl {
t.Fatalf("Expected TokenUrl %s, got %s", c2.TokenUrl, provider.TokenUrl())
}

if provider.DisplayName() != c2.DisplayName {
t.Fatalf("Expected DisplayName %s, got %s", c2.DisplayName, provider.DisplayName())
}

if provider.PKCE() != *c2.PKCE {
t.Fatalf("Expected PKCE %v, got %v", *c2.PKCE, provider.PKCE())
}
}
10 changes: 6 additions & 4 deletions tools/auth/apple.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ type Apple struct {
func NewAppleProvider() *Apple {
return &Apple{
baseProvider: &baseProvider{
scopes: nil, // custom scopes are currently not supported since they require a POST redirect
ctx: context.Background(),
authUrl: "https://appleid.apple.com/auth/authorize",
tokenUrl: "https://appleid.apple.com/auth/token",
ctx: context.Background(),
displayName: "Apple",
pkce: true,
scopes: nil, // custom scopes are currently not supported since they require a POST redirect
authUrl: "https://appleid.apple.com/auth/authorize",
tokenUrl: "https://appleid.apple.com/auth/token",
},
jwksUrl: "https://appleid.apple.com/auth/keys",
}
Expand Down
15 changes: 14 additions & 1 deletion tools/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ type AuthUser struct {
Username string `json:"username"`
Email string `json:"email"`
AvatarUrl string `json:"avatarUrl"`
RawUser map[string]any `json:"rawUser"`
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
Expiry types.DateTime `json:"expiry"`
RawUser map[string]any `json:"rawUser"`
}

// Provider defines a common interface for an OAuth2 client.
Expand All @@ -30,6 +30,19 @@ type Provider interface {
// SetContext assigns the specified context to the current provider.
SetContext(ctx context.Context)

// PKCE indicates whether the provider can use the PKCE flow.
PKCE() bool

// SetPKCE toggles the state whether the provider can use the PKCE flow or not.
SetPKCE(enable bool)

// DisplayName usually returns provider name as it is officially written
// and it could be used directly in the UI.
DisplayName() string

// SetDisplayName sets the provider's display name.
SetDisplayName(displayName string)

// Scopes returns the provider access permissions that will be requested.
Scopes() []string

Expand Down
24 changes: 23 additions & 1 deletion tools/auth/base_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ import (
// baseProvider defines common fields and methods used by OAuth2 client providers.
type baseProvider struct {
ctx context.Context
scopes []string
clientId string
clientSecret string
displayName string
redirectUrl string
authUrl string
tokenUrl string
userApiUrl string
scopes []string
pkce bool
}

// Context implements Provider.Context() interface method.
Expand All @@ -31,6 +33,26 @@ func (p *baseProvider) SetContext(ctx context.Context) {
p.ctx = ctx
}

// PKCE implements Provider.PKCE() interface method.
func (p *baseProvider) PKCE() bool {
return p.pkce
}

// SetPKCE implements Provider.SetPKCE() interface method.
func (p *baseProvider) SetPKCE(enable bool) {
p.pkce = enable
}

// DisplayName implements Provider.DisplayName() interface method.
func (p *baseProvider) DisplayName() string {
return p.displayName
}

// SetDisplayName implements Provider.SetDisplayName() interface method.
func (p *baseProvider) SetDisplayName(displayName string) {
p.displayName = displayName
}

// Scopes implements Provider.Scopes() interface method.
func (p *baseProvider) Scopes() []string {
return p.scopes
Expand Down
Loading

0 comments on commit b283ee2

Please sign in to comment.