Skip to content

Commit

Permalink
split the Session.Close(error) in Close() and CloseWithError(error)
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Jul 6, 2018
1 parent 2bc5b7f commit 8b2992a
Show file tree
Hide file tree
Showing 22 changed files with 168 additions and 105 deletions.
2 changes: 1 addition & 1 deletion benchmark/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func init() {
b.RecordValue("transfer rate [MB/s]", float64(dataLen)/1e6/runtime.Seconds())

ln.Close()
sess.Close(nil)
sess.Close()
}, samples)
})
}
Expand Down
12 changes: 6 additions & 6 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ func (c *client) establishSecureConnection(ctx context.Context) error {

select {
case <-ctx.Done():
// The session sending a PeerGoingAway error to the server.
c.session.Close(nil)
// The session will send a PeerGoingAway error to the server.
c.session.Close()
return ctx.Err()
case err := <-errorChan:
return err
Expand Down Expand Up @@ -366,7 +366,7 @@ func (c *client) handlePacketImpl(p *receivedPacket) error {

// version negotiation packets have no payload
if err := c.handleVersionNegotiationPacket(p.header); err != nil {
c.session.Close(err)
c.session.destroy(err)
}
return nil
}
Expand Down Expand Up @@ -474,7 +474,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
}

c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
c.session.Close(errCloseSessionForNewVersion)
c.session.destroy(errCloseSessionForNewVersion)
return nil
}

Expand Down Expand Up @@ -526,13 +526,13 @@ func (c *client) createNewTLSSession(
return err
}

func (c *client) Close(err error) error {
func (c *client) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.session == nil {
return nil
}
return c.session.Close(err)
return c.session.Close()
}

func (c *client) GetVersion() protocol.VersionNumber {
Expand Down
2 changes: 1 addition & 1 deletion client_multiplexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (m *clientMultiplexer) listen(c net.PacketConn, p *connManager) {
n, addr, err := c.ReadFrom(data)
if err != nil {
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
p.manager.Close(err)
p.manager.Close()
}
return
}
Expand Down
11 changes: 5 additions & 6 deletions client_multiplexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ var _ = Describe("Client Multiplexer", func() {
conn.dataToRead <- getPacket(connID)
Eventually(handledPacket).Should(BeClosed())
// makes the listen go routine return
packetHandler.EXPECT().Close(gomock.Any()).AnyTimes()
packetHandler.EXPECT().Close().AnyTimes()
close(conn.dataToRead)
})

Expand Down Expand Up @@ -85,8 +85,8 @@ var _ = Describe("Client Multiplexer", func() {
Eventually(handledPacket2).Should(BeClosed())

// makes the listen go routine return
packetHandler1.EXPECT().Close(gomock.Any()).AnyTimes()
packetHandler2.EXPECT().Close(gomock.Any()).AnyTimes()
packetHandler1.EXPECT().Close().AnyTimes()
packetHandler2.EXPECT().Close().AnyTimes()
close(conn.dataToRead)
})

Expand Down Expand Up @@ -114,11 +114,10 @@ var _ = Describe("Client Multiplexer", func() {

It("closes the packet handlers when reading from the conn fails", func() {
conn := newMockPacketConn()
testErr := errors.New("test error")
conn.readErr = testErr
conn.readErr = errors.New("test error")
done := make(chan struct{})
packetHandler := NewMockQuicSession(mockCtrl)
packetHandler.EXPECT().Close(testErr).Do(func(error) {
packetHandler.EXPECT().Close().Do(func() {
close(done)
})
getClientMultiplexer().AddConn(conn, 8)
Expand Down
14 changes: 7 additions & 7 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ var _ = Describe("Client", func() {

AfterEach(func() {
if s, ok := cl.session.(*session); ok {
s.Close(nil)
s.Close()
}
Eventually(areSessionsRunning).Should(BeFalse())
})
Expand Down Expand Up @@ -254,7 +254,7 @@ var _ = Describe("Client", func() {
close(dialed)
}()
Consistently(dialed).ShouldNot(BeClosed())
sess.EXPECT().Close(nil)
sess.EXPECT().Close()
cancel()
Eventually(dialed).Should(BeClosed())
})
Expand Down Expand Up @@ -493,7 +493,7 @@ var _ = Describe("Client", func() {
sess1 := NewMockQuicSession(mockCtrl)
run1 := make(chan struct{})
sess1.EXPECT().run().Do(func() { <-run1 }).Return(errCloseSessionForNewVersion)
sess1.EXPECT().Close(errCloseSessionForNewVersion).Do(func(error) { close(run1) })
sess1.EXPECT().destroy(errCloseSessionForNewVersion).Do(func(error) { close(run1) })
sess2 := NewMockQuicSession(mockCtrl)
sess2.EXPECT().run()
sessionChan := make(chan *MockQuicSession, 2)
Expand Down Expand Up @@ -538,7 +538,7 @@ var _ = Describe("Client", func() {
sess1 := NewMockQuicSession(mockCtrl)
run1 := make(chan struct{})
sess1.EXPECT().run().Do(func() { <-run1 }).Return(errCloseSessionForNewVersion)
sess1.EXPECT().Close(errCloseSessionForNewVersion).Do(func(error) { close(run1) })
sess1.EXPECT().destroy(errCloseSessionForNewVersion).Do(func(error) { close(run1) })
sess2 := NewMockQuicSession(mockCtrl)
sess2.EXPECT().run()
sessionChan := make(chan *MockQuicSession, 2)
Expand Down Expand Up @@ -578,15 +578,15 @@ var _ = Describe("Client", func() {

It("errors if no matching version is found", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().Close(gomock.Any())
sess.EXPECT().destroy(qerr.InvalidVersion)
cl.session = sess
cl.config = &Config{Versions: protocol.SupportedVersions}
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1}))
})

It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().Close(gomock.Any())
sess.EXPECT().destroy(qerr.InvalidVersion)
cl.session = sess
v := protocol.VersionNumber(1234)
Expect(v).ToNot(Equal(cl.version))
Expand All @@ -597,7 +597,7 @@ var _ = Describe("Client", func() {
It("changes to the version preferred by the quic.Config", func() {
mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any())
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().Close(errCloseSessionForNewVersion)
sess.EXPECT().destroy(errCloseSessionForNewVersion)
cl.session = sess
versions := []protocol.VersionNumber{1234, 4321}
cl.config = &Config{Versions: versions}
Expand Down
7 changes: 5 additions & 2 deletions h2quic/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,14 @@ func (c *client) CloseWithError(e error) error {
if c.session == nil {
return nil
}
return c.session.Close(e)
return c.session.CloseWithError(e)
}

func (c *client) Close() error {
return c.CloseWithError(nil)
if c.session == nil {
return nil
}
return c.session.Close()
}

// copied from net/transport.go
Expand Down
6 changes: 3 additions & 3 deletions h2quic/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
func (s *Server) handleHeaderStream(session streamCreator) {
stream, err := session.AcceptStream()
if err != nil {
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
session.CloseWithError(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
return
}

Expand All @@ -143,7 +143,7 @@ func (s *Server) handleHeaderStream(session streamCreator) {
if _, ok := err.(*qerr.QuicError); !ok {
s.logger.Errorf("error handling h2 request: %s", err.Error())
}
session.Close(err)
session.CloseWithError(err)
return
}
}
Expand Down Expand Up @@ -246,7 +246,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
}
if s.CloseAfterFirstRequest {
time.Sleep(100 * time.Millisecond)
session.Close(nil)
session.Close()
}
}()

Expand Down
7 changes: 5 additions & 2 deletions h2quic/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,18 @@ func (s *mockSession) OpenStreamSync() (quic.Stream, error) {
}
return s.OpenStream()
}
func (s *mockSession) Close(e error) error {
s.closedWithError = e
func (s *mockSession) Close() error {
s.ctxCancel()
if !s.closed {
close(s.blockOpenStreamChan)
}
s.closed = true
return nil
}
func (s *mockSession) CloseWithError(e error) error {
s.closedWithError = e
return s.Close()
}
func (s *mockSession) LocalAddr() net.Addr {
panic("not implemented")
}
Expand Down
2 changes: 1 addition & 1 deletion integrationtests/self/conn_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ var _ = Describe("Connection ID lengths tests", func() {
conf,
)
Expect(err).ToNot(HaveOccurred())
defer cl.Close(nil)
defer cl.Close()
str, err := cl.AcceptStream()
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(str)
Expand Down
14 changes: 7 additions & 7 deletions integrationtests/self/handshake_drop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ var _ = Describe("Handshake drop tests", func() {
defer GinkgoRecover()
sess, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
defer sess.Close(nil)
defer sess.Close()
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
Expand All @@ -83,8 +83,8 @@ var _ = Describe("Handshake drop tests", func() {

var serverSession quic.Session
Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession))
sess.Close(nil)
serverSession.Close(nil)
sess.Close()
serverSession.Close()
},
}

Expand Down Expand Up @@ -117,8 +117,8 @@ var _ = Describe("Handshake drop tests", func() {

var serverSession quic.Session
Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession))
sess.Close(nil)
serverSession.Close(nil)
sess.Close()
serverSession.Close()
},
}

Expand All @@ -141,8 +141,8 @@ var _ = Describe("Handshake drop tests", func() {
var serverSession quic.Session
Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession))
// both server and client accepted a session. Close now.
sess.Close(nil)
serverSession.Close(nil)
sess.Close()
serverSession.Close()
},
}

Expand Down
2 changes: 1 addition & 1 deletion integrationtests/self/rtt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ var _ = Describe("non-zero RTT", func() {
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(testserver.PRData))
sess.Close(nil)
sess.Close()
Eventually(done).Should(BeClosed())
})
}
Expand Down
2 changes: 1 addition & 1 deletion integrationtests/self/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ var _ = Describe("Bidirectional streams", func() {
sess, err := server.Accept()
Expect(err).ToNot(HaveOccurred())
runSendingPeer(sess)
sess.Close(nil)
sess.Close()
}()

client, err := quic.DialAddr(serverAddr, nil, qconf)
Expand Down
2 changes: 1 addition & 1 deletion integrationtests/self/uni_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ var _ = Describe("Unidirectional Streams", func() {
sess, err = server.Accept()
Expect(err).ToNot(HaveOccurred())
runReceivingPeer(sess)
sess.Close(nil)
sess.Close()
}()

client, err := quic.DialAddr(serverAddr, nil, qconf)
Expand Down
6 changes: 4 additions & 2 deletions interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,10 @@ type Session interface {
LocalAddr() net.Addr
// RemoteAddr returns the address of the peer.
RemoteAddr() net.Addr
// Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent.
Close(error) error
// Close the connection.
io.Closer
// Close the connection with an error.
CloseWithError(error) error
// The context is cancelled when the session is closed.
// Warning: This API should not be considered stable and might change soon.
Context() context.Context
Expand Down
10 changes: 6 additions & 4 deletions mock_packet_handler_manager_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 26 additions & 4 deletions mock_quic_session_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 8b2992a

Please sign in to comment.