Skip to content

Commit

Permalink
Chore: split enhanced mode instance (#936)
Browse files Browse the repository at this point in the history
Co-authored-by: Dreamacro <[email protected]>
  • Loading branch information
Kr328 and Dreamacro authored Sep 17, 2020
1 parent e773f95 commit 558ac6b
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 113 deletions.
5 changes: 5 additions & 0 deletions component/fakeip/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ func (p *Pool) Gateway() net.IP {
return uintToIP(p.gateway)
}

// PatchFrom clone cache from old pool
func (p *Pool) PatchFrom(o *Pool) {
o.cache.CloneTo(p.cache)
}

func (p *Pool) get(host string) net.IP {
current := p.offset
for {
Expand Down
46 changes: 46 additions & 0 deletions component/resolver/enhancer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package resolver

import (
"net"
)

var DefaultHostMapper Enhancer

type Enhancer interface {
FakeIPEnabled() bool
MappingEnabled() bool
IsFakeIP(net.IP) bool
FindHostByIP(net.IP) (string, bool)
}

func FakeIPEnabled() bool {
if mapper := DefaultHostMapper; mapper != nil {
return mapper.FakeIPEnabled()
}

return false
}

func MappingEnabled() bool {
if mapper := DefaultHostMapper; mapper != nil {
return mapper.MappingEnabled()
}

return false
}

func IsFakeIP(ip net.IP) bool {
if mapper := DefaultHostMapper; mapper != nil {
return mapper.IsFakeIP(ip)
}

return false
}

func FindHostByIP(ip net.IP) (string, bool) {
if mapper := DefaultHostMapper; mapper != nil {
return mapper.FindHostByIP(ip)
}

return "", false
}
76 changes: 76 additions & 0 deletions dns/enhancer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package dns

import (
"net"

"github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/component/fakeip"
)

type ResolverEnhancer struct {
mode EnhancedMode
fakePool *fakeip.Pool
mapping *cache.LruCache
}

func (h *ResolverEnhancer) FakeIPEnabled() bool {
return h.mode == FAKEIP
}

func (h *ResolverEnhancer) MappingEnabled() bool {
return h.mode == FAKEIP || h.mode == MAPPING
}

func (h *ResolverEnhancer) IsFakeIP(ip net.IP) bool {
if !h.FakeIPEnabled() {
return false
}

if pool := h.fakePool; pool != nil {
return pool.Exist(ip)
}

return false
}

func (h *ResolverEnhancer) FindHostByIP(ip net.IP) (string, bool) {
if pool := h.fakePool; pool != nil {
if host, existed := pool.LookBack(ip); existed {
return host, true
}
}

if mapping := h.mapping; mapping != nil {
if host, existed := h.mapping.Get(ip.String()); existed {
return host.(string), true
}
}

return "", false
}

func (h *ResolverEnhancer) PatchFrom(o *ResolverEnhancer) {
if h.mapping != nil && o.mapping != nil {
o.mapping.CloneTo(h.mapping)
}

if h.fakePool != nil && o.fakePool != nil {
h.fakePool.PatchFrom(o.fakePool)
}
}

func NewEnhancer(cfg Config) *ResolverEnhancer {
var fakePool *fakeip.Pool
var mapping *cache.LruCache

if cfg.EnhancedMode != NORMAL {
fakePool = cfg.Pool
mapping = cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true))
}

return &ResolverEnhancer{
mode: cfg.EnhancedMode,
fakePool: fakePool,
mapping: mapping,
}
}
90 changes: 64 additions & 26 deletions dns/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,31 @@ package dns
import (
"net"
"strings"
"time"

"github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/component/fakeip"
"github.com/Dreamacro/clash/component/trie"
"github.com/Dreamacro/clash/log"

D "github.com/miekg/dns"
)

type handler func(w D.ResponseWriter, r *D.Msg)
type handler func(r *D.Msg) (*D.Msg, error)
type middleware func(next handler) handler

func withHosts(hosts *trie.DomainTrie) middleware {
return func(next handler) handler {
return func(w D.ResponseWriter, r *D.Msg) {
return func(r *D.Msg) (*D.Msg, error) {
q := r.Question[0]

if !isIPRequest(q) {
next(w, r)
return
return next(r)
}

record := hosts.Search(strings.TrimRight(q.Name, "."))
if record == nil {
next(w, r)
return
return next(r)
}

ip := record.Data.(net.IP)
Expand All @@ -46,22 +46,60 @@ func withHosts(hosts *trie.DomainTrie) middleware {

msg.Answer = []D.RR{rr}
} else {
next(w, r)
return
return next(r)
}

msg.SetRcode(r, D.RcodeSuccess)
msg.Authoritative = true
msg.RecursionAvailable = true

w.WriteMsg(msg)
return msg, nil
}
}
}

func withMapping(mapping *cache.LruCache) middleware {
return func(next handler) handler {
return func(r *D.Msg) (*D.Msg, error) {
q := r.Question[0]

if !isIPRequest(q) {
return next(r)
}

msg, err := next(r)
if err != nil {
return nil, err
}

host := strings.TrimRight(q.Name, ".")

for _, ans := range msg.Answer {
var ip net.IP
var ttl uint32

switch a := ans.(type) {
case *D.A:
ip = a.A
ttl = a.Hdr.Ttl
case *D.AAAA:
ip = a.AAAA
ttl = a.Hdr.Ttl
default:
continue
}

mapping.SetWithExpire(ip.String(), host, time.Now().Add(time.Second*time.Duration(ttl)))
}

return msg, nil
}
}
}

func withFakeIP(fakePool *fakeip.Pool) middleware {
return func(next handler) handler {
return func(w D.ResponseWriter, r *D.Msg) {
return func(r *D.Msg) (*D.Msg, error) {
q := r.Question[0]

if q.Qtype == D.TypeAAAA {
Expand All @@ -72,17 +110,14 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
msg.Authoritative = true
msg.RecursionAvailable = true

w.WriteMsg(msg)
return
return msg, nil
} else if q.Qtype != D.TypeA {
next(w, r)
return
return next(r)
}

host := strings.TrimRight(q.Name, ".")
if fakePool.LookupHost(host) {
next(w, r)
return
return next(r)
}

rr := &D.A{}
Expand All @@ -97,13 +132,13 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
msg.Authoritative = true
msg.RecursionAvailable = true

w.WriteMsg(msg)
return msg, nil
}
}
}

func withResolver(resolver *Resolver) handler {
return func(w D.ResponseWriter, r *D.Msg) {
return func(r *D.Msg) (*D.Msg, error) {
q := r.Question[0]

// return a empty AAAA msg when ipv6 disabled
Expand All @@ -115,19 +150,18 @@ func withResolver(resolver *Resolver) handler {
msg.Authoritative = true
msg.RecursionAvailable = true

w.WriteMsg(msg)
return
return msg, nil
}

msg, err := resolver.Exchange(r)
if err != nil {
log.Debugln("[DNS Server] Exchange %s failed: %v", q.String(), err)
D.HandleFailed(w, r)
return
return msg, err
}
msg.SetRcode(r, msg.Rcode)
msg.Authoritative = true
w.WriteMsg(msg)

return msg, nil
}
}

Expand All @@ -142,15 +176,19 @@ func compose(middlewares []middleware, endpoint handler) handler {
return h
}

func newHandler(resolver *Resolver) handler {
func newHandler(resolver *Resolver, mapper *ResolverEnhancer) handler {
middlewares := []middleware{}

if resolver.hosts != nil {
middlewares = append(middlewares, withHosts(resolver.hosts))
}

if resolver.FakeIPEnabled() {
middlewares = append(middlewares, withFakeIP(resolver.pool))
if mapper.mode == FAKEIP {
middlewares = append(middlewares, withFakeIP(mapper.fakePool))
}

if mapper.mode != NORMAL {
middlewares = append(middlewares, withMapping(mapper.mapping))
}

return compose(middlewares, withResolver(resolver))
Expand Down
Loading

0 comments on commit 558ac6b

Please sign in to comment.