diff --git a/agent/websockets/connection.go b/agent/websockets/connection.go index 021c941..6be922b 100644 --- a/agent/websockets/connection.go +++ b/agent/websockets/connection.go @@ -57,7 +57,7 @@ func (m *message) Serialize(version int) interface{} { // and encapsulates it in an API that is a little more amenable to how the server side // of our websocket shim is implemented. type Connection struct { - ctx context.Context + done func() <-chan struct{} cancel context.CancelFunc clientMessages chan *message serverMessages chan *message @@ -107,12 +107,9 @@ func NewConnection(ctx context.Context, targetURL string, header http.Header, er // push messages. That way our handling of reads and writes are consistent. clientMessages := make(chan *message, 10) - closeConn := make(chan bool) go func() { - defer func() { - close(serverMessages) - closeConn <- true - }() + defer close(serverMessages) + defer cancel() for { select { case <-ctx.Done(): @@ -133,9 +130,7 @@ func NewConnection(ctx context.Context, targetURL string, header http.Header, er } }() go func() { - defer func() { - closeConn <- true - }() + defer cancel() for { select { case <-ctx.Done(): @@ -157,16 +152,14 @@ func NewConnection(ctx context.Context, targetURL string, header http.Header, er } }() go func() { - <-closeConn - // if either routines finishes, terminate the other - cancel() + <-ctx.Done() // closing the serverConn. This will cause serverConn.ReadMessage to stop. if err := serverConn.Close(); err != nil { errCallback(fmt.Errorf("failure closing a server websocket connection: %v", err)) } }() return &Connection{ - ctx: ctx, + done: ctx.Done, cancel: cancel, clientMessages: clientMessages, serverMessages: serverMessages, @@ -226,7 +219,7 @@ func (conn *Connection) SendClientMessage(msg interface{}, injectionEnabled bool } } select { - case <-conn.ctx.Done(): + case <-conn.done(): return fmt.Errorf("attempt to send a client message on a closed websocket connection") default: conn.clientMessages <- clientMessage