diff --git a/httputils/nethttp.go b/httputils/nethttp.go index f884e7c..36b70b4 100644 --- a/httputils/nethttp.go +++ b/httputils/nethttp.go @@ -1,7 +1,10 @@ package httputils import ( + "crypto/tls" "fmt" + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/http3" "github.com/valyala/fasthttp" "io" "net/http" @@ -17,6 +20,7 @@ type NetHTTPContext struct { server *Server connTime time.Time requestTime time.Time + tlsState tls.ConnectionState timingEvents uint } @@ -27,9 +31,23 @@ func NewRequestContextFromHttp(server *Server, w http.ResponseWriter, r *http.Re connTime: time.Now(), requestTime: time.Now(), server: server, + tlsState: getTLSState(w, r), } } +func getTLSState(w http.ResponseWriter, r *http.Request) tls.ConnectionState { + if hj, ok := w.(http3.Hijacker); ok { + if conn, ok := hj.StreamCreator().(quic.Connection); ok { + return conn.ConnectionState().TLS.ConnectionState + } + } + if r.TLS != nil { + return *r.TLS + } + + return tls.ConnectionState{} +} + func (c *NetHTTPContext) GetExtraHeaders() map[string]string { return c.server.GetExtraHeaders() } @@ -62,10 +80,7 @@ func (c *NetHTTPContext) GetRequestTime() time.Time { } func (c *NetHTTPContext) GetTLSServerName() string { - if c.httpRequest.TLS != nil { - return c.httpRequest.TLS.ServerName - } - return "" + return c.tlsState.ServerName } func (c *NetHTTPContext) GetHost() string { @@ -77,19 +92,11 @@ func (c *NetHTTPContext) GetProtocol() string { } func (c *NetHTTPContext) GetTLSVersion() uint16 { - if c.httpRequest.TLS != nil { - return c.httpRequest.TLS.Version - } - - return 0 + return c.tlsState.Version } func (c *NetHTTPContext) GetTLSCipher() uint16 { - if c.httpRequest.TLS != nil { - return c.httpRequest.TLS.CipherSuite - } - - return 0 + return c.tlsState.CipherSuite } func (c *NetHTTPContext) GetRequestHeader(name string) string {