Skip to content

Commit 7f9c160

Browse files
authored
Protect attribute "Inherit" method from recursive type definitions. (goadesign#1345)
1 parent 7e25d41 commit 7f9c160

File tree

2 files changed

+92
-4
lines changed

2 files changed

+92
-4
lines changed

design/definitions.go

+16-4
Original file line numberDiff line numberDiff line change
@@ -1204,21 +1204,33 @@ func (a *AttributeDefinition) Merge(other *AttributeDefinition) *AttributeDefini
12041204

12051205
// Inherit merges the properties of existing target type attributes with the argument's.
12061206
// The algorithm is recursive so that child attributes are also merged.
1207-
func (a *AttributeDefinition) Inherit(parent *AttributeDefinition) {
1207+
func (a *AttributeDefinition) Inherit(parent *AttributeDefinition, seen ...map[*AttributeDefinition]struct{}) {
12081208
if !a.shouldInherit(parent) {
12091209
return
12101210
}
12111211

12121212
a.inheritValidations(parent)
1213-
a.inheritRecursive(parent)
1213+
a.inheritRecursive(parent, seen...)
12141214
}
12151215

12161216
// DSL returns the initialization DSL.
12171217
func (a *AttributeDefinition) DSL() func() {
12181218
return a.DSLFunc
12191219
}
12201220

1221-
func (a *AttributeDefinition) inheritRecursive(parent *AttributeDefinition) {
1221+
func (a *AttributeDefinition) inheritRecursive(parent *AttributeDefinition, seen ...map[*AttributeDefinition]struct{}) {
1222+
// prevent infinite recursions
1223+
var s map[*AttributeDefinition]struct{}
1224+
if len(seen) > 0 {
1225+
s = seen[0]
1226+
if _, ok := s[parent]; ok {
1227+
return
1228+
}
1229+
} else {
1230+
s = make(map[*AttributeDefinition]struct{})
1231+
}
1232+
s[parent] = struct{}{}
1233+
12221234
if !a.shouldInherit(parent) {
12231235
return
12241236
}
@@ -1239,7 +1251,7 @@ func (a *AttributeDefinition) inheritRecursive(parent *AttributeDefinition) {
12391251
att.Type = patt.Type
12401252
} else if att.shouldInherit(patt) {
12411253
for _, att := range att.Type.ToObject() {
1242-
att.Inherit(patt.Type.ToObject()[n])
1254+
att.Inherit(patt.Type.ToObject()[n], s)
12431255
}
12441256
}
12451257
if att.Example == nil {

design/definitions_test.go

+76
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,82 @@ import (
99
. "github.com/onsi/gomega"
1010
)
1111

12+
var _ = Describe("Inherit", func() {
13+
var child, parent *design.AttributeDefinition
14+
15+
BeforeEach(func() {
16+
parent = &design.AttributeDefinition{Type: design.Object{}}
17+
child = &design.AttributeDefinition{Type: design.Object{}}
18+
})
19+
20+
JustBeforeEach(func() {
21+
child.Inherit(parent)
22+
})
23+
24+
Context("with a empty parent", func() {
25+
const attName = "c"
26+
BeforeEach(func() {
27+
child.Type.(design.Object)[attName] = &design.AttributeDefinition{Type: design.String}
28+
})
29+
30+
It("does not change", func() {
31+
obj := child.Type.(design.Object)
32+
Ω(obj).Should(HaveLen(1))
33+
Ω(obj).Should(HaveKey(attName))
34+
})
35+
})
36+
37+
Context("with a parent that defines no inherited attribute", func() {
38+
const (
39+
attName = "c"
40+
def = "default"
41+
)
42+
43+
BeforeEach(func() {
44+
child.Type.(design.Object)[attName] = &design.AttributeDefinition{Type: design.String}
45+
parent.Type.(design.Object)["other"] = &design.AttributeDefinition{Type: design.String, DefaultValue: def}
46+
})
47+
48+
It("does not change", func() {
49+
obj := child.Type.(design.Object)
50+
Ω(obj).Should(HaveLen(1))
51+
Ω(obj).Should(HaveKey(attName))
52+
Ω(obj[attName].DefaultValue).Should(BeNil())
53+
})
54+
})
55+
56+
Context("with a parent that defines an inherited attribute", func() {
57+
const (
58+
attName = "c"
59+
def = "default"
60+
)
61+
62+
BeforeEach(func() {
63+
child.Type.(design.Object)[attName] = &design.AttributeDefinition{Type: design.String}
64+
parent.Type.(design.Object)[attName] = &design.AttributeDefinition{Type: design.String, DefaultValue: def}
65+
})
66+
67+
It("inherits the default value", func() {
68+
obj := child.Type.(design.Object)
69+
Ω(obj).Should(HaveLen(1))
70+
Ω(obj).Should(HaveKey(attName))
71+
Ω(obj[attName].DefaultValue).Should(Equal(def))
72+
})
73+
})
74+
75+
Context("with recursive type definitions", func() {
76+
BeforeEach(func() {
77+
po := design.Object{}
78+
parent = &design.AttributeDefinition{Type: po}
79+
child = &design.AttributeDefinition{Type: &design.UserTypeDefinition{AttributeDefinition: parent}}
80+
po["recurse"] = child
81+
})
82+
83+
It("does not recurse infinitely", func() {})
84+
})
85+
86+
})
87+
1288
var _ = Describe("IsRequired", func() {
1389
var required string
1490
var attName string

0 commit comments

Comments
 (0)