Skip to content

Commit

Permalink
Make userapi responsible for checking access tokens (#1133)
Browse files Browse the repository at this point in the history
* Make userapi responsible for checking access tokens

There's still plenty of dependencies on account/device DBs, but this
is a start. This is a breaking change as it adds a required config
value `listen.user_api`.

* Cleanup

* Review comments and test fix
  • Loading branch information
kegsay authored Jun 16, 2020
1 parent 57b7fa3 commit 9c77022
Show file tree
Hide file tree
Showing 66 changed files with 421 additions and 400 deletions.
125 changes: 20 additions & 105 deletions clientapi/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,14 @@ package auth
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"fmt"
"net/http"
"strings"

"github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)

Expand All @@ -39,7 +36,7 @@ var tokenByteLength = 32
// DeviceDatabase represents a device database.
type DeviceDatabase interface {
// Look up the device matching the given access token.
GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error)
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
}

// AccountDatabase represents an account database.
Expand All @@ -48,22 +45,14 @@ type AccountDatabase interface {
GetAccountByLocalpart(ctx context.Context, localpart string) (*authtypes.Account, error)
}

// Data contains information required to authenticate a request.
type Data struct {
AccountDB AccountDatabase
DeviceDB DeviceDatabase
// AppServices is the list of all registered AS
AppServices []config.ApplicationService
}

// VerifyUserFromRequest authenticates the HTTP request,
// on success returns Device of the requester.
// Finds local user or an application service user.
// Note: For an AS user, AS dummy device is returned.
// On failure returns an JSON error response which can be sent to the client.
func VerifyUserFromRequest(
req *http.Request, data Data,
) (*authtypes.Device, *util.JSONResponse) {
req *http.Request, userAPI api.UserInternalAPI,
) (*api.Device, *util.JSONResponse) {
// Try to find the Application Service user
token, err := ExtractAccessToken(req)
if err != nil {
Expand All @@ -72,105 +61,31 @@ func VerifyUserFromRequest(
JSON: jsonerror.MissingToken(err.Error()),
}
}

// Search for app service with given access_token
var appService *config.ApplicationService
for _, as := range data.AppServices {
if as.ASToken == token {
appService = &as
break
}
var res api.QueryAccessTokenResponse
err = userAPI.QueryAccessToken(req.Context(), &api.QueryAccessTokenRequest{
AccessToken: token,
AppServiceUserID: req.URL.Query().Get("user_id"),
}, &res)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccessToken failed")
jsonErr := jsonerror.InternalServerError()
return nil, &jsonErr
}

if appService != nil {
// Create a dummy device for AS user
dev := authtypes.Device{
// Use AS dummy device ID
ID: types.AppServiceDeviceID,
// AS dummy device has AS's token.
AccessToken: token,
}

userID := req.URL.Query().Get("user_id")
localpart, err := userutil.ParseUsernameParam(userID, nil)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername(err.Error()),
}
}

if localpart != "" { // AS is masquerading as another user
// Verify that the user is registered
account, err := data.AccountDB.GetAccountByLocalpart(req.Context(), localpart)
// Verify that account exists & appServiceID matches
if err == nil && account.AppServiceID == appService.ID {
// Set the userID of dummy device
dev.UserID = userID
return &dev, nil
}

if res.Err != nil {
if forbidden, ok := res.Err.(*api.ErrorForbidden); ok {
return nil, &util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Application service has not registered this user"),
JSON: jsonerror.Forbidden(forbidden.Message),
}
}

// AS is not masquerading as any user, so use AS's sender_localpart
dev.UserID = appService.SenderLocalpart
return &dev, nil
}

// Try to find local user from device database
dev, devErr := verifyAccessToken(req, data.DeviceDB)
if devErr == nil {
return dev, verifyUserParameters(req)
}

return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.UnknownToken("Unrecognized access token"), // nolint: misspell
}
}

// verifyUserParameters ensures that a request coming from a regular user is not
// using any query parameters reserved for an application service
func verifyUserParameters(req *http.Request) *util.JSONResponse {
if req.URL.Query().Get("ts") != "" {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown("parameter 'ts' not allowed without valid parameter 'access_token'"),
}
}
return nil
}

// verifyAccessToken verifies that an access token was supplied in the given HTTP request
// and returns the device it corresponds to. Returns resErr (an error response which can be
// sent to the client) if the token is invalid or there was a problem querying the database.
func verifyAccessToken(req *http.Request, deviceDB DeviceDatabase) (device *authtypes.Device, resErr *util.JSONResponse) {
token, err := ExtractAccessToken(req)
if err != nil {
resErr = &util.JSONResponse{
if res.Device == nil {
return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.MissingToken(err.Error()),
}
return
}
device, err = deviceDB.GetDeviceByAccessToken(req.Context(), token)
if err != nil {
if err == sql.ErrNoRows {
resErr = &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.UnknownToken("Unknown token"),
}
} else {
util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByAccessToken failed")
jsonErr := jsonerror.InternalServerError()
resErr = &jsonErr
JSON: jsonerror.UnknownToken("Unknown token"),
}
}
return
return res.Device, nil
}

// GenerateAccessToken creates a new access token. Returns an error if failed to generate
Expand Down
30 changes: 0 additions & 30 deletions clientapi/auth/authtypes/device.go

This file was deleted.

10 changes: 5 additions & 5 deletions clientapi/auth/storage/devices/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ package devices
import (
"context"

"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api"
)

type Database interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*authtypes.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]authtypes.Device, error)
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *authtypes.Device, returnErr error)
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
Expand Down
20 changes: 10 additions & 10 deletions clientapi/auth/storage/devices/postgres/devices_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import (
"time"

"github.com/lib/pq"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)

Expand Down Expand Up @@ -135,14 +135,14 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
func (s *devicesStatements) insertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string,
) (*authtypes.Device, error) {
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil {
return nil, err
}
return &authtypes.Device{
return &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName),
AccessToken: accessToken,
Expand Down Expand Up @@ -189,8 +189,8 @@ func (s *devicesStatements) updateDeviceName(

func (s *devicesStatements) selectDeviceByToken(
ctx context.Context, accessToken string,
) (*authtypes.Device, error) {
var dev authtypes.Device
) (*api.Device, error) {
var dev api.Device
var localpart string
stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
Expand All @@ -205,8 +205,8 @@ func (s *devicesStatements) selectDeviceByToken(
// localpart and deviceID
func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
var dev authtypes.Device
) (*api.Device, error) {
var dev api.Device
stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName)
if err == nil {
Expand All @@ -218,8 +218,8 @@ func (s *devicesStatements) selectDeviceByID(

func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
devices := []authtypes.Device{}
) ([]api.Device, error) {
devices := []api.Device{}

rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart)

Expand All @@ -229,7 +229,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed")

for rows.Next() {
var dev authtypes.Device
var dev api.Device
var id, displayname sql.NullString
err = rows.Scan(&id, &displayname)
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions clientapi/auth/storage/devices/postgres/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import (
"database/sql"
"encoding/base64"

"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)

Expand Down Expand Up @@ -52,22 +52,22 @@ func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serve
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*authtypes.Device, error) {
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}

// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}

// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}

Expand All @@ -80,7 +80,7 @@ func (d *Database) GetDevicesByLocalpart(
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string,
) (dev *authtypes.Device, returnErr error) {
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
Expand Down
Loading

0 comments on commit 9c77022

Please sign in to comment.