diff --git a/raft.go b/raft.go index 773dea5bb..06a40d849 100644 --- a/raft.go +++ b/raft.go @@ -29,14 +29,21 @@ func (r *Raft) getRPCHeader() RPCHeader { } } -// dispositionRPC houses logic about whether this instance of Raft can process -// an RPC message with the given header. -func (r *Raft) dispositionRPC(header RPCHeader) bool { +// checkRPCHeader houses logic about whether this instance of Raft can process +// the given RPC message. +func (r *Raft) checkRPCHeader(rpc RPC) error { + // Get the header off the RPC message. + wh, ok := rpc.Command.(WithRPCHeader) + if !ok { + return fmt.Errorf("RPC does not have a header") + } + header := wh.GetRPCHeader() + // First check is to just make sure the code can understand the // protocol at all. if header.ProtocolVersion < ProtocolVersionMin || header.ProtocolVersion > ProtocolVersionMax { - return false + return ErrUnsupportedProtocol } // Second check is whether we should support this message, given the @@ -45,7 +52,11 @@ func (r *Raft) dispositionRPC(header RPCHeader) bool { // currently what we want, and in general support one version back. We // may need to revisit this policy depending on how future protocol // changes evolve. - return header.ProtocolVersion >= r.conf.ProtocolVersion-1 + if header.ProtocolVersion < r.conf.ProtocolVersion-1 { + return ErrUnsupportedProtocol + } + + return nil } // commitTuple is used to send an index that was committed, @@ -798,16 +809,8 @@ func (r *Raft) processLog(l *Log, future *logFuture) { // processRPC is called to handle an incoming RPC request. This must only be // called from the main thread. func (r *Raft) processRPC(rpc RPC) { - if wh, ok := rpc.Command.(WithRPCHeader); ok { - if ok := r.dispositionRPC(wh.GetRPCHeader()); !ok { - r.logger.Printf("[ERR] raft: Ignoring unsupported RPC %#v for command: %#v", wh, rpc.Command) - rpc.Respond(nil, ErrUnsupportedProtocol) - return - } - - } else { - r.logger.Printf("[ERR] raft: Ignoring un-versioned command: %#v", rpc.Command) - rpc.Respond(nil, fmt.Errorf("unversioned command")) + if err := r.checkRPCHeader(rpc); err != nil { + rpc.Respond(nil, err) return }