diff --git a/CHANGELOG.md b/CHANGELOG.md index b74aa2214..999893662 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -82,6 +82,8 @@ - (@todo docs) Use a default fetch function that will return all relations in case the fetchFunc argument of `Dao.ExpandRecord()` and `Dao.ExpandRecords()` is `nil`. +- (@todo docs) Added `record.ExpandedOne(rel)` and `record.ExpandedAll(rel)` helpers to retrieve casted single or multiple expand relations from the already loaded "expand" Record data. + ## v0.16.9 diff --git a/daos/record.go b/daos/record.go index d6df53927..302a4eee3 100644 --- a/daos/record.go +++ b/daos/record.go @@ -197,6 +197,8 @@ func (dao *Dao) FindRecordsByIds( return records, nil } +// @todo consider to depricate as it may be easier to just use dao.RecordQuery() +// // FindRecordsByExpr finds all records by the specified db expression. // // Returns all collection records if no expressions are provided. diff --git a/models/record.go b/models/record.go index 7be2c44e0..8b4becdec 100644 --- a/models/record.go +++ b/models/record.go @@ -378,6 +378,56 @@ func (m *Record) GetStringSlice(key string) []string { return list.ToUniqueStringSlice(m.Get(key)) } +// ExpandedOne retrieves a single relation Record from the already +// loaded expand data of the current model. +// +// If the requested expand relation is multiple, this method returns +// only first available Record from the expanded relation. +// +// Returns nil if there is no such expand relation loaded. +func (m *Record) ExpandedOne(relField string) *Record { + if m.expand == nil { + return nil + } + + rel := m.expand.Get(relField) + + switch v := rel.(type) { + case *Record: + return v + case []*Record: + if len(v) > 0 { + return v[0] + } + } + + return nil +} + +// ExpandedAll retrieves a slice of relation Records from the already +// loaded expand data of the current model. +// +// If the requested expand relation is single, this method normalizes +// the return result and will wrap the single model as a slice. +// +// Returns nil slice if there is no such expand relation loaded. +func (m *Record) ExpandedAll(relField string) []*Record { + if m.expand == nil { + return nil + } + + rel := m.expand.Get(relField) + + switch v := rel.(type) { + case *Record: + return []*Record{v} + case []*Record: + return v + } + + return nil +} + // Retrieves the "key" json field value and unmarshals it into "result". // // Example diff --git a/models/record_test.go b/models/record_test.go index 1bdc990a3..95dd478dc 100644 --- a/models/record_test.go +++ b/models/record_test.go @@ -541,6 +541,70 @@ func TestRecordMergeExpandNilCheck(t *testing.T) { } } +func TestRecordExpandedRel(t *testing.T) { + collection := &models.Collection{} + + main := models.NewRecord(collection) + + single := models.NewRecord(collection) + single.Id = "single" + + multiple1 := models.NewRecord(collection) + multiple1.Id = "multiple1" + + multiple2 := models.NewRecord(collection) + multiple2.Id = "multiple2" + + main.SetExpand(map[string]any{ + "single": single, + "multiple": []*models.Record{multiple1, multiple2}, + }) + + if v := main.ExpandedOne("missing"); v != nil { + t.Fatalf("Expected nil, got %v", v) + } + + if v := main.ExpandedOne("single"); v == nil || v.Id != "single" { + t.Fatalf("Expected record with id %q, got %v", "single", v) + } + + if v := main.ExpandedOne("multiple"); v == nil || v.Id != "multiple1" { + t.Fatalf("Expected record with id %q, got %v", "multiple1", v) + } +} + +func TestRecordExpandedAll(t *testing.T) { + collection := &models.Collection{} + + main := models.NewRecord(collection) + + single := models.NewRecord(collection) + single.Id = "single" + + multiple1 := models.NewRecord(collection) + multiple1.Id = "multiple1" + + multiple2 := models.NewRecord(collection) + multiple2.Id = "multiple2" + + main.SetExpand(map[string]any{ + "single": single, + "multiple": []*models.Record{multiple1, multiple2}, + }) + + if v := main.ExpandedAll("missing"); v != nil { + t.Fatalf("Expected nil, got %v", v) + } + + if v := main.ExpandedAll("single"); len(v) != 1 || v[0].Id != "single" { + t.Fatalf("Expected [single] slice, got %v", v) + } + + if v := main.ExpandedAll("multiple"); len(v) != 2 || v[0].Id != "multiple1" || v[1].Id != "multiple2" { + t.Fatalf("Expected [multiple1, multiple2] slice, got %v", v) + } +} + func TestRecordSchemaData(t *testing.T) { collection := &models.Collection{ Type: models.CollectionTypeAuth,