Skip to content

Commit

Permalink
Export service fields (warrant-dev#83)
Browse files Browse the repository at this point in the history
* Export service fields to allow repos and services to be extended

* Rename GetByObjectId to match service and repository methods

* Fix user usages and add warrant service exports
  • Loading branch information
stanleyphu authored Apr 15, 2023
1 parent 743cd5b commit c9abeb1
Show file tree
Hide file tree
Showing 13 changed files with 219 additions and 219 deletions.
34 changes: 17 additions & 17 deletions pkg/authz/check/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,29 @@ import (

type CheckService struct {
service.BaseService
warrantRepo warrant.WarrantRepository
eventSvc event.EventService
ctxSvc wntContext.ContextService
objectTypeSvc objecttype.ObjectTypeService
WarrantRepository warrant.WarrantRepository
EventSvc event.EventService
CtxSvc wntContext.ContextService
ObjectTypeSvc objecttype.ObjectTypeService
}

func NewService(env service.Env, warrantRepo warrant.WarrantRepository, ctxSvc wntContext.ContextService, eventSvc event.EventService, objectTypeSvc objecttype.ObjectTypeService) CheckService {
return CheckService{
BaseService: service.NewBaseService(env),
warrantRepo: warrantRepo,
ctxSvc: ctxSvc,
eventSvc: eventSvc,
objectTypeSvc: objectTypeSvc,
BaseService: service.NewBaseService(env),
WarrantRepository: warrantRepo,
CtxSvc: ctxSvc,
EventSvc: eventSvc,
ObjectTypeSvc: objectTypeSvc,
}
}

func (svc CheckService) getWithContextMatch(ctx context.Context, spec warrant.WarrantSpec) (*warrant.WarrantSpec, error) {
warrant, err := svc.warrantRepo.GetWithContextMatch(ctx, spec.ObjectType, spec.ObjectId, spec.Relation, spec.Subject.ObjectType, spec.Subject.ObjectId, spec.Subject.Relation, spec.Context.ToHash())
warrant, err := svc.WarrantRepository.GetWithContextMatch(ctx, spec.ObjectType, spec.ObjectId, spec.Relation, spec.Subject.ObjectType, spec.Subject.ObjectId, spec.Subject.Relation, spec.Context.ToHash())
if err != nil || warrant == nil {
return nil, err
}

contextSetSpec, err := svc.ctxSvc.ListByWarrantId(ctx, []int64{warrant.GetID()})
contextSetSpec, err := svc.CtxSvc.ListByWarrantId(ctx, []int64{warrant.GetID()})
if err != nil {
return nil, err
}
Expand All @@ -51,7 +51,7 @@ func (svc CheckService) getMatchingSubjects(ctx context.Context, objectType stri
log.Debug().Msgf("Getting matching subjects for %s:%s#%s@%s:___%s", objectType, objectId, relation, subjectType, wntCtx)

warrantSpecs := make([]warrant.WarrantSpec, 0)
objectTypeSpec, err := svc.objectTypeSvc.GetByTypeId(ctx, objectType)
objectTypeSpec, err := svc.ObjectTypeSvc.GetByTypeId(ctx, objectType)
if err != nil {
return warrantSpecs, err
}
Expand All @@ -60,7 +60,7 @@ func (svc CheckService) getMatchingSubjects(ctx context.Context, objectType stri
return warrantSpecs, nil
}

warrants, err := svc.warrantRepo.GetAllMatchingObjectAndRelation(
warrants, err := svc.WarrantRepository.GetAllMatchingObjectAndRelation(
ctx,
objectType,
objectId,
Expand All @@ -81,7 +81,7 @@ func (svc CheckService) getMatchingSubjects(ctx context.Context, objectType stri
return warrantSpecs, err
}

warrants, err = svc.warrantRepo.GetAllMatchingWildcard(
warrants, err = svc.WarrantRepository.GetAllMatchingWildcard(
ctx,
objectType,
objectId,
Expand Down Expand Up @@ -357,7 +357,7 @@ func (svc CheckService) Check(ctx context.Context, authInfo *service.AuthInfo, w
}

// Attempt to match against defined rules for target relation
objectTypeSpec, err := svc.objectTypeSvc.GetByTypeId(ctx, warrantCheck.ObjectType)
objectTypeSpec, err := svc.ObjectTypeSvc.GetByTypeId(ctx, warrantCheck.ObjectType)
if err != nil {
return false, decisionPath, err
}
Expand All @@ -369,15 +369,15 @@ func (svc CheckService) Check(ctx context.Context, authInfo *service.AuthInfo, w
}

if match {
err := svc.eventSvc.TrackAccessAllowedEvent(ctx, warrantCheck.ObjectType, warrantCheck.ObjectId, warrantCheck.Relation, warrantCheck.Subject.ObjectType, warrantCheck.Subject.ObjectId, warrantCheck.Subject.Relation, warrantCheck.Context)
err := svc.EventSvc.TrackAccessAllowedEvent(ctx, warrantCheck.ObjectType, warrantCheck.ObjectId, warrantCheck.Relation, warrantCheck.Subject.ObjectType, warrantCheck.Subject.ObjectId, warrantCheck.Subject.Relation, warrantCheck.Context)
if err != nil {
return false, decisionPath, err
}

return true, decisionPath, nil
}

err = svc.eventSvc.TrackAccessDeniedEvent(ctx, warrantCheck.ObjectType, warrantCheck.ObjectId, warrantCheck.Relation, warrantCheck.Subject.ObjectType, warrantCheck.Subject.ObjectId, warrantCheck.Subject.Relation, warrantCheck.Context)
err = svc.EventSvc.TrackAccessDeniedEvent(ctx, warrantCheck.ObjectType, warrantCheck.ObjectId, warrantCheck.Relation, warrantCheck.Subject.ObjectType, warrantCheck.Subject.ObjectId, warrantCheck.Subject.Relation, warrantCheck.Context)
if err != nil {
return false, decisionPath, err
}
Expand Down
42 changes: 21 additions & 21 deletions pkg/authz/feature/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,44 @@ const ResourceTypeFeature = "feature"

type FeatureService struct {
service.BaseService
repo FeatureRepository
eventSvc event.EventService
objectSvc object.ObjectService
Repository FeatureRepository
EventSvc event.EventService
ObjectSvc object.ObjectService
}

func NewService(env service.Env, repo FeatureRepository, eventSvc event.EventService, objectSvc object.ObjectService) FeatureService {
func NewService(env service.Env, repository FeatureRepository, eventSvc event.EventService, objectSvc object.ObjectService) FeatureService {
return FeatureService{
BaseService: service.NewBaseService(env),
repo: repo,
eventSvc: eventSvc,
objectSvc: objectSvc,
Repository: repository,
EventSvc: eventSvc,
ObjectSvc: objectSvc,
}
}

func (svc FeatureService) Create(ctx context.Context, featureSpec FeatureSpec) (*FeatureSpec, error) {
var newFeature Model
err := svc.Env().DB().WithinTransaction(ctx, func(txCtx context.Context) error {
createdObject, err := svc.objectSvc.Create(txCtx, *featureSpec.ToObjectSpec())
createdObject, err := svc.ObjectSvc.Create(txCtx, *featureSpec.ToObjectSpec())
if err != nil {
return err
}

_, err = svc.repo.GetByFeatureId(txCtx, featureSpec.FeatureId)
_, err = svc.Repository.GetByFeatureId(txCtx, featureSpec.FeatureId)
if err == nil {
return service.NewDuplicateRecordError("Feature", featureSpec.FeatureId, "A feature with the given featureId already exists")
}

newFeatureId, err := svc.repo.Create(txCtx, featureSpec.ToFeature(createdObject.ID))
newFeatureId, err := svc.Repository.Create(txCtx, featureSpec.ToFeature(createdObject.ID))
if err != nil {
return err
}

newFeature, err = svc.repo.GetById(txCtx, newFeatureId)
newFeature, err = svc.Repository.GetById(txCtx, newFeatureId)
if err != nil {
return err
}

err = svc.eventSvc.TrackResourceCreated(ctx, ResourceTypeFeature, newFeature.GetFeatureId(), newFeature.ToFeatureSpec())
err = svc.EventSvc.TrackResourceCreated(ctx, ResourceTypeFeature, newFeature.GetFeatureId(), newFeature.ToFeatureSpec())
if err != nil {
return err
}
Expand All @@ -67,7 +67,7 @@ func (svc FeatureService) Create(ctx context.Context, featureSpec FeatureSpec) (
}

func (svc FeatureService) GetByFeatureId(ctx context.Context, featureId string) (*FeatureSpec, error) {
feature, err := svc.repo.GetByFeatureId(ctx, featureId)
feature, err := svc.Repository.GetByFeatureId(ctx, featureId)
if err != nil {
return nil, err
}
Expand All @@ -77,7 +77,7 @@ func (svc FeatureService) GetByFeatureId(ctx context.Context, featureId string)

func (svc FeatureService) List(ctx context.Context, listParams middleware.ListParams) ([]FeatureSpec, error) {
featureSpecs := make([]FeatureSpec, 0)
features, err := svc.repo.List(ctx, listParams)
features, err := svc.Repository.List(ctx, listParams)
if err != nil {
return featureSpecs, nil
}
Expand All @@ -90,25 +90,25 @@ func (svc FeatureService) List(ctx context.Context, listParams middleware.ListPa
}

func (svc FeatureService) UpdateByFeatureId(ctx context.Context, featureId string, featureSpec UpdateFeatureSpec) (*FeatureSpec, error) {
currentFeature, err := svc.repo.GetByFeatureId(ctx, featureId)
currentFeature, err := svc.Repository.GetByFeatureId(ctx, featureId)
if err != nil {
return nil, err
}

currentFeature.SetName(featureSpec.Name)
currentFeature.SetDescription(featureSpec.Description)
err = svc.repo.UpdateByFeatureId(ctx, featureId, currentFeature)
err = svc.Repository.UpdateByFeatureId(ctx, featureId, currentFeature)
if err != nil {
return nil, err
}

updatedFeature, err := svc.repo.GetByFeatureId(ctx, featureId)
updatedFeature, err := svc.Repository.GetByFeatureId(ctx, featureId)
if err != nil {
return nil, err
}

updatedFeatureSpec := updatedFeature.ToFeatureSpec()
err = svc.eventSvc.TrackResourceUpdated(ctx, ResourceTypeFeature, updatedFeature.GetFeatureId(), updatedFeatureSpec)
err = svc.EventSvc.TrackResourceUpdated(ctx, ResourceTypeFeature, updatedFeature.GetFeatureId(), updatedFeatureSpec)
if err != nil {
return nil, err
}
Expand All @@ -118,17 +118,17 @@ func (svc FeatureService) UpdateByFeatureId(ctx context.Context, featureId strin

func (svc FeatureService) DeleteByFeatureId(ctx context.Context, featureId string) error {
err := svc.Env().DB().WithinTransaction(ctx, func(txCtx context.Context) error {
err := svc.repo.DeleteByFeatureId(txCtx, featureId)
err := svc.Repository.DeleteByFeatureId(txCtx, featureId)
if err != nil {
return err
}

err = svc.objectSvc.DeleteByObjectTypeAndId(txCtx, objecttype.ObjectTypeFeature, featureId)
err = svc.ObjectSvc.DeleteByObjectTypeAndId(txCtx, objecttype.ObjectTypeFeature, featureId)
if err != nil {
return err
}

err = svc.eventSvc.TrackResourceDeleted(ctx, ResourceTypeFeature, featureId, nil)
err = svc.EventSvc.TrackResourceDeleted(ctx, ResourceTypeFeature, featureId, nil)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/authz/object/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func ListHandler(svc ObjectService, w http.ResponseWriter, r *http.Request) erro
func GetHandler(svc ObjectService, w http.ResponseWriter, r *http.Request) error {
objectType := mux.Vars(r)["objectType"]
objectIdParam := mux.Vars(r)["objectId"]
object, err := svc.GetByObjectId(r.Context(), objectType, objectIdParam)
object, err := svc.GetByObjectTypeAndId(r.Context(), objectType, objectIdParam)
if err != nil {
return err
}
Expand Down
30 changes: 15 additions & 15 deletions pkg/authz/object/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,32 @@ import (

type ObjectService struct {
service.BaseService
repo ObjectRepository
eventSvc event.EventService
warrantSvc warrant.WarrantService
Repository ObjectRepository
EventSvc event.EventService
WarrantSvc warrant.WarrantService
}

func NewService(env service.Env, repo ObjectRepository, eventSvc event.EventService, warrantSvc warrant.WarrantService) ObjectService {
func NewService(env service.Env, repository ObjectRepository, eventSvc event.EventService, warrantSvc warrant.WarrantService) ObjectService {
return ObjectService{
BaseService: service.NewBaseService(env),
repo: repo,
eventSvc: eventSvc,
warrantSvc: warrantSvc,
Repository: repository,
EventSvc: eventSvc,
WarrantSvc: warrantSvc,
}
}

func (svc ObjectService) Create(ctx context.Context, objectSpec ObjectSpec) (*ObjectSpec, error) {
_, err := svc.repo.GetByObjectTypeAndId(ctx, objectSpec.ObjectType, objectSpec.ObjectId)
_, err := svc.Repository.GetByObjectTypeAndId(ctx, objectSpec.ObjectType, objectSpec.ObjectId)
if err == nil {
return nil, service.NewDuplicateRecordError("Object", fmt.Sprintf("%s:%s", objectSpec.ObjectType, objectSpec.ObjectId), "An object with the given objectType and objectId already exists")
}

newObjectId, err := svc.repo.Create(ctx, *objectSpec.ToObject())
newObjectId, err := svc.Repository.Create(ctx, *objectSpec.ToObject())
if err != nil {
return nil, err
}

newObject, err := svc.repo.GetById(ctx, newObjectId)
newObject, err := svc.Repository.GetById(ctx, newObjectId)
if err != nil {
return nil, err
}
Expand All @@ -47,12 +47,12 @@ func (svc ObjectService) Create(ctx context.Context, objectSpec ObjectSpec) (*Ob

func (svc ObjectService) DeleteByObjectTypeAndId(ctx context.Context, objectType string, objectId string) error {
err := svc.Env().DB().WithinTransaction(ctx, func(txCtx context.Context) error {
err := svc.repo.DeleteByObjectTypeAndId(txCtx, objectType, objectId)
err := svc.Repository.DeleteByObjectTypeAndId(txCtx, objectType, objectId)
if err != nil {
return err
}

err = svc.warrantSvc.DeleteRelatedWarrants(txCtx, objectType, objectId)
err = svc.WarrantSvc.DeleteRelatedWarrants(txCtx, objectType, objectId)
if err != nil {
return err
}
Expand All @@ -63,8 +63,8 @@ func (svc ObjectService) DeleteByObjectTypeAndId(ctx context.Context, objectType
return err
}

func (svc ObjectService) GetByObjectId(ctx context.Context, objectType string, objectId string) (*ObjectSpec, error) {
object, err := svc.repo.GetByObjectTypeAndId(ctx, objectType, objectId)
func (svc ObjectService) GetByObjectTypeAndId(ctx context.Context, objectType string, objectId string) (*ObjectSpec, error) {
object, err := svc.Repository.GetByObjectTypeAndId(ctx, objectType, objectId)
if err != nil {
return nil, err
}
Expand All @@ -74,7 +74,7 @@ func (svc ObjectService) GetByObjectId(ctx context.Context, objectType string, o

func (svc ObjectService) List(ctx context.Context, filterOptions *FilterOptions, listParams middleware.ListParams) ([]ObjectSpec, error) {
objectSpecs := make([]ObjectSpec, 0)
objects, err := svc.repo.List(ctx, filterOptions, listParams)
objects, err := svc.Repository.List(ctx, filterOptions, listParams)
if err != nil {
return objectSpecs, err
}
Expand Down
Loading

0 comments on commit c9abeb1

Please sign in to comment.