Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions agent/websockets/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ type Connection struct {
clientMessages chan *message
serverMessages chan *message
protocolVersion int
subprotocol string
}

// This map defines the set of headers that should be stripped from the WS request, as they
Expand Down Expand Up @@ -169,6 +170,7 @@ func NewConnection(ctx context.Context, targetURL string, header http.Header, er
cancel: cancel,
clientMessages: clientMessages,
serverMessages: serverMessages,
subprotocol: serverConn.Subprotocol(),
}, nil
}

Expand Down Expand Up @@ -264,6 +266,11 @@ func (conn *Connection) ReadServerMessages() ([]interface{}, error) {
}
}

// Subprotocol reports the websocket subprotocol (if any) that was accepted by the server.
func (conn *Connection) Subprotocol() string {
return conn.subprotocol
}

// injectWebsocketMessage injects a shim header value into a single websocket message in-place.
// Returns a pointer to a new copy of the struct on success.
func injectWebsocketMessage(msg *message, injectionPath []string, injectionValues map[string]string) (*message, error) {
Expand Down
8 changes: 5 additions & 3 deletions agent/websockets/shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,10 @@ func ShimBody(shimPath string) (func(resp *http.Response) error, error) {
}

type sessionMessage struct {
ID string `json:"id,omitempty"`
Message interface{} `json:"msg,omitempty"`
Version int `json:"v,omitempty"`
ID string `json:"id,omitempty"`
Message interface{} `json:"msg,omitempty"`
Version int `json:"v,omitempty"`
Subprotocol string `json:"s,omitempty"`
}

func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost bool, openWebsocketWrapper func(http.Handler, *metrics.MetricHandler) http.Handler, enableWebsocketInjection bool, metricHandler *metrics.MetricHandler) http.Handler {
Expand Down Expand Up @@ -330,6 +331,7 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b
ID: sessionID,
Message: targetURL.String(),
Version: conn.protocolVersion,
Subprotocol: conn.Subprotocol(),
}
respBytes, err := json.Marshal(resp)
if err != nil {
Expand Down
Loading