blob: 3966bbee9b0c84934bc207a7d51a8a74a1d7d2c9 [file] [log] [blame]
From d7117da84b89be273699faa2131d83c475c3ae03 Mon Sep 17 00:00:00 2001
From: David Stevens <stevensd@chromium.org>
Date: Wed, 13 Nov 2019 15:38:58 +0900
Subject: [PATCH] Revert "rpcc: Repeat the error that closed the connection
(#74)"
This reverts commit 15d95f6d349e0fca451c2b5a02382c1ae5622c76.
---
rpcc/conn.go | 66 +++++++++++++++------------------------------
rpcc/conn_test.go | 54 -------------------------------------
rpcc/stream.go | 2 +-
rpcc/stream_test.go | 4 +--
4 files changed, 25 insertions(+), 101 deletions(-)
diff --git a/rpcc/conn.go b/rpcc/conn.go
index a37bf04..bcb327c 100644
--- a/rpcc/conn.go
+++ b/rpcc/conn.go
@@ -240,17 +240,18 @@ type Conn struct {
compressionLevel func(level int) error
mu sync.Mutex // Protects following.
+ closed bool
reqSeq uint64
pending map[uint64]*rpcCall
- streams map[string]*streamClients
- closed bool
- err error // Protected by mu and closed until context is cancelled.
reqMu sync.Mutex // Protects following.
req Request
// Encodes and decodes JSON onto conn. Encoding is
// guarded by mutex and decoding is done by recv.
codec Codec
+
+ streamMu sync.Mutex // Protects following.
+ streams map[string]*streamClients
}
// Response represents an RPC response or notification sent by the server.
@@ -377,7 +378,7 @@ func (c *Conn) send(ctx context.Context, call *rpcCall) (err error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
- return c.err
+ return ErrConnClosing
}
c.reqSeq++
reqID := c.reqSeq
@@ -429,25 +430,26 @@ func (c *Conn) send(ctx context.Context, call *rpcCall) (err error) {
// notify handles RPC notifications and sends them
// to the appropriate stream listeners.
func (c *Conn) notify(method string, data []byte) {
- c.mu.Lock()
+ c.streamMu.Lock()
stream := c.streams[method]
+ c.streamMu.Unlock()
+
if stream != nil {
// Stream writer must be able to handle incoming writes
// even after it has been removed (unsubscribed).
stream.write(method, data)
}
- c.mu.Unlock()
}
// listen registers a new stream listener (chan) for the RPC notification
// method. Returns a function for removing the listener. Error if the
// connection is closed.
func (c *Conn) listen(method string, w streamWriter) (func(), error) {
- c.mu.Lock()
- defer c.mu.Unlock()
+ c.streamMu.Lock()
+ defer c.streamMu.Unlock()
if c.streams == nil {
- return nil, c.err
+ return nil, ErrConnClosing
}
stream, ok := c.streams[method]
@@ -461,59 +463,35 @@ func (c *Conn) listen(method string, w streamWriter) (func(), error) {
return unsub, nil
}
-type closeError struct {
- msg string
- err error
-}
-
-func (e *closeError) Cause() error {
- return e.err
-}
-
-func (e *closeError) Error() string {
- return fmt.Sprintf("%s: %v", e.msg, e.err)
-}
-
-// Close closes the connection. Subsequent calls to Close will return the error
-// that closed the connection.
+// Close closes the connection.
func (c *Conn) close(err error) error {
- c.mu.Lock()
- defer c.mu.Unlock()
+ c.cancel()
+ c.mu.Lock()
if c.closed {
- return c.err
+ c.mu.Unlock()
+ return ErrConnClosing
}
c.closed = true
if err == nil {
err = ErrConnClosing
- } else {
- err = &closeError{msg: ErrConnClosing.Error(), err: err}
}
- c.err = err
for id, call := range c.pending {
delete(c.pending, id)
call.done(err)
}
- // Stop sending on all streams by signaling
- // that the connection is closed.
+ c.mu.Unlock()
+
+ // Stop sending on all streams.
+ c.streamMu.Lock()
c.streams = nil
+ c.streamMu.Unlock()
// Conn can be nil if DialContext did not complete.
if c.conn != nil {
- wserr := c.conn.Close()
- if wserr != nil && err == ErrConnClosing {
- err = wserr
- c.err = &closeError{msg: ErrConnClosing.Error(), err: err}
- }
+ err = c.conn.Close()
}
- // Delay cancel until c.err has settled, at this point any active
- // streams will be closed.
- c.cancel()
-
- if err == ErrConnClosing {
- return nil
- }
return err
}
diff --git a/rpcc/conn_test.go b/rpcc/conn_test.go
index 1450830..0cc4b3b 100644
--- a/rpcc/conn_test.go
+++ b/rpcc/conn_test.go
@@ -339,60 +339,6 @@ func TestConn_StreamRecv(t *testing.T) {
}
}
-func TestConn_PropagateError(t *testing.T) {
- srv := newTestServer(t, nil)
- defer srv.Close()
-
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
-
- s1, err := NewStream(ctx, "test.Stream1", srv.conn)
- if err != nil {
- t.Fatal(err)
- }
- defer s1.Close()
- s2, err := NewStream(ctx, "test.Stream2", srv.conn)
- if err != nil {
- t.Fatal(err)
- }
- defer s2.Close()
-
- errC := make(chan error, 2)
- go func() {
- errC <- Invoke(ctx, "test.Invoke", nil, nil, srv.conn)
- }()
- go func() {
- var reply string
- errC <- s1.RecvMsg(&reply)
- }()
-
- // Give a little time for both Invoke & Recv.
- time.Sleep(5 * time.Millisecond)
-
- srv.wsConn.Close()
-
- // Give a little time for connection to close.
- time.Sleep(5 * time.Millisecond)
-
- lastErr := Invoke(ctx, "test.Invoke", nil, nil, srv.conn)
- if lastErr == nil {
- t.Error("RecvMsg on closed connection: got nil, want an error")
- }
-
- var reply string
- err = s2.RecvMsg(&reply)
- if err != lastErr {
- t.Errorf("Error was not repeated, got %v, want %v", err, lastErr)
- }
-
- for i := 0; i < 2; i++ {
- err := <-errC
- if err != lastErr {
- t.Errorf("Error was not repeated, got %v, want %v", err, lastErr)
- }
- }
-}
-
func TestConn_Context(t *testing.T) {
srv := newTestServer(t, nil)
defer srv.Close()
diff --git a/rpcc/stream.go b/rpcc/stream.go
index 367c147..507d2c4 100644
--- a/rpcc/stream.go
+++ b/rpcc/stream.go
@@ -172,7 +172,7 @@ func (s *streamClient) watch() {
case <-s.ctx.Done():
s.close(s.ctx.Err())
case <-s.conn.ctx.Done():
- s.close(s.conn.err)
+ s.close(ErrConnClosing)
case <-s.done:
}
}
diff --git a/rpcc/stream_test.go b/rpcc/stream_test.go
index 2758e02..7d25007 100644
--- a/rpcc/stream_test.go
+++ b/rpcc/stream_test.go
@@ -198,7 +198,7 @@ func TestStream_RecvAfterConnClose(t *testing.T) {
conn.notify("test", []byte(`"message2"`))
conn.notify("test", []byte(`"message3"`))
- conn.Close()
+ connCancel()
for i := 0; i < 3; i++ {
var reply string
@@ -210,7 +210,7 @@ func TestStream_RecvAfterConnClose(t *testing.T) {
err = s.RecvMsg(nil)
if err != ErrConnClosing {
- t.Errorf("err got %v, want %v", err, ErrConnClosing)
+ t.Errorf("err got %v, want ErrConnClosing", err)
}
}
--
2.24.0.rc1.363.gb1bccd3e3d-goog