Skip to content

Commit

Permalink
add allowed host middleware and remove workDir middleware (ollama#3018
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jmorganca authored Mar 9, 2024
1 parent ecc133d commit fc8c044
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 26 deletions.
77 changes: 60 additions & 17 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"log/slog"
"net"
"net/http"
"net/netip"
"os"
"os/signal"
"path/filepath"
Expand All @@ -35,7 +36,7 @@ import (
var mode string = gin.DebugMode

type Server struct {
WorkDir string
addr net.Addr
}

func init() {
Expand Down Expand Up @@ -904,15 +905,64 @@ var defaultAllowOrigins = []string{
"0.0.0.0",
}

func NewServer() (*Server, error) {
workDir, err := os.MkdirTemp("", "ollama")
if err != nil {
return nil, err
func allowedHost(host string) bool {
if host == "" || host == "localhost" {
return true
}

if hostname, err := os.Hostname(); err == nil && host == hostname {
return true
}

var tlds = []string{
".localhost",
".local",
".internal",
}

for _, tld := range tlds {
if strings.HasSuffix(host, "."+tld) {
return true
}
}

return &Server{
WorkDir: workDir,
}, nil
return false
}

func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
return func(c *gin.Context) {
if addr == nil {
c.Next()
return
}

if !netip.MustParseAddrPort(addr.String()).Addr().IsLoopback() {
c.Next()
return
}

if addrPort, _ := netip.ParseAddrPort(c.Request.Host); addrPort.Addr().IsLoopback() {
c.Next()
return
}

if addr, _ := netip.ParseAddr(c.Request.Host); addr.IsLoopback() {
c.Next()
return
}

host, _, err := net.SplitHostPort(c.Request.Host)
if err != nil {
host = c.Request.Host
}

if allowedHost(host) {
c.Next()
return
}

c.AbortWithStatus(http.StatusForbidden)
}
}

func (s *Server) GenerateRoutes() http.Handler {
Expand All @@ -938,10 +988,7 @@ func (s *Server) GenerateRoutes() http.Handler {
r := gin.Default()
r.Use(
cors.New(config),
func(c *gin.Context) {
c.Set("workDir", s.WorkDir)
c.Next()
},
allowedHostsMiddleware(s.addr),
)

r.POST("/api/pull", PullModelHandler)
Expand Down Expand Up @@ -1010,10 +1057,7 @@ func Serve(ln net.Listener) error {
}
}

s, err := NewServer()
if err != nil {
return err
}
s := &Server{addr: ln.Addr()}
r := s.GenerateRoutes()

slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
Expand All @@ -1029,7 +1073,6 @@ func Serve(ln net.Listener) error {
if loaded.runner != nil {
loaded.runner.Close()
}
os.RemoveAll(s.WorkDir)
os.Exit(0)
}()

Expand Down
10 changes: 1 addition & 9 deletions server/routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@ import (
"github.com/jmorganca/ollama/version"
)

func setupServer(t *testing.T) (*Server, error) {
t.Helper()

return NewServer()
}

func Test_Routes(t *testing.T) {
type testCase struct {
Name string
Expand Down Expand Up @@ -207,9 +201,7 @@ func Test_Routes(t *testing.T) {
},
}

s, err := setupServer(t)
assert.Nil(t, err)

s := Server{}
router := s.GenerateRoutes()

httpSrv := httptest.NewServer(router)
Expand Down

0 comments on commit fc8c044

Please sign in to comment.