package httputils import ( "crypto/tls" "fmt" "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/http3" "github.com/valyala/fasthttp" "io" "net/http" "net/textproto" "strconv" "strings" "time" ) type NetHTTPContext struct { httpWriter http.ResponseWriter httpRequest *http.Request server *Server connTime time.Time requestTime time.Time tlsState tls.ConnectionState timingEvents uint } func NewRequestContextFromHttp(server *Server, w http.ResponseWriter, r *http.Request) *NetHTTPContext { return &NetHTTPContext{ httpWriter: w, httpRequest: r, connTime: time.Now(), requestTime: time.Now(), server: server, tlsState: getTLSState(w, r), } } func getTLSState(w http.ResponseWriter, r *http.Request) tls.ConnectionState { if hj, ok := w.(http3.Hijacker); ok { if conn, ok := hj.StreamCreator().(quic.Connection); ok { return conn.ConnectionState().TLS.ConnectionState } } if r.TLS != nil { return *r.TLS } return tls.ConnectionState{} } func (c *NetHTTPContext) GetExtraHeaders() map[string]string { return c.server.GetExtraHeaders() } func (c *NetHTTPContext) AddTiming(name string, desc string, d time.Duration) { if d < 0 { d = 0 } c.AddResponseHeader("Server-Timing", fmt.Sprintf("%d_%s;desc=\"%s\";dur=%.6F", c.timingEvents, name, desc, float64(d.Nanoseconds())/1e6)) c.timingEvents++ } func (c *NetHTTPContext) AddTimingInformational(name string, desc string) { c.AddResponseHeader("Server-Timing", fmt.Sprintf("%d_%s;desc=\"%s\"", c.timingEvents, name, desc)) c.timingEvents++ } func (c *NetHTTPContext) GetPath() string { return c.httpRequest.URL.Path } func (c *NetHTTPContext) GetConnectionTime() time.Time { return c.connTime } func (c *NetHTTPContext) GetRequestTime() time.Time { return c.requestTime } func (c *NetHTTPContext) GetTLSServerName() string { return c.tlsState.ServerName } func (c *NetHTTPContext) GetHost() string { return c.httpRequest.Host } func (c *NetHTTPContext) GetProtocol() string { return c.httpRequest.Proto } func (c *NetHTTPContext) GetTLSVersion() uint16 { return c.tlsState.Version } func (c *NetHTTPContext) GetTLSCipher() uint16 { return c.tlsState.CipherSuite } func (c *NetHTTPContext) GetRequestHeader(name string) string { return c.httpRequest.Header.Get(name) } func (c *NetHTTPContext) GetResponseHeader(name string) string { return c.httpWriter.Header().Get(name) } func (c *NetHTTPContext) AddResponseHeader(name string, value string) { c.httpWriter.Header().Add(name, value) } func (c *NetHTTPContext) SetResponseHeader(name string, value string) { c.httpWriter.Header().Set(name, value) } func (c *NetHTTPContext) ServeStream(stream Stream) { defer stream.Close() const rangePrefix = "bytes=" if stream == nil { c.SetResponseCode(fasthttp.StatusInternalServerError) return } if dstream, ok := stream.(DefinedStream); ok { c.SetResponseHeader("Last-Modified", dstream.ModTime().UTC().Format(modTimeFormat)) c.SetResponseHeader("Accept-Ranges", "bytes") size := dstream.Size() rangeHeader := c.GetRequestHeader("range") if len(rangeHeader) > 0 { if !strings.HasPrefix(rangeHeader, rangePrefix) { c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) return } ranges := strings.Split(rangeHeader[len(rangePrefix):], ",") if len(ranges) != 1 { c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) return } ra := ranges[0] start, end, ok := strings.Cut(ra, "-") if !ok { c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) return } start, end = textproto.TrimString(start), textproto.TrimString(end) var rangeStart, rangeLength int64 if start == "" { //Only supports forward ranges, not backward ranges c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) return } else { i, err := strconv.ParseInt(start, 10, 64) if err != nil || i < 0 { //Only supports forward ranges, not backward ranges } if i >= size { // If the range begins after the size of the content, // then it does not overlap. c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) return } rangeStart = i if end == "" { // If no end is specified, range extends to end of the file. rangeLength = size - rangeStart } else { i, err := strconv.ParseInt(end, 10, 64) if err != nil || rangeStart > i { c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) return } if i >= size { i = size - 1 } rangeLength = i - rangeStart + 1 } } if _, err := dstream.Seek(rangeStart, io.SeekStart); err != nil { c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) return } c.SetResponseHeader("Content-Range", fmt.Sprintf("bytes %d-%d/%d", rangeStart, rangeStart+rangeLength-1, size)) size = rangeLength c.SetResponseHeader("Content-Length", strconv.FormatInt(size, 10)) c.SetResponseCode(fasthttp.StatusPartialContent) } else { c.SetResponseHeader("Content-Length", strconv.FormatInt(size, 10)) c.SetResponseCode(fasthttp.StatusOK) } if !c.IsHead() { io.CopyN(c.httpWriter, dstream, size) } } else { c.SetResponseCode(fasthttp.StatusOK) if !c.IsHead() { c.SetResponseHeader("Transfer-Encoding", "chunked") io.Copy(c.httpWriter, dstream) } } } func (c *NetHTTPContext) ServeBytes(content []byte) { c.httpWriter.Write(content) } func (c *NetHTTPContext) SetResponseCode(code int) { c.httpWriter.WriteHeader(code) } func (c *NetHTTPContext) DoRedirect(location string, code int) { http.Redirect(c.httpWriter, c.httpRequest, location, code) } func (c *NetHTTPContext) GetBody() io.Reader { return c.httpRequest.Body } func (c *NetHTTPContext) IsGet() bool { return c.httpRequest.Method == "GET" } func (c *NetHTTPContext) IsPost() bool { return c.httpRequest.Method == "POST" } func (c *NetHTTPContext) IsOptions() bool { return c.httpRequest.Method == "OPTIONS" } func (c *NetHTTPContext) IsHead() bool { return c.httpRequest.Method == "HEAD" }