Skip to content

Commit

Permalink
[main] Login email before username (grafana#57400)
Browse files Browse the repository at this point in the history
* Swap order of login fields

* Validate email field before validating the username field.

Co-authored-by: linoman <[email protected]>
  • Loading branch information
kalleep and linoman authored Oct 21, 2022
1 parent 7e631e7 commit d55e5b8
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 21 deletions.
42 changes: 21 additions & 21 deletions pkg/services/user/userimpl/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,27 +170,30 @@ func (ss *sqlStore) GetByLogin(ctx context.Context, query *user.GetUserByLoginQu
return user.ErrUserNotFound
}

// Try and find the user by login first.
// It's not sufficient to assume that a LoginOrEmail with an "@" is an email.
where := "login=?"
if ss.cfg.CaseInsensitiveLogin {
where = "LOWER(login)=LOWER(?)"
}

has, err := sess.Where(ss.notServiceAccountFilter()).Where(where, query.LoginOrEmail).Get(usr)
if err != nil {
return err
}

if !has && strings.Contains(query.LoginOrEmail, "@") {
// If the user wasn't found, and it contains an "@" fallback to finding the
// user by email.
var where string
var has bool
var err error

// Since username can be an email address, attempt login with email address
// first if the login field has the "@" symbol.
if strings.Contains(query.LoginOrEmail, "@") {
where = "email=?"
if ss.cfg.CaseInsensitiveLogin {
where = "LOWER(email)=LOWER(?)"
}
usr = &user.User{}
has, err = sess.Where(ss.notServiceAccountFilter()).Where(where, query.LoginOrEmail).Get(usr)

if err != nil {
return err
}
}

// Look for the login field instead of email
if !has {
where = "login=?"
if ss.cfg.CaseInsensitiveLogin {
where = "LOWER(login)=LOWER(?)"
}
has, err = sess.Where(ss.notServiceAccountFilter()).Where(where, query.LoginOrEmail).Get(usr)
}

Expand All @@ -199,18 +202,15 @@ func (ss *sqlStore) GetByLogin(ctx context.Context, query *user.GetUserByLoginQu
} else if !has {
return user.ErrUserNotFound
}

if ss.cfg.CaseInsensitiveLogin {
if err := ss.userCaseInsensitiveLoginConflict(ctx, sess, usr.Login, usr.Email); err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return usr, nil

return usr, err
}

func (ss *sqlStore) GetByEmail(ctx context.Context, query *user.GetUserByEmailQuery) (*user.User, error) {
Expand Down
41 changes: 41 additions & 0 deletions pkg/services/user/userimpl/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,47 @@ func TestIntegrationUserDataAccess(t *testing.T) {
require.Error(t, err)
})

t.Run("GetByLogin - user2 uses user1.email as login", func(t *testing.T) {
// create user_1
user1 := &user.User{
Email: "[email protected]",
Name: "user_1",
Login: "user_1",
Password: "user_1_password",
Created: time.Now(),
Updated: time.Now(),
IsDisabled: true,
}
_, err := userStore.Insert(context.Background(), user1)
require.Nil(t, err)

// create user_2
user2 := &user.User{
Email: "[email protected]",
Name: "user_2",
Login: "[email protected]",
Password: "user_2_password",
Created: time.Now(),
Updated: time.Now(),
IsDisabled: true,
}
_, err = userStore.Insert(context.Background(), user2)
require.Nil(t, err)

// query user database for user_1 email
query := user.GetUserByLoginQuery{LoginOrEmail: "[email protected]"}
result, err := userStore.GetByLogin(context.Background(), &query)
require.Nil(t, err)

// expect user_1 as result
require.Equal(t, user1.Email, result.Email)
require.Equal(t, user1.Login, result.Login)
require.Equal(t, user1.Name, result.Name)
require.NotEqual(t, user2.Email, result.Email)
require.NotEqual(t, user2.Login, result.Login)
require.NotEqual(t, user2.Name, result.Name)
})

ss.Cfg.CaseInsensitiveLogin = false
})

Expand Down

0 comments on commit d55e5b8

Please sign in to comment.