Skip to content

Commit

Permalink
Write process context on node start to simplify test orchestration (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
marun authored Jul 25, 2023
1 parent 8df4c5f commit 2b8dc5f
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 87 deletions.
14 changes: 0 additions & 14 deletions api/server/mock_server.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

67 changes: 6 additions & 61 deletions api/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package server

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
Expand All @@ -29,7 +28,6 @@ import (
"github.com/ava-labs/avalanchego/snow/engine/common"
"github.com/ava-labs/avalanchego/trace"
"github.com/ava-labs/avalanchego/utils/constants"
"github.com/ava-labs/avalanchego/utils/ips"
"github.com/ava-labs/avalanchego/utils/logging"
)

Expand Down Expand Up @@ -66,8 +64,6 @@ type Server interface {
PathAdderWithReadLock
// Dispatch starts the API server
Dispatch() error
// DispatchTLS starts the API server with the provided TLS certificate
DispatchTLS(certBytes, keyBytes []byte) error
// RegisterChain registers the API endpoints associated with this chain.
// That is, add <route, handler> pairs to server so that API calls can be
// made to the VM.
Expand All @@ -88,9 +84,6 @@ type server struct {
log logging.Logger
// generates new logs for chains to write to
factory logging.Factory
// Listens for HTTP traffic on this address
listenHost string
listenPort string

shutdownTimeout time.Duration

Expand All @@ -103,14 +96,16 @@ type server struct {
router *router

srv *http.Server

// Listener used to serve traffic
listener net.Listener
}

// New returns an instance of a Server.
func New(
log logging.Logger,
factory logging.Factory,
host string,
port uint16,
listener net.Listener,
allowedOrigins []string,
shutdownTimeout time.Duration,
nodeID ids.NodeID,
Expand Down Expand Up @@ -153,8 +148,6 @@ func New(
return &server{
log: log,
factory: factory,
listenHost: host,
listenPort: fmt.Sprintf("%d", port),
shutdownTimeout: shutdownTimeout,
tracingEnabled: tracingEnabled,
tracer: tracer,
Expand All @@ -167,60 +160,12 @@ func New(
WriteTimeout: httpConfig.WriteTimeout,
IdleTimeout: httpConfig.IdleTimeout,
},
listener: listener,
}, nil
}

func (s *server) Dispatch() error {
listenAddress := net.JoinHostPort(s.listenHost, s.listenPort)
listener, err := net.Listen("tcp", listenAddress)
if err != nil {
return err
}

ipPort, err := ips.ToIPPort(listener.Addr().String())
if err != nil {
s.log.Info("HTTP API server listening",
zap.String("address", listenAddress),
)
} else {
s.log.Info("HTTP API server listening",
zap.String("host", s.listenHost),
zap.Uint16("port", ipPort.Port),
)
}

return s.srv.Serve(listener)
}

func (s *server) DispatchTLS(certBytes, keyBytes []byte) error {
listenAddress := net.JoinHostPort(s.listenHost, s.listenPort)
cert, err := tls.X509KeyPair(certBytes, keyBytes)
if err != nil {
return err
}
config := &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
}

listener, err := tls.Listen("tcp", listenAddress, config)
if err != nil {
return err
}

ipPort, err := ips.ToIPPort(listener.Addr().String())
if err != nil {
s.log.Info("HTTPS API server listening",
zap.String("address", listenAddress),
)
} else {
s.log.Info("HTTPS API server listening",
zap.String("host", s.listenHost),
zap.Uint16("port", ipPort.Port),
)
}

return s.srv.Serve(listener)
return s.srv.Serve(s.listener)
}

func (s *server) RegisterChain(chainName string, ctx *snow.ConsensusContext, vm common.VM) {
Expand Down
2 changes: 2 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,8 @@ func GetNodeConfig(v *viper.Viper) (node.Config, error) {

nodeConfig.ChainDataDir = GetExpandedArg(v, ChainDataDirKey)

nodeConfig.ProcessContextFilePath = GetExpandedArg(v, ProcessContextFileKey)

nodeConfig.ProvidedFlags = providedFlags(v)
return nodeConfig, nil
}
Expand Down
3 changes: 3 additions & 0 deletions config/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ var (
defaultSubnetConfigDir = filepath.Join(defaultConfigDir, "subnets")
defaultPluginDir = filepath.Join(defaultUnexpandedDataDir, "plugins")
defaultChainDataDir = filepath.Join(defaultUnexpandedDataDir, "chainData")
defaultProcessContextPath = filepath.Join(defaultUnexpandedDataDir, "process.json")
)

func deprecateFlags(fs *pflag.FlagSet) error {
Expand Down Expand Up @@ -368,6 +369,8 @@ func addNodeFlags(fs *pflag.FlagSet) {
fs.Bool(TracingInsecureKey, true, "If true, don't use TLS when sending trace data")
fs.Float64(TracingSampleRateKey, 0.1, "The fraction of traces to sample. If >= 1, always sample. If <= 0, never sample")
fs.StringToString(TracingHeadersKey, map[string]string{}, "The headers to provide the trace indexer")

fs.String(ProcessContextFileKey, defaultProcessContextPath, "The path to write process context to (including PID, API URI, and staking address).")
}

// BuildFlagSet returns a complete set of flags for avalanchego
Expand Down
1 change: 1 addition & 0 deletions config/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,4 +209,5 @@ const (
TracingSampleRateKey = "tracing-sample-rate"
TracingExporterTypeKey = "tracing-exporter-type"
TracingHeadersKey = "tracing-headers"
ProcessContextFileKey = "process-context-file"
)
4 changes: 4 additions & 0 deletions node/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,8 @@ type Config struct {
// ChainDataDir is the root path for per-chain directories where VMs can
// write arbitrary data.
ChainDataDir string `json:"chainDataDir"`

// Path to write process context to (including PID, API URI, and
// staking address).
ProcessContextFilePath string `json:"processContextFilePath"`
}
98 changes: 86 additions & 12 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ package node
import (
"context"
"crypto"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -144,6 +146,11 @@ type Node struct {
networkNamespace string
Net network.Network

// The staking address will optionally be written to a process context
// file to enable other nodes to be configured to use this node as a
// beacon.
stakingAddress string

// tlsKeyLogWriterCloser is a debug file handle that writes all the TLS
// session keys. This value should only be non-nil during debugging.
tlsKeyLogWriterCloser io.WriteCloser
Expand All @@ -154,6 +161,8 @@ type Node struct {
// current validators of the network
vdrs validators.Manager

apiURI string

// Handles HTTP API calls
APIServer server.Server

Expand Down Expand Up @@ -254,6 +263,9 @@ func (n *Node) initNetworking(primaryNetVdrs validators.Set) error {
)
}

// Record the bound address to enable inclusion in process context file.
n.stakingAddress = listener.Addr().String()

tlsKey, ok := n.Config.StakingTLSCert.PrivateKey.(crypto.Signer)
if !ok {
return errInvalidTLSKey
Expand Down Expand Up @@ -374,19 +386,51 @@ func (n *Node) initNetworking(primaryNetVdrs validators.Set) error {
return err
}

type NodeProcessContext struct {
// The process id of the node
PID int `json:"pid"`
// URI to access the node API
// Format: [https|http]://[host]:[port]
URI string `json:"uri"`
// Address other nodes can use to communicate with this node
// Format: [host]:[port]
StakingAddress string `json:"stakingAddress"`
}

// Write process context to the configured path. Supports the use of
// dynamically chosen network ports with local network orchestration.
func (n *Node) writeProcessContext() error {
n.Log.Info("writing process context", zap.String("path", n.Config.ProcessContextFilePath))

// Write the process context to disk
processContext := &NodeProcessContext{
PID: os.Getpid(),
URI: n.apiURI,
StakingAddress: n.stakingAddress, // Set by network initialization
}
bytes, err := json.MarshalIndent(processContext, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal process context: %w", err)
}
if err := os.WriteFile(n.Config.ProcessContextFilePath, bytes, perms.ReadWrite); err != nil {
return fmt.Errorf("failed to write process context: %w", err)
}
return nil
}

// Dispatch starts the node's servers.
// Returns when the node exits.
func (n *Node) Dispatch() error {
if err := n.writeProcessContext(); err != nil {
return err
}

// Start the HTTP API server
go n.Log.RecoverAndPanic(func() {
var err error
if n.Config.HTTPSEnabled {
n.Log.Debug("initializing API server with TLS")
err = n.APIServer.DispatchTLS(n.Config.HTTPSCert, n.Config.HTTPSKey)
} else {
n.Log.Debug("initializing API server without TLS")
err = n.APIServer.Dispatch()
}
n.Log.Info("API server listening",
zap.String("uri", n.apiURI),
)
err := n.APIServer.Dispatch()
// When [n].Shutdown() is called, [n.APIServer].Close() is called.
// This causes [n.APIServer].Dispatch() to return an error.
// If that happened, don't log/return an error here.
Expand Down Expand Up @@ -429,6 +473,16 @@ func (n *Node) Dispatch() error {

// Wait until the node is done shutting down before returning
n.DoneShuttingDown.Wait()

// Remove the process context file to communicate to an orchestrator
// that the node is no longer running.
if err := os.Remove(n.Config.ProcessContextFilePath); err != nil && !os.IsNotExist(err) {
n.Log.Error("removal of process context file failed",
zap.String("path", n.Config.ProcessContextFilePath),
zap.Error(err),
)
}

return err
}

Expand Down Expand Up @@ -601,13 +655,34 @@ func (n *Node) initMetrics() {
func (n *Node) initAPIServer() error {
n.Log.Info("initializing API server")

listenAddress := net.JoinHostPort(n.Config.HTTPHost, fmt.Sprintf("%d", n.Config.HTTPPort))
listener, err := net.Listen("tcp", listenAddress)
if err != nil {
return err
}

protocol := "http"
if n.Config.HTTPSEnabled {
cert, err := tls.X509KeyPair(n.Config.HTTPSCert, n.Config.HTTPSKey)
if err != nil {
return err
}
config := &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
}
listener = tls.NewListener(listener, config)

protocol = "https"
}
n.apiURI = fmt.Sprintf("%s://%s", protocol, listener.Addr())

if !n.Config.APIRequireAuthToken {
var err error
n.APIServer, err = server.New(
n.Log,
n.LogFactory,
n.Config.HTTPHost,
n.Config.HTTPPort,
listener,
n.Config.HTTPAllowedOrigins,
n.Config.ShutdownTimeout,
n.ID,
Expand All @@ -629,8 +704,7 @@ func (n *Node) initAPIServer() error {
n.APIServer, err = server.New(
n.Log,
n.LogFactory,
n.Config.HTTPHost,
n.Config.HTTPPort,
listener,
n.Config.HTTPAllowedOrigins,
n.Config.ShutdownTimeout,
n.ID,
Expand Down

0 comments on commit 2b8dc5f

Please sign in to comment.