Skip to content

Commit

Permalink
Prevent removing last admin (gotify#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
饺子w authored and jmattheis committed Feb 26, 2019
1 parent ec5b1f8 commit 2fa395c
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 0 deletions.
9 changes: 9 additions & 0 deletions api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type UserDatabase interface {
DeleteUserByID(id uint) error
UpdateUser(user *model.User)
CreateUser(user *model.User) error
CountUser(condition ...interface{}) int
}

// UserChangeNotifier notifies listeners for user changes.
Expand Down Expand Up @@ -252,6 +253,10 @@ func (a *UserAPI) GetUserByID(ctx *gin.Context) {
func (a *UserAPI) DeleteUserByID(ctx *gin.Context) {
withID(ctx, "id", func(id uint) {
if user := a.DB.GetUserByID(id); user != nil {
if user.Admin && a.DB.CountUser(&model.User{Admin: true}) == 1 {
ctx.AbortWithError(400, errors.New("cannot delete last admin"))
return
}
if err := a.UserChangeNotifier.fireUserDeleted(id); err != nil {
ctx.AbortWithError(500, err)
return
Expand Down Expand Up @@ -350,6 +355,10 @@ func (a *UserAPI) UpdateUserByID(ctx *gin.Context) {
var user *model.UserExternalWithPass
if err := ctx.Bind(&user); err == nil {
if oldUser := a.DB.GetUserByID(id); oldUser != nil {
if !user.Admin && oldUser.Admin && a.DB.CountUser(&model.User{Admin: true}) == 1 {
ctx.AbortWithError(400, errors.New("cannot delete last admin"))
return
}
internal := a.toInternalUser(user, oldUser.Pass)
internal.ID = id
a.DB.UpdateUser(internal)
Expand Down
29 changes: 29 additions & 0 deletions api/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,19 @@ func (s *UserSuite) Test_GetUserByID_UnknownUser() {
assert.Equal(s.T(), 404, s.recorder.Code)
}

func (s *UserSuite) Test_DeleteUserByID_LastAdmin_Expect400() {
s.db.CreateUser(&model.User{
ID: 7,
Name: "admin",
Admin: true,
})
s.ctx.Params = gin.Params{{Key: "id", Value: "7"}}

s.a.DeleteUserByID(s.ctx)

assert.Equal(s.T(), 400, s.recorder.Code)
}

func (s *UserSuite) Test_DeleteUserByID_InvalidID() {
s.ctx.Params = gin.Params{{Key: "id", Value: "abc"}}

Expand Down Expand Up @@ -221,6 +234,22 @@ func (s *UserSuite) Test_UpdateUserByID_InvalidID() {
assert.Equal(s.T(), 400, s.recorder.Code)
}

func (s *UserSuite) Test_UpdateUserByID_LastAdmin_Expect400() {
s.db.CreateUser(&model.User{
ID: 7,
Name: "admin",
Admin: true,
})

s.ctx.Params = gin.Params{{Key: "id", Value: "7"}}

s.ctx.Request = httptest.NewRequest("POST", "/user/7", strings.NewReader(`{"name": "admin", "pass": "", "admin": false}`))
s.ctx.Request.Header.Set("Content-Type", "application/json")
s.a.UpdateUserByID(s.ctx)

assert.Equal(s.T(), 400, s.recorder.Code)
}

func (s *UserSuite) Test_UpdateUserByID_UnknownUser() {
s.ctx.Params = gin.Params{{Key: "id", Value: "2"}}

Expand Down
13 changes: 13 additions & 0 deletions database/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ func (d *GormDatabase) GetUserByID(id uint) *model.User {
return nil
}

// CountUser returns the user count which satisfies the given condition.
func (d *GormDatabase) CountUser(condition ...interface{}) int {
c := -1
handle := d.DB.Model(new(model.User))
if len(condition) == 1 {
handle = handle.Where(condition[0])
} else if len(condition) > 1 {
handle = handle.Where(condition[0], condition[1:]...)
}
handle.Count(&c)
return c
}

// GetUsers returns all users.
func (d *GormDatabase) GetUsers() []*model.User {
var users []*model.User
Expand Down
3 changes: 3 additions & 0 deletions database/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ func (s *DatabaseSuite) TestUser() {

jmattheis := s.db.GetUserByID(1)
assert.NotNil(s.T(), jmattheis, "on bootup the first user should be automatically created")
assert.Equal(s.T(), 1, s.db.CountUser("admin = ?", true), 1, "there is initially one admin")

users := s.db.GetUsers()
assert.Len(s.T(), users, 1)
Expand All @@ -19,6 +20,7 @@ func (s *DatabaseSuite) TestUser() {
nicories := &model.User{Name: "nicories", Pass: []byte{1, 2, 3, 4}, Admin: false}
s.db.CreateUser(nicories)
assert.NotEqual(s.T(), 0, nicories.ID, "on create user a new id should be assigned")
assert.Equal(s.T(), 2, s.db.CountUser(), "two users should exist")

assert.Equal(s.T(), nicories, s.db.GetUserByName("nicories"))

Expand All @@ -35,6 +37,7 @@ func (s *DatabaseSuite) TestUser() {
assert.Equal(s.T(), &model.User{ID: nicories.ID, Name: "tom", Pass: []byte{12}, Admin: true}, tom)
users = s.db.GetUsers()
assert.Len(s.T(), users, 2)
assert.Equal(s.T(), 2, s.db.CountUser(&model.User{Admin: true}), "two admins exist")

s.db.DeleteUserByID(tom.ID)
users = s.db.GetUsers()
Expand Down

0 comments on commit 2fa395c

Please sign in to comment.