Skip to content

Commit 9260035

Browse files
committed
Do not print NS on last hop
1 parent e46102c commit 9260035

File tree

3 files changed

+48
-47
lines changed

3 files changed

+48
-47
lines changed

client/cache.go

+3-18
Original file line numberDiff line numberDiff line change
@@ -39,29 +39,14 @@ func (s Server) String() string {
3939
return fmt.Sprintf("%s %d NS (%s): %v", s.Name, s.TTL, strings.Join(s.Addrs, ","), s.LookupErr)
4040
}
4141

42-
type Servers []Server
43-
44-
func (s Servers) String() string {
45-
if len(s) > 0 {
46-
if s[0].Name == "A.root-servers.net." {
47-
return "*.root-servers.net."
48-
}
49-
}
50-
names := make([]string, 0, len(s))
51-
for _, s := range s {
52-
names = append(names, s.Name)
53-
}
54-
return strings.Join(names, ", ")
55-
}
56-
5742
// DelegationCache store and retrive delegations.
5843
type DelegationCache struct {
59-
c map[string]Servers
44+
c map[string][]Server
6045
mu sync.Mutex
6146
}
6247

6348
// Get returns the most specific name servers for domain with its matching label.
64-
func (d *DelegationCache) Get(domain string) (label string, servers Servers) {
49+
func (d *DelegationCache) Get(domain string) (label string, servers []Server) {
6550
d.mu.Lock()
6651
defer d.mu.Unlock()
6752
for offset, end := 0, false; !end; offset, end = dns.NextLabel(domain, offset) {
@@ -85,7 +70,7 @@ func (d *DelegationCache) Add(domain string, s Server) error {
8570
}
8671
}
8772
if d.c == nil {
88-
d.c = map[string]Servers{}
73+
d.c = map[string][]Server{}
8974
}
9075
d.c[domain] = append(d.c[domain], s)
9176
return nil

client/client.go

+43-27
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func (rs Responses) Fastest() *Response {
3737
}
3838

3939
type Tracer struct {
40-
GotDelegateResponses func(i int, m *dns.Msg, rs Responses)
40+
GotDelegateResponses func(i int, m *dns.Msg, rs Responses, last bool)
4141
FollowingCNAME func(domain, target string)
4242
}
4343

@@ -51,7 +51,7 @@ func New() Client {
5151

5252
// ParallelQuery perform an exchange using m with all servers in parallel and
5353
// return all responses.
54-
func (c *Client) ParallelQuery(m *dns.Msg, servers Servers) Responses {
54+
func (c *Client) ParallelQuery(m *dns.Msg, servers []Server) Responses {
5555
rc := make(chan Response)
5656
cnt := 0
5757
for _, s := range servers {
@@ -81,8 +81,9 @@ func (c *Client) RecursiveQuery(m *dns.Msg, tracer Tracer) (r *dns.Msg, rtt time
8181
m = m.Copy()
8282
qname := m.Question[0].Name
8383
qtype := m.Question[0].Qtype
84+
zone := "."
8485
for i := 1; i < 100; i++ {
85-
deleg, servers := c.DCache.Get(qname)
86+
_, servers := c.DCache.Get(qname)
8687
m.Question[0].Name = qname
8788
rs := c.ParallelQuery(m, servers)
8889

@@ -99,26 +100,47 @@ func (c *Client) RecursiveQuery(m *dns.Msg, tracer Tracer) (r *dns.Msg, rtt time
99100
}
100101
rtt += fr.RTT
101102

102-
done := false
103+
var done bool
104+
var deleg bool
105+
var cname string
103106
for _, rr := range r.Answer {
104-
if rr.Header().Rrtype == qtype && rr.Header().Name == qname {
107+
if rr.Header().Name == qname && rr.Header().Rrtype == qtype {
105108
done = true
106109
break
110+
} else if rr.Header().Rrtype == dns.TypeCNAME {
111+
cname = rr.Header().Name
112+
qname = rr.(*dns.CNAME).Target
113+
zone = "."
114+
}
115+
}
116+
if !done && cname == "" {
117+
for _, ns := range r.Ns {
118+
if ns, ok := ns.(*dns.NS); ok && len(ns.Header().Name) > len(zone) {
119+
deleg = true
120+
zone = ns.Header().Name
121+
break
122+
}
107123
}
108124
}
109125

110-
if !done {
126+
if deleg {
111127
lrttc := make(chan time.Duration)
112128
lc := 0
113129
for _, ns := range r.Ns {
114130
ns, ok := ns.(*dns.NS)
115131
if !ok {
116132
continue // skip DS records
117133
}
134+
name := ns.Header().Name
118135
var addrs []string
119136
for _, rr := range r.Extra {
120-
if a, ok := rr.(*dns.A); ok && a.Header().Name == ns.Ns {
121-
addrs = append(addrs, a.A.String())
137+
if rr.Header().Name == ns.Ns {
138+
switch a := rr.(type) {
139+
case *dns.A:
140+
addrs = append(addrs, a.A.String())
141+
case *dns.AAAA:
142+
addrs = append(addrs, a.AAAA.String())
143+
}
122144
}
123145
}
124146
s := Server{
@@ -137,12 +159,13 @@ func (c *Client) RecursiveQuery(m *dns.Msg, tracer Tracer) (r *dns.Msg, rtt time
137159
if err != nil {
138160
s.LookupErr = err
139161
}
140-
c.DCache.Add(ns.Header().Name, s)
162+
c.DCache.Add(name, s)
141163
lrttc <- s.LookupRTT
142164
}()
143165
continue
144166
}
145-
c.DCache.Add(ns.Header().Name, s)
167+
c.DCache.Add(name, s)
168+
c.LCache.Set(s.Name, s.Addrs)
146169
if tracer.GotDelegateResponses == nil {
147170
// If not traced, do not resolve all NS
148171
break
@@ -159,27 +182,17 @@ func (c *Client) RecursiveQuery(m *dns.Msg, tracer Tracer) (r *dns.Msg, rtt time
159182
}
160183

161184
if tracer.GotDelegateResponses != nil {
162-
tracer.GotDelegateResponses(i, m.Copy(), rs)
185+
last := !deleg && cname == ""
186+
tracer.GotDelegateResponses(i, m.Copy(), rs, last)
163187
}
164188

165-
if len(r.Answer) > 0 {
166-
var cname string
167-
for _, rr := range r.Answer {
168-
if rr.Header().Rrtype == dns.TypeCNAME {
169-
cname = rr.Header().Name
170-
qname = rr.(*dns.CNAME).Target
171-
}
172-
}
173-
if cname != "" {
174-
if tracer.FollowingCNAME != nil {
175-
tracer.FollowingCNAME(cname, qname)
176-
}
177-
continue
189+
if cname != "" {
190+
if tracer.FollowingCNAME != nil {
191+
tracer.FollowingCNAME(cname, qname)
178192
}
179-
return r, rtt, nil
193+
continue
180194
}
181-
182-
if label, _ := c.DCache.Get(qname); len(r.Ns) == 0 || deleg == label {
195+
if !deleg {
183196
return r, rtt, nil
184197
}
185198
}
@@ -214,6 +227,9 @@ func (c *Client) lookupHost(m *dns.Msg) (addrs []string, rtt time.Duration, err
214227
if r.RTT > rtt {
215228
rtt = r.RTT // get the longest of the two // queries
216229
}
230+
if r.Msg == nil {
231+
continue
232+
}
217233
for _, rr := range r.Msg.Answer {
218234
switch rr := rr.(type) {
219235
case *dns.A:

main.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func main() {
8181
c := client.New()
8282
c.Client.Timeout = 500 * time.Millisecond
8383
t := client.Tracer{
84-
GotDelegateResponses: func(i int, m *dns.Msg, rs client.Responses) {
84+
GotDelegateResponses: func(i int, m *dns.Msg, rs client.Responses, last bool) {
8585
fr := rs.Fastest()
8686
var r *dns.Msg
8787
if fr != nil {
@@ -114,7 +114,7 @@ func main() {
114114
fmt.Print("\n")
115115
}
116116

117-
if len(r.Ns) > 0 {
117+
if !last && len(r.Ns) > 0 {
118118
var label string
119119
for _, rr := range r.Ns {
120120
if ns, ok := rr.(*dns.NS); ok {

0 commit comments

Comments
 (0)