Skip to content

Commit

Permalink
add config file and reload signal
Browse files Browse the repository at this point in the history
  • Loading branch information
vacuityv committed Apr 25, 2023
1 parent 0a1741e commit e5da63c
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 50 deletions.
35 changes: 15 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ To install, run

$ go get github.com/vacuityv/vacproxy

Build

$ go install github.com/vacuityv/vacproxy/cmd/vacproxy

You will now find a `vacproxy` binary in your `$GOPATH/bin` directory.

Usage
Expand All @@ -24,21 +20,20 @@ Run `vacproxy -help` for more information.

$ vacproxy -help
Usage of ./vacproxy:
-auth string
basic credentials(username:password)
-bind string
proxy bind address (default "0.0.0.0:7777")
-daemon
run as daemon
-log string
the log file path (default "./vacproxy.log")
-pid string
the pid file path (default "./vacproxy.pid")
-q quit proxy
-s string
Send signal to the daemon:
quit — graceful shutdown
stop — fast shutdown
reload — reloading the configuration file
-bind string
proxy bind address (default "0.0.0.0:7777")
-config string
config file (default "./config.yml")
-log string
the log file path (default "./vacproxy.log")
-pid string
the pid file path (default "./vacproxy.pid")
-q
quit proxy
-s string
Send signal to the daemon:
stop — shutdown, same as -q
reload — reloading the configuration file



18 changes: 18 additions & 0 deletions config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: test

# 代理鉴权配置,enabled为true且user和password均不为空代表鉴权
auth:
enabled: true
user: test
password: 1234

# 请求ip白名单,放空代表不限制
inAllowList:
- 127.0.0.1
- 192.168.100.*

# 目标域名/ip白名单,放空代表不限制
outAllowList:
# - weixin.qq.com
# - alipay.com
# - baidu.com
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ module vacproxy

go 1.20

require github.com/sevlyar/go-daemon v0.1.6
require (
github.com/sevlyar/go-daemon v0.1.6
gopkg.in/yaml.v3 v3.0.1
)

require (
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect
Expand Down
50 changes: 29 additions & 21 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,32 @@ import (

var (
signal = flag.String("s", "", `Send signal to the daemon:
stop — shutdown
stop — shutdown, same as -q
reload — reloading the configuration file`)
)

var (
stop = make(chan struct{})
done = make(chan struct{})
stop = make(chan struct{})
done = make(chan struct{})
reload = make(chan struct{})
)

func main() {

addr := flag.String("bind", "0.0.0.0:7777", "proxy bind address")
auth := flag.String("auth", "", "basic credentials(username:password)")
logf := flag.String("log", "./vacproxy.log", "the log file path")
pidf := flag.String("pid", "./vacproxy.pid", "the pid file path")
quit := flag.Bool("q", false, "quit proxy")
nd := flag.Bool("nd", false, "not run as daemon")
configFile := flag.String("config", "./config.yml", "config file")
flag.Parse()

if *quit || *signal == "stop" {
*signal = "stop"
*nd = false
*quit = true
}

if *nd {
startWorker(*addr, *auth)
return
}

daemon.AddCommand(daemon.StringFlag(signal, "stop"), syscall.SIGTERM, termHandler)
//daemon.AddCommand(daemon.StringFlag(signal, "reload"), syscall.SIGHUP, reloadHandler)
daemon.AddCommand(daemon.StringFlag(signal, "reload"), syscall.SIGHUP, reloadHandler)

cntxt := &daemon.Context{
PidFileName: *pidf,
Expand Down Expand Up @@ -75,30 +69,44 @@ func main() {
log.Println("daemon started")

log.Println(*addr)
startWorker(*addr, *auth)
go startWorker(*addr, *configFile)

err = daemon.ServeSignals()
if err != nil {
log.Fatalf("Error: %s", err.Error())
}

log.Println("daemon terminated")

}

func startWorker(addr string, auth string) {
func startWorker(addr string, configFile string) {
// main worker
server := service.NewServer(addr, auth)
server := service.NewServer(addr, configFile)
// watch the signal
go watchSig(server)
// start server
server.Start()
select {
case <-stop:
server.Stop()
default:

}

func watchSig(s *service.Server) {
for {
select {
case <-stop:
s.Stop()
case <-reload:
s.Reload()
default:
}
}
}

func reloadHandler(sig os.Signal) error {
reload <- struct{}{}
return nil
}

func termHandler(sig os.Signal) error {
log.Println("terminating...")
stop <- struct{}{}
if sig == syscall.SIGQUIT {
<-done
Expand Down
34 changes: 28 additions & 6 deletions service/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,37 @@ func (c *serverConnnection) serve() {
return
}

localForClientIp, localForClientPort, clientIp, clientPort := getIpAddr(c.rwc)
localForServerIp, localForServerPort, serverIp, serverPort := getIpAddr(remoteConn)

// check client ip
if len(c.server.config.InAllowList) > 0 && !c.server.inMatch.Match(clientIp) {
log.Printf("%d-clientIp not in allow list: %s", c.logId, clientIp)
_, err = c.rwc.Write([]byte("HTTP/1.1 401 clientIp not in allow list " + clientIp + "\r\n\r\n"))
if err != nil {
log.Printf("%s", err.Error())
return
}
}
// check server ip
domain := strings.Split(remote, ":")[0]
if len(c.server.config.OutAllowList) > 0 && !c.server.outMatch.Match(domain) {
log.Printf("%d-server host not in allow list: %s", c.logId, domain)
_, err = c.rwc.Write([]byte("HTTP/1.1 401 server host not in allow list " + domain + "\r\n\r\n"))
if err != nil {
log.Printf("%s", err.Error())
return
}
}

log.Printf("%d-client connect %s:%d to %s:%d", c.logId, clientIp, clientPort, localForClientIp, localForClientPort)
log.Printf("%d-server connect %s:%d to %s:%d", c.logId, localForServerIp, localForServerPort, serverIp, serverPort)

if isHttps {
// if https, should sent 200 to client
_, err = c.rwc.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n"))
if err != nil {
log.Fatalln(err)
log.Printf("%s", err.Error())
return
}
} else {
Expand All @@ -63,12 +89,7 @@ func (c *serverConnnection) serve() {

// build bidirectional-streams
log.Printf("%d-begin tunnel", c.logId)
localForClientIp, localForClientPort, clientIp, clientPort := getIpAddr(c.rwc)
localForServerIp, localForServerPort, serverIp, serverPort := getIpAddr(remoteConn)
log.Printf("%d-client connect %s:%d to %s:%d", c.logId, clientIp, clientPort, localForClientIp, localForClientPort)
log.Printf("%d-server connect %s:%d to %s:%d", c.logId, localForServerIp, localForServerPort, serverIp, serverPort)
c.tunnel(remoteConn)
log.Printf("%d-end tunnel", c.logId)
}

/*
Expand Down Expand Up @@ -151,6 +172,7 @@ func (c *serverConnnection) tunnel(remoteConn net.Conn) {
}
defer remoteConn.Close()
defer c.rwc.Close()
defer log.Printf("%d-end tunnel", c.logId)

clientDoneCh, serverDoneCh := make(chan struct{}), make(chan struct{})
go dataCopy(remoteConn, c.rwc, serverDoneCh)
Expand Down
47 changes: 47 additions & 0 deletions service/ipDomainMatcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package service

import (
"strings"
)

type Node struct {
children map[string]*Node
value bool
}

func NewNode() *Node {
return &Node{
children: make(map[string]*Node),
value: false,
}
}

func (n *Node) Insert(domain string) {
parts := strings.Split(domain, ".")
node := n
for _, part := range parts {
if _, ok := node.children[part]; !ok {
node.children[part] = NewNode()
}
node = node.children[part]
}
node.value = true
}

func (n *Node) Match(domain string) bool {
parts := strings.Split(domain, ".")
node := n
for i, part := range parts {
if child, ok := node.children[part]; ok {
node = child
} else if child, ok := node.children["*"]; ok && i == len(parts)-1 {
node = child
} else {
return false
}
if node.value {
return true
}
}
return false
}
Loading

0 comments on commit e5da63c

Please sign in to comment.