diff --git a/dialer.go b/dialer.go index e66678e..d4d80a5 100644 --- a/dialer.go +++ b/dialer.go @@ -31,6 +31,9 @@ type Handshake struct { // Extensions is the list of negotiated extensions. Extensions []httphead.Option + + // Header all request headers obtained during handshake + Header http.Header } // Errors used by the websocket client. diff --git a/server.go b/server.go index 863bb22..edb4ff1 100644 --- a/server.go +++ b/server.go @@ -109,6 +109,10 @@ func Upgrade(conn io.ReadWriter) (Handshake, error) { // HTTPUpgrader contains options for upgrading connection to websocket from // net/http Handler arguments. type HTTPUpgrader struct { + // CopyHeadersToHandshake setting specifies whether headers should be preserved during the handshake process. + // If enabled, the headers will be copied to Handshake.Header. + CopyHeadersToHandshake bool + // Timeout is the maximum amount of time an Upgrade() will spent while // writing handshake response. // @@ -241,6 +245,12 @@ func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net. if h := u.Header; h != nil { header[0] = HandshakeHeaderHTTP(h) } + + if u.CopyHeadersToHandshake { + // set handshake header + hs.Header = r.Header + } + if err == nil { httpWriteResponseUpgrade(rw.Writer, strToBytes(nonce), hs, header.WriteTo) err = rw.Writer.Flush() @@ -262,6 +272,10 @@ func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net. // Upgrader contains options for upgrading connection to websocket. type Upgrader struct { + // CopyHeadersToHandshake setting specifies whether headers should be preserved during the handshake process. + // If enabled, the headers will be copied to Handshake.Header. + CopyHeadersToHandshake bool + // ReadBufferSize and WriteBufferSize is an I/O buffer sizes. // They used to read and write http data while upgrading to WebSocket. // Allocated buffers are pooled with sync.Pool to avoid extra allocations. @@ -498,6 +512,12 @@ func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) { nonce = make([]byte, nonceSize) ) + + if u.CopyHeadersToHandshake { + // init handshake headers + hs.Header = make(http.Header) + } + for err == nil { line, e := readLine(br) if e != nil { @@ -514,6 +534,11 @@ func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) { break } + if u.CopyHeadersToHandshake { + // copy and add header + hs.Header.Add(btsToString(bytes.Clone(k)), btsToString(bytes.Clone(v))) + } + switch btsToString(k) { case headerHostCanonical: headerSeen |= headerSeenHost