MakyuuIchaival/httputils/nethttp.go

246 lines
6 KiB
Go

package httputils
import (
"crypto/tls"
"fmt"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/http3"
"github.com/valyala/fasthttp"
"io"
"net/http"
"net/textproto"
"strconv"
"strings"
"time"
)
type NetHTTPContext struct {
httpWriter http.ResponseWriter
httpRequest *http.Request
server *Server
connTime time.Time
requestTime time.Time
tlsState tls.ConnectionState
timingEvents uint
}
func NewRequestContextFromHttp(server *Server, w http.ResponseWriter, r *http.Request) *NetHTTPContext {
return &NetHTTPContext{
httpWriter: w,
httpRequest: r,
connTime: time.Now(),
requestTime: time.Now(),
server: server,
tlsState: getTLSState(w, r),
}
}
func getTLSState(w http.ResponseWriter, r *http.Request) tls.ConnectionState {
if hj, ok := w.(http3.Hijacker); ok {
if conn, ok := hj.StreamCreator().(quic.Connection); ok {
return conn.ConnectionState().TLS.ConnectionState
}
}
if r.TLS != nil {
return *r.TLS
}
return tls.ConnectionState{}
}
func (c *NetHTTPContext) GetExtraHeaders() map[string]string {
return c.server.GetExtraHeaders()
}
func (c *NetHTTPContext) AddTiming(name string, desc string, d time.Duration) {
if d < 0 {
d = 0
}
c.AddResponseHeader("Server-Timing", fmt.Sprintf("%d_%s;desc=\"%s\";dur=%.6F", c.timingEvents, name, desc, float64(d.Nanoseconds())/1e6))
c.timingEvents++
}
func (c *NetHTTPContext) AddTimingInformational(name string, desc string) {
c.AddResponseHeader("Server-Timing", fmt.Sprintf("%d_%s;desc=\"%s\"", c.timingEvents, name, desc))
c.timingEvents++
}
func (c *NetHTTPContext) GetPath() string {
return c.httpRequest.URL.Path
}
func (c *NetHTTPContext) GetConnectionTime() time.Time {
return c.connTime
}
func (c *NetHTTPContext) GetRequestTime() time.Time {
return c.requestTime
}
func (c *NetHTTPContext) GetTLSServerName() string {
return c.tlsState.ServerName
}
func (c *NetHTTPContext) GetHost() string {
return c.httpRequest.Host
}
func (c *NetHTTPContext) GetProtocol() string {
return c.httpRequest.Proto
}
func (c *NetHTTPContext) GetTLSVersion() uint16 {
return c.tlsState.Version
}
func (c *NetHTTPContext) GetTLSCipher() uint16 {
return c.tlsState.CipherSuite
}
func (c *NetHTTPContext) GetRequestHeader(name string) string {
return c.httpRequest.Header.Get(name)
}
func (c *NetHTTPContext) GetResponseHeader(name string) string {
return c.httpWriter.Header().Get(name)
}
func (c *NetHTTPContext) AddResponseHeader(name string, value string) {
c.httpWriter.Header().Add(name, value)
}
func (c *NetHTTPContext) SetResponseHeader(name string, value string) {
c.httpWriter.Header().Set(name, value)
}
func (c *NetHTTPContext) ServeStream(stream Stream) {
defer stream.Close()
const rangePrefix = "bytes="
if stream == nil {
c.SetResponseCode(fasthttp.StatusInternalServerError)
return
}
if dstream, ok := stream.(DefinedStream); ok {
c.SetResponseHeader("Last-Modified", dstream.ModTime().UTC().Format(modTimeFormat))
c.SetResponseHeader("Accept-Ranges", "bytes")
size := dstream.Size()
rangeHeader := c.GetRequestHeader("range")
if len(rangeHeader) > 0 {
if !strings.HasPrefix(rangeHeader, rangePrefix) {
c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable)
return
}
ranges := strings.Split(rangeHeader[len(rangePrefix):], ",")
if len(ranges) != 1 {
c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable)
return
}
ra := ranges[0]
start, end, ok := strings.Cut(ra, "-")
if !ok {
c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable)
return
}
start, end = textproto.TrimString(start), textproto.TrimString(end)
var rangeStart, rangeLength int64
if start == "" {
//Only supports forward ranges, not backward ranges
c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable)
return
} else {
i, err := strconv.ParseInt(start, 10, 64)
if err != nil || i < 0 {
//Only supports forward ranges, not backward ranges
}
if i >= size {
// If the range begins after the size of the content,
// then it does not overlap.
c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable)
return
}
rangeStart = i
if end == "" {
// If no end is specified, range extends to end of the file.
rangeLength = size - rangeStart
} else {
i, err := strconv.ParseInt(end, 10, 64)
if err != nil || rangeStart > i {
c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable)
return
}
if i >= size {
i = size - 1
}
rangeLength = i - rangeStart + 1
}
}
if _, err := dstream.Seek(rangeStart, io.SeekStart); err != nil {
c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable)
return
}
c.SetResponseHeader("Content-Range", fmt.Sprintf("bytes %d-%d/%d", rangeStart, rangeStart+rangeLength-1, size))
size = rangeLength
c.SetResponseHeader("Content-Length", strconv.FormatInt(size, 10))
c.SetResponseCode(fasthttp.StatusPartialContent)
} else {
c.SetResponseHeader("Content-Length", strconv.FormatInt(size, 10))
c.SetResponseCode(fasthttp.StatusOK)
}
if !c.IsHead() {
io.CopyN(c.httpWriter, dstream, size)
}
} else {
c.SetResponseCode(fasthttp.StatusOK)
if !c.IsHead() {
c.SetResponseHeader("Transfer-Encoding", "chunked")
io.Copy(c.httpWriter, dstream)
}
}
}
func (c *NetHTTPContext) ServeBytes(content []byte) {
c.httpWriter.Write(content)
}
func (c *NetHTTPContext) SetResponseCode(code int) {
c.httpWriter.WriteHeader(code)
}
func (c *NetHTTPContext) DoRedirect(location string, code int) {
http.Redirect(c.httpWriter, c.httpRequest, location, code)
}
func (c *NetHTTPContext) GetBody() io.Reader {
return c.httpRequest.Body
}
func (c *NetHTTPContext) IsGet() bool {
return c.httpRequest.Method == "GET"
}
func (c *NetHTTPContext) IsPost() bool {
return c.httpRequest.Method == "POST"
}
func (c *NetHTTPContext) IsOptions() bool {
return c.httpRequest.Method == "OPTIONS"
}
func (c *NetHTTPContext) IsHead() bool {
return c.httpRequest.Method == "HEAD"
}