Skip to content

Commit d75cb7b

Browse files
committed
Fix domain comparison bug
1 parent 9260035 commit d75cb7b

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

client/cache.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ type DelegationCache struct {
4949
func (d *DelegationCache) Get(domain string) (label string, servers []Server) {
5050
d.mu.Lock()
5151
defer d.mu.Unlock()
52+
domain = strings.ToLower(domain)
5253
for offset, end := 0, false; !end; offset, end = dns.NextLabel(domain, offset) {
5354
label = domain[offset:]
5455
var found bool
@@ -64,8 +65,9 @@ func (d *DelegationCache) Get(domain string) (label string, servers []Server) {
6465
func (d *DelegationCache) Add(domain string, s Server) error {
6566
d.mu.Lock()
6667
defer d.mu.Unlock()
68+
domain = strings.ToLower(domain)
6769
for _, s2 := range d.c[domain] {
68-
if s2.Name == s.Name {
70+
if domainEqual(s2.Name, s.Name) {
6971
return nil
7072
}
7173
}

client/client.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package client
33
import (
44
"errors"
55
"net"
6+
"strings"
67
"time"
78

89
"github.com/miekg/dns"
@@ -74,6 +75,10 @@ func (c *Client) ParallelQuery(m *dns.Msg, servers []Server) Responses {
7475
return rs
7576
}
7677

78+
func domainEqual(d1, d2 string) bool {
79+
return strings.ToLower(dns.Fqdn(d1)) == strings.ToLower(dns.Fqdn(d2))
80+
}
81+
7782
// RecursiveQuery performs a recursive query by querying all the available name
7883
// servers to gather statistics.
7984
func (c *Client) RecursiveQuery(m *dns.Msg, tracer Tracer) (r *dns.Msg, rtt time.Duration, err error) {
@@ -104,7 +109,7 @@ func (c *Client) RecursiveQuery(m *dns.Msg, tracer Tracer) (r *dns.Msg, rtt time
104109
var deleg bool
105110
var cname string
106111
for _, rr := range r.Answer {
107-
if rr.Header().Name == qname && rr.Header().Rrtype == qtype {
112+
if domainEqual(rr.Header().Name, qname) && rr.Header().Rrtype == qtype {
108113
done = true
109114
break
110115
} else if rr.Header().Rrtype == dns.TypeCNAME {
@@ -134,7 +139,7 @@ func (c *Client) RecursiveQuery(m *dns.Msg, tracer Tracer) (r *dns.Msg, rtt time
134139
name := ns.Header().Name
135140
var addrs []string
136141
for _, rr := range r.Extra {
137-
if rr.Header().Name == ns.Ns {
142+
if domainEqual(rr.Header().Name, ns.Ns) {
138143
switch a := rr.(type) {
139144
case *dns.A:
140145
addrs = append(addrs, a.A.String())

0 commit comments

Comments
 (0)