diff --git a/go.mod b/go.mod index ac3a853..ad9b876 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/cheekybits/genny v1.0.0 // indirect github.com/fsnotify/fsnotify v1.5.4 // indirect github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect - github.com/klauspost/compress v1.15.5 // indirect + github.com/klauspost/compress v1.15.6 // indirect github.com/klauspost/cpuid/v2 v2.0.12 // indirect github.com/marten-seemann/qpack v0.2.1 // indirect github.com/marten-seemann/qtls-go1-16 v0.1.5 // indirect @@ -36,7 +36,7 @@ require ( github.com/valyala/fastrand v1.1.0 // indirect golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e // indirect golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect - golang.org/x/net v0.0.0-20220526153639-5463443f8c37 // indirect + golang.org/x/net v0.0.0-20220531201128-c960675eff93 // indirect golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect golang.org/x/text v0.3.7 // indirect golang.org/x/tools v0.1.10 // indirect diff --git a/go.sum b/go.sum index be6c680..679be92 100644 --- a/go.sum +++ b/go.sum @@ -81,8 +81,8 @@ github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCV github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/klauspost/compress v1.15.5 h1:qyCLMz2JCrKADihKOh9FxnW3houKeNsp2h5OEz0QSEA= -github.com/klauspost/compress v1.15.5/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= +github.com/klauspost/compress v1.15.6 h1:6D9PcO8QWu0JyaQ2zUMmu16T1T+zjjEpP91guRsvDfY= +github.com/klauspost/compress v1.15.6/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= github.com/klauspost/cpuid/v2 v2.0.4/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.12 h1:p9dKCg8i4gmOxtv35DvrYoWqYzQrvEVdjQ762Y0OqZE= @@ -241,8 +241,8 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.0.0-20220526153639-5463443f8c37 h1:lUkvobShwKsOesNfWWlCS5q7fnbG1MEliIzwu886fn8= -golang.org/x/net v0.0.0-20220526153639-5463443f8c37/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220531201128-c960675eff93 h1:MYimHLfoXEpOhqd/zgoA/uoXzHB86AEky4LAx5ij9xA= +golang.org/x/net v0.0.0-20220531201128-c960675eff93/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= diff --git a/httputils/context.go b/httputils/context.go deleted file mode 100644 index ca227ee..0000000 --- a/httputils/context.go +++ /dev/null @@ -1,223 +0,0 @@ -package httputils - -import ( - "bytes" - "fmt" - "github.com/valyala/fasthttp" - "io" - "net/http" - "time" -) - -var fsHandler = (&fasthttp.FS{ - Root: "/", - AcceptByteRange: true, - Compress: false, - CompressBrotli: false, - CacheDuration: time.Minute * 15, - PathRewrite: func(ctx *fasthttp.RequestCtx) []byte { - return ctx.Request.URI().PathOriginal() - }, -}).NewRequestHandler() - -type RequestContext struct { - fasthttp *fasthttp.RequestCtx - httpWriter *http.ResponseWriter - httpRequest *http.Request - server *Server - connTime time.Time - requestTime time.Time - timingEvents uint -} - -func NewRequestContextFromFastHttp(server *Server, ctx *fasthttp.RequestCtx) RequestContext { - return RequestContext{ - fasthttp: ctx, - connTime: ctx.ConnTime(), - requestTime: ctx.Time(), - server: server, - } -} - -func NewRequestContextFromHttp(server *Server, w http.ResponseWriter, r *http.Request) RequestContext { - return RequestContext{ - httpWriter: &w, - httpRequest: r, - connTime: time.Now(), - requestTime: time.Now(), - server: server, - } -} - -func (c *RequestContext) GetExtraHeaders() map[string]string { - return c.server.GetExtraHeaders() -} - -func (c *RequestContext) AddTiming(name string, desc string, d time.Duration) { - if d < 0 { - d = 0 - } - - c.AddResponseHeader("Server-Timing", fmt.Sprintf("%d_%s;desc=\"%s\";dur=%.6F", c.timingEvents, name, desc, float64(d.Nanoseconds())/1e6)) - c.timingEvents++ -} - -func (c *RequestContext) AddTimingInformational(name string, desc string) { - c.AddResponseHeader("Server-Timing", fmt.Sprintf("%d_%s;desc=\"%s\"", c.timingEvents, name, desc)) - c.timingEvents++ -} - -func (c *RequestContext) GetPath() string { - if c.fasthttp != nil { - return string(c.fasthttp.Path()) - } else if c.httpRequest != nil { - return c.httpRequest.URL.Path - } - - return "" -} - -func (c *RequestContext) GetConnectionTime() time.Time { - return c.connTime -} - -func (c *RequestContext) GetRequestTime() time.Time { - return c.requestTime -} - -func (c *RequestContext) IsFastHttp() bool { - return c.fasthttp != nil -} - -func (c *RequestContext) GetTLSServerName() string { - if c.fasthttp != nil { - return c.fasthttp.TLSConnectionState().ServerName - } else if c.httpRequest != nil { - return c.httpRequest.TLS.ServerName - } - - return "" -} - -func (c *RequestContext) GetRequestHeader(name string) string { - if c.fasthttp != nil { - return string(c.fasthttp.Request.Header.Peek(name)) - } else if c.httpRequest != nil { - return c.httpRequest.Header.Get(name) - } - - return "" -} - -func (c *RequestContext) GetResponseHeader(name string) string { - if c.fasthttp != nil { - return string(c.fasthttp.Response.Header.Peek(name)) - } else if c.httpWriter != nil { - return (*c.httpWriter).Header().Get(name) - } - - return "" -} - -func (c *RequestContext) AddResponseHeader(name string, value string) { - if c.fasthttp != nil { - c.fasthttp.Response.Header.Add(name, value) - } else if c.httpWriter != nil { - (*c.httpWriter).Header().Add(name, value) - } -} - -func (c *RequestContext) SetResponseHeader(name string, value string) { - if c.fasthttp != nil { - c.fasthttp.Response.Header.Set(name, value) - } else if c.httpWriter != nil { - (*c.httpWriter).Header().Set(name, value) - } -} - -func (c *RequestContext) ServeFile(path string) { - if c.fasthttp != nil { - c.fasthttp.Request.URI().Reset() - c.fasthttp.Request.URI().SetPath(path) - fsHandler(c.fasthttp) - } else if c.httpWriter != nil { - http.ServeFile(*c.httpWriter, c.httpRequest, path) - } -} - -func (c *RequestContext) ServeBytes(content []byte) { - if c.fasthttp != nil { - c.fasthttp.Write(content) - } else if c.httpWriter != nil { - (*c.httpWriter).Write(content) - } -} - -func (c *RequestContext) SetResponseCode(code int) { - if c.fasthttp != nil { - c.fasthttp.Response.SetStatusCode(code) - } else if c.httpWriter != nil { - (*c.httpWriter).WriteHeader(code) - } -} - -func (c *RequestContext) DoRedirect(location string, code int) { - if c.fasthttp != nil { - c.fasthttp.Redirect(location, code) - } else if c.httpWriter != nil { - http.Redirect(*c.httpWriter, c.httpRequest, location, code) - } -} - -func (c *RequestContext) GetBody() io.Reader { - if c.fasthttp != nil { - b := c.fasthttp.Request.Body() - buf := make([]byte, len(b)) - copy(buf, b) - return bytes.NewBuffer(buf) - } else if c.httpRequest != nil { - return c.httpRequest.Body - } - - return nil -} - -func (c *RequestContext) IsGet() bool { - if c.fasthttp != nil { - return c.fasthttp.IsGet() - } else if c.httpRequest != nil { - return c.httpRequest.Method == "GET" - } - - return false -} - -func (c *RequestContext) IsPost() bool { - if c.fasthttp != nil { - return c.fasthttp.IsPost() - } else if c.httpRequest != nil { - return c.httpRequest.Method == "POST" - } - - return false -} - -func (c *RequestContext) IsOptions() bool { - if c.fasthttp != nil { - return c.fasthttp.IsOptions() - } else if c.httpRequest != nil { - return c.httpRequest.Method == "OPTIONS" - } - - return false -} - -func (c *RequestContext) IsHead() bool { - if c.fasthttp != nil { - return c.fasthttp.IsHead() - } else if c.httpRequest != nil { - return c.httpRequest.Method == "HEAD" - } - - return false -} diff --git a/httputils/fasthttp.go b/httputils/fasthttp.go new file mode 100644 index 0000000..d5facb6 --- /dev/null +++ b/httputils/fasthttp.go @@ -0,0 +1,131 @@ +package httputils + +import ( + "bytes" + "fmt" + "github.com/valyala/fasthttp" + "io" + "time" +) + +var fsHandler = (&fasthttp.FS{ + Root: "/", + AcceptByteRange: true, + Compress: false, + CompressBrotli: false, + CacheDuration: time.Minute * 15, + PathRewrite: func(ctx *fasthttp.RequestCtx) []byte { + return ctx.Request.URI().PathOriginal() + }, +}).NewRequestHandler() + +type FastHTTPContext struct { + ctx *fasthttp.RequestCtx + server *Server + connTime time.Time + requestTime time.Time + timingEvents uint +} + +func NewRequestContextFromFastHttp(server *Server, ctx *fasthttp.RequestCtx) *FastHTTPContext { + if ctx == nil { + return nil + } + return &FastHTTPContext{ + ctx: ctx, + connTime: ctx.ConnTime(), + requestTime: ctx.Time(), + server: server, + } +} + +func (c *FastHTTPContext) GetExtraHeaders() map[string]string { + return c.server.GetExtraHeaders() +} + +func (c *FastHTTPContext) AddTiming(name string, desc string, d time.Duration) { + if d < 0 { + d = 0 + } + + c.AddResponseHeader("Server-Timing", fmt.Sprintf("%d_%s;desc=\"%s\";dur=%.6F", c.timingEvents, name, desc, float64(d.Nanoseconds())/1e6)) + c.timingEvents++ +} + +func (c *FastHTTPContext) AddTimingInformational(name string, desc string) { + c.AddResponseHeader("Server-Timing", fmt.Sprintf("%d_%s;desc=\"%s\"", c.timingEvents, name, desc)) + c.timingEvents++ +} + +func (c *FastHTTPContext) GetPath() string { + return string(c.ctx.Path()) +} + +func (c *FastHTTPContext) GetConnectionTime() time.Time { + return c.connTime +} + +func (c *FastHTTPContext) GetRequestTime() time.Time { + return c.requestTime +} + +func (c *FastHTTPContext) GetTLSServerName() string { + return c.ctx.TLSConnectionState().ServerName +} + +func (c *FastHTTPContext) GetRequestHeader(name string) string { + return string(c.ctx.Request.Header.Peek(name)) +} + +func (c *FastHTTPContext) GetResponseHeader(name string) string { + return string(c.ctx.Response.Header.Peek(name)) +} + +func (c *FastHTTPContext) AddResponseHeader(name string, value string) { + c.ctx.Response.Header.Add(name, value) +} + +func (c *FastHTTPContext) SetResponseHeader(name string, value string) { + c.ctx.Response.Header.Set(name, value) +} + +func (c *FastHTTPContext) ServeFile(path string) { + c.ctx.Request.URI().Reset() + c.ctx.Request.URI().SetPath(path) + fsHandler(c.ctx) +} + +func (c *FastHTTPContext) ServeBytes(content []byte) { + c.ctx.Write(content) +} + +func (c *FastHTTPContext) SetResponseCode(code int) { + c.ctx.Response.SetStatusCode(code) +} + +func (c *FastHTTPContext) DoRedirect(location string, code int) { + c.ctx.Redirect(location, code) +} + +func (c *FastHTTPContext) GetBody() io.Reader { + b := c.ctx.Request.Body() + buf := make([]byte, len(b)) + copy(buf, b) + return bytes.NewBuffer(buf) +} + +func (c *FastHTTPContext) IsGet() bool { + return c.ctx.IsGet() +} + +func (c *FastHTTPContext) IsPost() bool { + return c.ctx.IsPost() +} + +func (c *FastHTTPContext) IsOptions() bool { + return c.ctx.IsOptions() +} + +func (c *FastHTTPContext) IsHead() bool { + return c.ctx.IsHead() +} diff --git a/httputils/nethttp.go b/httputils/nethttp.go new file mode 100644 index 0000000..207abf8 --- /dev/null +++ b/httputils/nethttp.go @@ -0,0 +1,124 @@ +package httputils + +import ( + "fmt" + "io" + "net/http" + "time" +) + +type NetHTTPContext struct { + httpWriter http.ResponseWriter + httpRequest *http.Request + server *Server + connTime time.Time + requestTime time.Time + timingEvents uint +} + +func NewRequestContextFromHttp(server *Server, w http.ResponseWriter, r *http.Request) *NetHTTPContext { + return &NetHTTPContext{ + httpWriter: w, + httpRequest: r, + connTime: time.Now(), + requestTime: time.Now(), + server: server, + } +} + +func (c *NetHTTPContext) GetExtraHeaders() map[string]string { + return c.server.GetExtraHeaders() +} + +func (c *NetHTTPContext) AddTiming(name string, desc string, d time.Duration) { + if d < 0 { + d = 0 + } + + c.AddResponseHeader("Server-Timing", fmt.Sprintf("%d_%s;desc=\"%s\";dur=%.6F", c.timingEvents, name, desc, float64(d.Nanoseconds())/1e6)) + c.timingEvents++ +} + +func (c *NetHTTPContext) AddTimingInformational(name string, desc string) { + c.AddResponseHeader("Server-Timing", fmt.Sprintf("%d_%s;desc=\"%s\"", c.timingEvents, name, desc)) + c.timingEvents++ +} + +func (c *NetHTTPContext) GetPath() string { + return c.httpRequest.URL.Path + +} + +func (c *NetHTTPContext) GetConnectionTime() time.Time { + return c.connTime +} + +func (c *NetHTTPContext) GetRequestTime() time.Time { + return c.requestTime +} + +func (c *NetHTTPContext) GetTLSServerName() string { + + return c.httpRequest.TLS.ServerName + +} + +func (c *NetHTTPContext) GetRequestHeader(name string) string { + return c.httpRequest.Header.Get(name) + +} + +func (c *NetHTTPContext) GetResponseHeader(name string) string { + return c.httpWriter.Header().Get(name) +} + +func (c *NetHTTPContext) AddResponseHeader(name string, value string) { + c.httpWriter.Header().Add(name, value) +} + +func (c *NetHTTPContext) SetResponseHeader(name string, value string) { + c.httpWriter.Header().Set(name, value) +} + +func (c *NetHTTPContext) ServeFile(path string) { + + http.ServeFile(c.httpWriter, c.httpRequest, path) + +} + +func (c *NetHTTPContext) ServeBytes(content []byte) { + + c.httpWriter.Write(content) +} + +func (c *NetHTTPContext) SetResponseCode(code int) { + + c.httpWriter.WriteHeader(code) +} + +func (c *NetHTTPContext) DoRedirect(location string, code int) { + http.Redirect(c.httpWriter, c.httpRequest, location, code) + +} + +func (c *NetHTTPContext) GetBody() io.Reader { + return c.httpRequest.Body +} + +func (c *NetHTTPContext) IsGet() bool { + return c.httpRequest.Method == "GET" +} + +func (c *NetHTTPContext) IsPost() bool { + return c.httpRequest.Method == "POST" + +} + +func (c *NetHTTPContext) IsOptions() bool { + return c.httpRequest.Method == "OPTIONS" + +} + +func (c *NetHTTPContext) IsHead() bool { + return c.httpRequest.Method == "HEAD" +} diff --git a/httputils/server.go b/httputils/server.go index f461b76..83d09e3 100644 --- a/httputils/server.go +++ b/httputils/server.go @@ -6,6 +6,7 @@ import ( "github.com/dgrr/http2" "github.com/lucas-clemente/quic-go/http3" "github.com/valyala/fasthttp" + "io" "log" "net" "net/http" @@ -14,22 +15,52 @@ import ( ) type Server struct { - ListenAddress string - TLSConfig *tlsutils.Configuration - EnableHTTP2 bool - EnableHTTP3 bool - Debug bool - Handler RequestHandler + ListenAddress string + TLSConfig *tlsutils.Configuration + EnableHTTP2 bool + FastHTTPRequestServer FastHTTPRequestServer + NetHTTPRequestServer NetHTTPRequestServer + EnableHTTP3 bool + Debug bool + Handler RequestHandler extraHeadersMutex sync.RWMutex extraHeaders map[string]string } -type RequestHandler func(ctx *RequestContext) +type RequestContext interface { + GetPath() string + GetConnectionTime() time.Time + GetRequestTime() time.Time + GetTLSServerName() string + GetBody() io.Reader + GetRequestHeader(name string) string + GetResponseHeader(name string) string + AddResponseHeader(name string, value string) + SetResponseHeader(name string, value string) + ServeFile(path string) + ServeBytes(content []byte) + SetResponseCode(code int) + DoRedirect(location string, code int) + + IsGet() bool + IsPost() bool + IsOptions() bool + IsHead() bool + + GetExtraHeaders() map[string]string + + AddTiming(name string, desc string, d time.Duration) + AddTimingInformational(name string, desc string) +} + +type RequestHandler func(ctx RequestContext) + +type FastHTTPRequestServer func(ctx *fasthttp.RequestCtx) RequestContext +type NetHTTPRequestServer func(w http.ResponseWriter, r *http.Request) RequestContext func (server *Server) Serve() { var wg sync.WaitGroup - defer wg.Wait() if server.EnableHTTP2 { server.TLSConfig.Config.NextProtos = []string{ @@ -41,15 +72,22 @@ func (server *Server) Serve() { } wg.Add(1) - go func(wg *sync.WaitGroup) { + go func() { defer wg.Done() + handler := func(ctx *fasthttp.RequestCtx) RequestContext { + return NewRequestContextFromFastHttp(server, ctx) + } + + if server.FastHTTPRequestServer != nil { + handler = server.FastHTTPRequestServer + } + s := &fasthttp.Server{ ReadTimeout: 5 * time.Second, IdleTimeout: 15 * time.Second, Handler: func(ctx *fasthttp.RequestCtx) { - context := NewRequestContextFromFastHttp(server, ctx) - server.Handler(&context) + server.Handler(handler(ctx)) }, NoDefaultServerHeader: true, NoDefaultDate: true, @@ -74,23 +112,30 @@ func (server *Server) Serve() { if err != nil { log.Panic(err) } - }(&wg) + }() if server.EnableHTTP3 { wg.Add(1) - go func(wg *sync.WaitGroup) { + go func() { defer wg.Done() + handler := func(w http.ResponseWriter, r *http.Request) RequestContext { + return NewRequestContextFromHttp(server, w, r) + } + + if server.NetHTTPRequestServer != nil { + handler = server.NetHTTPRequestServer + } + s := &http3.Server{ Server: &http.Server{ Addr: server.ListenAddress, TLSConfig: server.TLSConfig.QUICConfig, Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - context := NewRequestContextFromHttp(server, w, r) if server.Debug { log.Print("Received HTTP/3 request") } - server.Handler(&context) + server.Handler(handler(w, r)) }), }, } @@ -104,8 +149,10 @@ func (server *Server) Serve() { if err != nil { log.Panic(err) } - }(&wg) + }() } + + wg.Wait() } func (server *Server) GetExtraHeaders() map[string]string {