Skip to content

Commit 5665706

Browse files
committedDec 18, 2019
Fix ca* checks
1 parent 8e6b725 commit 5665706

File tree

2 files changed

+135
-80
lines changed

2 files changed

+135
-80
lines changed
 

‎firewall.go

+93-41
Original file line numberDiff line numberDiff line change
@@ -83,19 +83,23 @@ func newFirewallTable() *FirewallTable {
8383
}
8484
}
8585

86+
type FirewallCA struct {
87+
Any *FirewallRule
88+
CANames map[string]*FirewallRule
89+
CAShas map[string]*FirewallRule
90+
}
91+
8692
type FirewallRule struct {
87-
// Any makes Hosts, Groups, and CIDR irrelevant. CAName and CASha still need to be checked
88-
Any bool
89-
Hosts map[string]struct{}
90-
Groups [][]string
91-
CIDR *CIDRTree
92-
CANames map[string]struct{}
93-
CAShas map[string]struct{}
93+
// Any makes Hosts, Groups, and CIDR irrelevant
94+
Any bool
95+
Hosts map[string]struct{}
96+
Groups [][]string
97+
CIDR *CIDRTree
9498
}
9599

96100
// Even though ports are uint16, int32 maps are faster for lookup
97101
// Plus we can use `-1` for fragment rules
98-
type firewallPort map[int32]*FirewallRule
102+
type firewallPort map[int32]*FirewallCA
99103

100104
type FirewallPacket struct {
101105
LocalIP uint32
@@ -182,9 +186,9 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
182186

183187
func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
184188
fw := NewFirewall(
185-
c.GetDuration("firewall.conntrack.tcp_timeout", time.Duration(time.Minute*12)),
186-
c.GetDuration("firewall.conntrack.udp_timeout", time.Duration(time.Minute*3)),
187-
c.GetDuration("firewall.conntrack.default_timeout", time.Duration(time.Minute*10)),
189+
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
190+
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
191+
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
188192
nc,
189193
//TODO: max_connections
190194
)
@@ -499,12 +503,9 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
499503

500504
for i := startPort; i <= endPort; i++ {
501505
if _, ok := fp[i]; !ok {
502-
fp[i] = &FirewallRule{
503-
Groups: make([][]string, 0),
504-
Hosts: make(map[string]struct{}),
505-
CIDR: NewCIDRTree(),
506-
CANames: make(map[string]struct{}),
507-
CAShas: make(map[string]struct{}),
506+
fp[i] = &FirewallCA{
507+
CANames: make(map[string]*FirewallRule),
508+
CAShas: make(map[string]*FirewallRule),
508509
}
509510
}
510511

@@ -539,15 +540,83 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
539540
return fp[fwPortAny].match(p, c, caPool)
540541
}
541542

542-
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
543-
if caName != "" {
544-
fr.CANames[caName] = struct{}{}
543+
func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
544+
// If there is an any rule then there is no need to establish specific ca rules
545+
if fc.Any != nil {
546+
return fc.Any.addRule(groups, host, ip)
547+
}
548+
549+
fr := func() *FirewallRule {
550+
return &FirewallRule{
551+
Hosts: make(map[string]struct{}),
552+
Groups: make([][]string, 0),
553+
CIDR: NewCIDRTree(),
554+
}
555+
}
556+
557+
any := false
558+
if caSha == "" && caName == "" {
559+
any = true
560+
}
561+
562+
if any {
563+
if fc.Any == nil {
564+
fc.Any = fr()
565+
}
566+
567+
// If it's any we need to wipe out any pre-existing rules to save on memory
568+
fc.CAShas = make(map[string]*FirewallRule)
569+
fc.CANames = make(map[string]*FirewallRule)
570+
return fc.Any.addRule(groups, host, ip)
545571
}
546572

547573
if caSha != "" {
548-
fr.CAShas[caSha] = struct{}{}
574+
if _, ok := fc.CAShas[caSha]; !ok {
575+
fc.CAShas[caSha] = fr()
576+
}
577+
err := fc.CAShas[caSha].addRule(groups, host, ip)
578+
if err != nil {
579+
return err
580+
}
581+
}
582+
583+
if caName != "" {
584+
if _, ok := fc.CANames[caName]; !ok {
585+
fc.CANames[caName] = fr()
586+
}
587+
err := fc.CANames[caName].addRule(groups, host, ip)
588+
if err != nil {
589+
return err
590+
}
591+
}
592+
593+
return nil
594+
}
595+
596+
func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
597+
if fc == nil {
598+
return false
599+
}
600+
601+
if fc.Any != nil {
602+
return fc.Any.match(p, c)
549603
}
550604

605+
if t, ok := fc.CAShas[c.Details.Issuer]; ok {
606+
if t.match(p, c) {
607+
return true
608+
}
609+
}
610+
611+
s, err := caPool.GetCAForCert(c)
612+
if err != nil {
613+
return false
614+
}
615+
616+
return fc.CANames[s.Details.Name].match(p, c)
617+
}
618+
619+
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
551620
if fr.Any {
552621
return nil
553622
}
@@ -593,28 +662,11 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
593662
return false
594663
}
595664

596-
func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
665+
func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool {
597666
if fr == nil {
598667
return false
599668
}
600669

601-
// CASha and CAName always need to be checked
602-
if len(fr.CAShas) > 0 {
603-
if _, ok := fr.CAShas[c.Details.Issuer]; !ok {
604-
return false
605-
}
606-
}
607-
608-
if len(fr.CANames) > 0 {
609-
s, err := caPool.GetCAForCert(c)
610-
if err != nil {
611-
return false
612-
}
613-
if _, ok := fr.CANames[s.Details.Name]; !ok {
614-
return false
615-
}
616-
}
617-
618670
// Shortcut path for if groups, hosts, or cidr contained an `any`
619671
if fr.Any {
620672
return true
@@ -773,7 +825,7 @@ func setTCPRTTTracking(c *conn, p []byte) {
773825
ihl := int(p[0]&0x0f) << 2
774826

775827
// Don't track FIN packets
776-
if uint8(p[ihl+13])&tcpFIN != 0 {
828+
if p[ihl+13]&tcpFIN != 0 {
777829
return
778830
}
779831

@@ -787,7 +839,7 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
787839
}
788840

789841
ihl := int(p[0]&0x0f) << 2
790-
if uint8(p[ihl+13])&tcpACK == 0 {
842+
if p[ihl+13]&tcpACK == 0 {
791843
return false
792844
}
793845

‎firewall_test.go

+42-39
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"encoding/binary"
66
"errors"
7+
"fmt"
78
"math"
89
"net"
910
"testing"
@@ -61,37 +62,37 @@ func TestFirewall_AddRule(t *testing.T) {
6162
assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", ""))
6263
// Make sure an empty rule creates structure but doesn't allow anything to flow
6364
//TODO: ideally an empty rule would return an error
64-
assert.False(t, fw.InRules.TCP[1].Any)
65-
assert.Empty(t, fw.InRules.TCP[1].Groups)
66-
assert.Empty(t, fw.InRules.TCP[1].Hosts)
67-
assert.Nil(t, fw.InRules.TCP[1].CIDR.root.left)
68-
assert.Nil(t, fw.InRules.TCP[1].CIDR.root.right)
69-
assert.Nil(t, fw.InRules.TCP[1].CIDR.root.value)
65+
assert.False(t, fw.InRules.TCP[1].Any.Any)
66+
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
67+
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
68+
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left)
69+
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
70+
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
7071

7172
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
7273
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
73-
assert.False(t, fw.InRules.UDP[1].Any)
74-
assert.Contains(t, fw.InRules.UDP[1].Groups[0], "g1")
75-
assert.Empty(t, fw.InRules.UDP[1].Hosts)
76-
assert.Nil(t, fw.InRules.UDP[1].CIDR.root.left)
77-
assert.Nil(t, fw.InRules.UDP[1].CIDR.root.right)
78-
assert.Nil(t, fw.InRules.UDP[1].CIDR.root.value)
74+
assert.False(t, fw.InRules.UDP[1].Any.Any)
75+
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
76+
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
77+
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.left)
78+
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
79+
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
7980

8081
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
8182
assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
82-
assert.False(t, fw.InRules.ICMP[1].Any)
83-
assert.Empty(t, fw.InRules.ICMP[1].Groups)
84-
assert.Contains(t, fw.InRules.ICMP[1].Hosts, "h1")
85-
assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.left)
86-
assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.right)
87-
assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.value)
83+
assert.False(t, fw.InRules.ICMP[1].Any.Any)
84+
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
85+
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
86+
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.left)
87+
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
88+
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
8889

8990
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
9091
assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
91-
assert.False(t, fw.OutRules.AnyProto[1].Any)
92-
assert.Empty(t, fw.OutRules.AnyProto[1].Groups)
93-
assert.Empty(t, fw.OutRules.AnyProto[1].Hosts)
94-
assert.NotNil(t, fw.OutRules.AnyProto[1].CIDR.Match(ip2int(ti.IP)))
92+
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
93+
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
94+
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
95+
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
9596

9697
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
9798
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
@@ -104,28 +105,30 @@ func TestFirewall_AddRule(t *testing.T) {
104105
// Set any and clear fields
105106
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
106107
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
107-
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Groups[0])
108-
assert.Contains(t, fw.OutRules.AnyProto[0].Hosts, "h1")
109-
assert.NotNil(t, fw.OutRules.AnyProto[0].CIDR.Match(ip2int(ti.IP)))
108+
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
109+
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
110+
assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(ip2int(ti.IP)))
110111

111112
// run twice just to make sure
113+
//TODO: these ANY rules should clear the CA firewall portion
112114
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
113115
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
114-
assert.True(t, fw.OutRules.AnyProto[0].Any)
115-
assert.Empty(t, fw.OutRules.AnyProto[0].Groups)
116-
assert.Empty(t, fw.OutRules.AnyProto[0].Hosts)
117-
assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.left)
118-
assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.right)
119-
assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.value)
116+
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
117+
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
118+
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
119+
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left)
120+
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
121+
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
122+
fmt.Printf("%+v\n", fw.OutRules.AnyProto[0])
120123

121124
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
122125
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
123-
assert.True(t, fw.OutRules.AnyProto[0].Any)
126+
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
124127

125128
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
126129
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
127130
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
128-
assert.True(t, fw.OutRules.AnyProto[0].Any)
131+
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
129132

130133
// Test error conditions
131134
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
@@ -209,11 +212,11 @@ func BenchmarkFirewallTable_match(b *testing.B) {
209212
}
210213

211214
_, n, _ := net.ParseCIDR("172.1.1.1/32")
212-
ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
213-
ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
214-
ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
215-
ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
216-
ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
215+
_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
216+
_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
217+
_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
218+
_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
219+
_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
217220
cp := cert.NewCAPool()
218221

219222
b.Run("fail on proto", func(b *testing.B) {
@@ -281,7 +284,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
281284
}
282285
})
283286

284-
ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
287+
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
285288

286289
b.Run("pass on ip with any port", func(b *testing.B) {
287290
ip := ip2int(net.IPv4(172, 1, 1, 1))

0 commit comments

Comments
 (0)
Please sign in to comment.