Skip to content

Commit

Permalink
Open() func now validates connection, fixed bwmarrin#198
Browse files Browse the repository at this point in the history
Now the open function will follow through a bit more and insure that the
proper sequence of events happens during the Open call.  This required
some refactoring and a few mild changes in the onEvent func.
  • Loading branch information
bwmarrin committed Nov 11, 2017
1 parent 43bf6cf commit 7d1657e
Showing 1 changed file with 124 additions and 68 deletions.
192 changes: 124 additions & 68 deletions wsapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"compress/zlib"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"runtime"
Expand Down Expand Up @@ -45,100 +46,160 @@ type resumePacket struct {
} `json:"d"`
}

// Open opens a websocket connection to Discord.
func (s *Session) Open() (err error) {

// Open creates a websocket connection to Discord.
// See: https://discordapp.com/developers/docs/topics/gateway#connecting
func (s *Session) Open() error {
s.log(LogInformational, "called")

s.Lock()
defer func() {
if err != nil {
s.Unlock()
}
}()
var err error

// A basic state is a hard requirement for Voice.
if s.State == nil {
state := NewState()
state.TrackChannels = false
state.TrackEmojis = false
state.TrackMembers = false
state.TrackRoles = false
state.TrackVoice = false
s.State = state
}
// Prevent Open or other major Session functions from
// being called while Open is still running.
s.Lock()
defer s.Unlock()

// If the websock is already open, bail out here.
if s.wsConn != nil {
err = ErrWSAlreadyOpen
return
}

if s.VoiceConnections == nil {
s.log(LogInformational, "creating new VoiceConnections map")
s.VoiceConnections = make(map[string]*VoiceConnection)
return ErrWSAlreadyOpen
}

// Get the gateway to use for the Websocket connection
if s.gateway == "" {
s.gateway, err = s.Gateway()
if err != nil {
return
return err
}

// Add the version and encoding to the URL
s.gateway = s.gateway + "?v=" + APIVersion + "&encoding=json"
}

// Connect to the Gateway
s.log(LogInformational, "connecting to gateway %s", s.gateway)
header := http.Header{}
header.Add("accept-encoding", "zlib")

s.log(LogInformational, "connecting to gateway %s", s.gateway)
s.wsConn, _, err = websocket.DefaultDialer.Dial(s.gateway, header)
if err != nil {
s.log(LogWarning, "error connecting to gateway %s, %s", s.gateway, err)
s.gateway = "" // clear cached gateway
// TODO: should we add a retry block here?
return
s.wsConn = nil // Just to be safe.
return err
}

defer func() {
// because of this, all code below must set err to the error
// when exiting with an error :) Maybe someone has a better
// way :)
if err != nil {
s.wsConn.Close()
s.wsConn = nil
}
}()

// The first response from Discord should be an Op 10 (Hello) Packet.
// When processed by onEvent the heartbeat goroutine will be started.
mt, m, err := s.wsConn.ReadMessage()
if err != nil {
return err
}
e, err := s.onEvent(mt, m)
if err != nil {
return err
}
if e.Operation != 10 {
err = fmt.Errorf("Expecting Op 10, got Op %d instead.", e.Operation)
return err
}
s.log(LogInformational, "Op 10 Hello Packet received from Discord")
s.LastHeartbeatAck = time.Now().UTC()
var h helloOp
if err = json.Unmarshal(e.RawData, &h); err != nil {
err = fmt.Errorf("error unmarshalling helloOp, %s", err)
return err
}

// Now we send either an Op 2 Identity if this is a brand new
// connection or Op 6 Resume if we are resuming an existing connection.
sequence := atomic.LoadInt64(s.sequence)
if s.sessionID != "" && sequence > 0 {
if s.sessionID == "" && sequence == 0 {

// Send Op 2 Identity Packet
err = s.identify()
if err != nil {
err = fmt.Errorf("error sending identify packet to gateway, %s, %s", s.gateway, err)
return err
}

} else {

// Send Op 6 Resume Packet
p := resumePacket{}
p.Op = 6
p.Data.Token = s.Token
p.Data.SessionID = s.sessionID
p.Data.Sequence = sequence

s.log(LogInformational, "sending resume packet to gateway")
s.wsMutex.Lock()
err = s.wsConn.WriteJSON(p)
s.wsMutex.Unlock()
if err != nil {
s.log(LogWarning, "error sending gateway resume packet, %s, %s", s.gateway, err)
return
err = fmt.Errorf("error sending gateway resume packet, %s, %s", s.gateway, err)
return err
}

} else {

err = s.identify()
if err != nil {
s.log(LogWarning, "error sending gateway identify packet, %s, %s", s.gateway, err)
return
}
}

// Create listening outside of listen, as it needs to happen inside the mutex
// lock.
s.listening = make(chan interface{})
go s.listen(s.wsConn, s.listening)
s.LastHeartbeatAck = time.Now().UTC()
// A basic state is a hard requirement for Voice.
// We create it here so the below READY/RESUMED packet can populate
// the state :)
// XXX: Move to New() func?
if s.State == nil {
state := NewState()
state.TrackChannels = false
state.TrackEmojis = false
state.TrackMembers = false
state.TrackRoles = false
state.TrackVoice = false
s.State = state
}

s.Unlock()
// Now Discord should send us a READY or RESUMED packet.
mt, m, err = s.wsConn.ReadMessage()
if err != nil {
return err
}
e, err = s.onEvent(mt, m)
if err != nil {
return err
}
if e.Type != `READY` && e.Type != `RESUMED` {
// This is not fatal, but it does not follow their API documentation.
s.log(LogWarning, "Expected READY/RESUMED, instead got:\n%#v\n", e)
}
s.log(LogInformational, "First Packet:\n%#v\n", e)

s.log(LogInformational, "emit connect event")
s.log(LogInformational, "We are now connected to Discord, emitting connect event")
s.handleEvent(connectEventType, &Connect{})

// A VoiceConnections map is a hard requirement for Voice.
// XXX: can this be moved to when opening a voice connection?
if s.VoiceConnections == nil {
s.log(LogInformational, "creating new VoiceConnections map")
s.VoiceConnections = make(map[string]*VoiceConnection)
}

// Create listening chan outside of listen, as it needs to happen inside the
// mutex lock and needs to exist before calling heartbeat and listen
// go rountines.
s.listening = make(chan interface{})

// Start sending heartbeats and reading messages from Discord.
go s.heartbeat(s.wsConn, s.listening, h.HeartbeatInterval)
go s.listen(s.wsConn, s.listening)

s.log(LogInformational, "exiting")
return
return nil
}

// listen polls the websocket connection for events, it will stop when the
Expand Down Expand Up @@ -364,9 +425,7 @@ func (s *Session) RequestGuildMembers(guildID, query string, limit int) (err err
//
// If you use the AddHandler() function to register a handler for the
// "OnEvent" event then all events will be passed to that handler.
//
// TODO: You may also register a custom event handler entirely using...
func (s *Session) onEvent(messageType int, message []byte) {
func (s *Session) onEvent(messageType int, message []byte) (*Event, error) {

var err error
var reader io.Reader
Expand All @@ -378,7 +437,7 @@ func (s *Session) onEvent(messageType int, message []byte) {
z, err2 := zlib.NewReader(reader)
if err2 != nil {
s.log(LogError, "error uncompressing websocket message, %s", err)
return
return nil, err2
}

defer func() {
Expand All @@ -396,7 +455,7 @@ func (s *Session) onEvent(messageType int, message []byte) {
decoder := json.NewDecoder(reader)
if err = decoder.Decode(&e); err != nil {
s.log(LogError, "error decoding websocket message, %s", err)
return
return e, err
}

s.log(LogDebug, "Op: %d, Seq: %d, Type: %s, Data: %s\n\n", e.Operation, e.Sequence, e.Type, string(e.RawData))
Expand All @@ -410,10 +469,10 @@ func (s *Session) onEvent(messageType int, message []byte) {
s.wsMutex.Unlock()
if err != nil {
s.log(LogError, "error sending heartbeat in response to Op1")
return
return e, err
}

return
return e, nil
}

// Reconnect
Expand All @@ -422,7 +481,7 @@ func (s *Session) onEvent(messageType int, message []byte) {
s.log(LogInformational, "Closing and reconnecting in response to Op7")
s.Close()
s.reconnect()
return
return e, nil
}

// Invalid Session
Expand All @@ -434,36 +493,31 @@ func (s *Session) onEvent(messageType int, message []byte) {
err = s.identify()
if err != nil {
s.log(LogWarning, "error sending gateway identify packet, %s, %s", s.gateway, err)
return
return e, err
}

return
return e, nil
}

if e.Operation == 10 {
var h helloOp
if err = json.Unmarshal(e.RawData, &h); err != nil {
s.log(LogError, "error unmarshalling helloOp, %s", err)
} else {
go s.heartbeat(s.wsConn, s.listening, h.HeartbeatInterval)
}
return
// Op10 is handled by Open()
return e, nil
}

if e.Operation == 11 {
s.Lock()
s.LastHeartbeatAck = time.Now().UTC()
s.Unlock()
s.log(LogInformational, "got heartbeat ACK")
return
return e, nil
}

// Do not try to Dispatch a non-Dispatch Message
if e.Operation != 0 {
// But we probably should be doing something with them.
// TEMP
s.log(LogWarning, "unknown Op: %d, Seq: %d, Type: %s, Data: %s, message: %s", e.Operation, e.Sequence, e.Type, string(e.RawData), string(message))
return
return e, nil
}

// Store the message sequence
Expand Down Expand Up @@ -492,6 +546,8 @@ func (s *Session) onEvent(messageType int, message []byte) {

// For legacy reasons, we send the raw event also, this could be useful for handling unknown events.
s.handleEvent(eventEventType, e)

return e, nil
}

// ------------------------------------------------------------------------------------------------
Expand Down

0 comments on commit 7d1657e

Please sign in to comment.