Skip to content

Commit

Permalink
Fix disabled DNS resolver fail (#978)
Browse files Browse the repository at this point in the history
Fix fail of DNS when it disabled in the settings
  • Loading branch information
gigovich authored Jun 22, 2023
1 parent c20f98c commit 774d8e9
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 30 deletions.
63 changes: 34 additions & 29 deletions client/internal/dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type DefaultServer struct {
ctx context.Context
ctxCancel context.CancelFunc
mux sync.Mutex
fakeResolverWG sync.WaitGroup
udpFilterHookID string
server *dns.Server
dnsMux *dns.ServeMux
dnsMuxMap registeredHandlerMap
Expand Down Expand Up @@ -105,7 +105,10 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
defaultServer.enabled = hasValidDnsServer(initialDnsCfg)
}

defaultServer.evalRuntimeAddress()
if wgInterface.IsUserspaceBind() {
defaultServer.evelRuntimeAddressForUserspace()
}

return defaultServer, nil
}

Expand All @@ -118,6 +121,9 @@ func (s *DefaultServer) Initialize() (err error) {
return nil
}

if !s.wgInterface.IsUserspaceBind() {
s.evalRuntimeAddress()
}
s.hostManager, err = newHostManager(s.wgInterface)
return
}
Expand All @@ -126,17 +132,8 @@ func (s *DefaultServer) Initialize() (err error) {
func (s *DefaultServer) listen() {
// nil check required in unit tests
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
s.fakeResolverWG.Add(1)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)

hookID := s.filterDNSTraffic()
s.fakeResolverWG.Wait()
if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
}()
s.udpFilterHookID = s.filterDNSTraffic()
s.setListenerStatus(true)
return
}

Expand All @@ -153,6 +150,10 @@ func (s *DefaultServer) listen() {
}()
}

// DnsIP returns the DNS resolver server IP address
//
// When kernel space interface used it return real DNS server listener IP address
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
func (s *DefaultServer) DnsIP() string {
if !s.enabled {
return ""
Expand Down Expand Up @@ -201,17 +202,25 @@ func (s *DefaultServer) Stop() {
}
}

if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
s.fakeResolverWG.Done()
}

err := s.stopListener()
if err != nil {
log.Error(err)
}
}

func (s *DefaultServer) stopListener() error {
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
// udpFilterHookID here empty only in the unit tests
if filter := s.wgInterface.GetFilter(); filter != nil && s.udpFilterHookID != "" {
if err := filter.RemovePacketHook(s.udpFilterHookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
}
s.udpFilterHookID = ""
s.listenerIsRunning = false
return nil
}

if !s.listenerIsRunning {
return nil
}
Expand Down Expand Up @@ -275,12 +284,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be disabled, we stop the listener or fake resolver
// and proceed with a regular update to clean up the handlers and records
if !update.ServiceEnable {
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
s.fakeResolverWG.Done()
} else {
if err := s.stopListener(); err != nil {
log.Error(err)
}
if err := s.stopListener(); err != nil {
log.Error(err)
}
} else if !s.listenerIsRunning {
s.listen()
Expand Down Expand Up @@ -555,17 +560,17 @@ func (s *DefaultServer) filterDNSTraffic() string {
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
}

func (s *DefaultServer) evelRuntimeAddressForUserspace() {
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
s.runtimePort = defaultPort
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
}

func (s *DefaultServer) evalRuntimeAddress() {
defer func() {
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
}()

if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
s.runtimePort = defaultPort
return
}

if s.customAddress != nil {
s.runtimeIP = s.customAddress.Addr().String()
s.runtimePort = int(s.customAddress.Port())
Expand Down
131 changes: 130 additions & 1 deletion client/internal/dns/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@ import (
"fmt"
"net"
"net/netip"
"os"
"strings"
"testing"
"time"

"github.com/golang/mock/gomock"
"github.com/miekg/dns"

"github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
pfmock "github.com/netbirdio/netbird/iface/mocks"
)

var zoneRecords = []nbdns.SimpleRecord{
Expand Down Expand Up @@ -241,7 +244,6 @@ func TestUpdateDNSServer(t *testing.T) {
dnsServer.updateSerial = testCase.initSerial
// pretend we are running
dnsServer.listenerIsRunning = true
dnsServer.fakeResolverWG.Add(1)

err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {
Expand Down Expand Up @@ -276,6 +278,133 @@ func TestUpdateDNSServer(t *testing.T) {
}
}

func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)

os.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Errorf("create stdnet: %v", err)
return
}

wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", iface.DefaultMTU, nil, newNet)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
}

err = wgIface.Create()
if err != nil {
t.Errorf("crate and init wireguard interface: %v", err)
return
}
defer func() {
if err = wgIface.Close(); err != nil {
t.Logf("close wireguard interface: %v", err)
}
}()

ctrl := gomock.NewController(t)
defer ctrl.Finish()

_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
if err != nil {
t.Errorf("parse CIDR: %v", err)
return
}

packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().SetNetwork(ipNet)
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().RemovePacketHook(gomock.Any()).AnyTimes()

if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
return
}

dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
if err != nil {
t.Errorf("create DNS server: %v", err)
return
}

err = dnsServer.Initialize()
if err != nil {
t.Errorf("run DNS server: %v", err)
return
}
defer func() {
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
t.Logf("restore DNS settings on the host: %v", err)
return
}
}()

dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}}
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
dnsServer.updateSerial = 0

nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}

update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
}

// Start the server with regular configuration
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}

update2 := update
update2.ServiceEnable = false
// Disable the server, stop the listener
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}

update3 := update2
update3.NameServerGroups = update3.NameServerGroups[:1]
// But service still get updates and we checking that we handle
// internal state in the right way
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
}

func TestDNSServerStartStop(t *testing.T) {
testCases := []struct {
name string
Expand Down

0 comments on commit 774d8e9

Please sign in to comment.