From bc9af1d613b9c0d83caef9abf45d03e2a5ab4e87 Mon Sep 17 00:00:00 2001 From: WeebDataHoarder <57538841+WeebDataHoarder@users.noreply.github.com> Date: Mon, 6 Jun 2022 22:16:43 +0200 Subject: [PATCH] Added stream for context for file serving --- contentmessage/message.go | 2 +- httputils/fasthttp.go | 99 ++++++++++++++++++++++++++++++++------- httputils/fileserver.go | 68 +++++++++++++++++++++++++++ httputils/nethttp.go | 95 ++++++++++++++++++++++++++++++++++--- httputils/server.go | 10 +++- 5 files changed, 248 insertions(+), 26 deletions(-) create mode 100644 httputils/fileserver.go diff --git a/contentmessage/message.go b/contentmessage/message.go index a5ec7c1..b20bef7 100644 --- a/contentmessage/message.go +++ b/contentmessage/message.go @@ -143,7 +143,7 @@ func (s *ContentMessage) Encode() []byte { return append(message, s.Signature...) } -func (s ContentMessage) String() string { +func (s *ContentMessage) String() string { return fmt.Sprintf("%d %x %d %s %x", s.Version, s.PublicKey, s.IssueTime, s.Identifier.String(), s.Signature) } diff --git a/httputils/fasthttp.go b/httputils/fasthttp.go index d5facb6..4956182 100644 --- a/httputils/fasthttp.go +++ b/httputils/fasthttp.go @@ -5,20 +5,12 @@ import ( "fmt" "github.com/valyala/fasthttp" "io" + "net/textproto" + "strconv" + "strings" "time" ) -var fsHandler = (&fasthttp.FS{ - Root: "/", - AcceptByteRange: true, - Compress: false, - CompressBrotli: false, - CacheDuration: time.Minute * 15, - PathRewrite: func(ctx *fasthttp.RequestCtx) []byte { - return ctx.Request.URI().PathOriginal() - }, -}).NewRequestHandler() - type FastHTTPContext struct { ctx *fasthttp.RequestCtx server *Server @@ -89,14 +81,89 @@ func (c *FastHTTPContext) SetResponseHeader(name string, value string) { c.ctx.Response.Header.Set(name, value) } -func (c *FastHTTPContext) ServeFile(path string) { - c.ctx.Request.URI().Reset() - c.ctx.Request.URI().SetPath(path) - fsHandler(c.ctx) +func (c *FastHTTPContext) ServeStream(stream Stream) { + const rangePrefix = "bytes=" + + if stream == nil { + c.SetResponseCode(fasthttp.StatusInternalServerError) + return + } + + if dstream, ok := stream.(DefinedStream); ok { + c.SetResponseHeader("Accept-Ranges", "bytes") + c.SetResponseHeader("Last-Modified", dstream.ModTime().UTC().Format(modTimeFormat)) + + size := dstream.Size() + rangeHeader := c.GetRequestHeader("range") + if len(rangeHeader) > 0 { + if !strings.HasPrefix(rangeHeader, rangePrefix) { + c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) + return + } + ranges := strings.Split(rangeHeader[len(rangePrefix):], ",") + if len(ranges) != 1 { + c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) + return + } + + ra := ranges[0] + start, end, ok := strings.Cut(ra, "-") + if !ok { + c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) + return + } + start, end = textproto.TrimString(start), textproto.TrimString(end) + var rangeStart, rangeLength int64 + if start == "" { + //Only supports forward ranges, not backward ranges + c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) + return + } else { + i, err := strconv.ParseInt(start, 10, 64) + if err != nil || i < 0 { + //Only supports forward ranges, not backward ranges + } + if i >= size { + // If the range begins after the size of the content, + // then it does not overlap. + c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) + return + } + rangeStart = i + if end == "" { + // If no end is specified, range extends to end of the file. + rangeLength = size - rangeStart + } else { + i, err := strconv.ParseInt(end, 10, 64) + if err != nil || rangeStart > i { + c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) + return + } + if i >= size { + i = size - 1 + } + rangeLength = i - rangeStart + 1 + } + } + + if _, err := dstream.Seek(rangeStart, io.SeekStart); err != nil { + c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) + return + } + + c.SetResponseCode(fasthttp.StatusPartialContent) + c.SetResponseHeader("Content-Range", fmt.Sprintf("bytes %d-%d/%d", rangeStart, rangeStart+rangeLength-1, size)) + size = rangeLength + } + + c.ctx.SetBodyStream(stream, int(size)) + } else { + c.ctx.SetBodyStream(stream, -1) + } } func (c *FastHTTPContext) ServeBytes(content []byte) { - c.ctx.Write(content) + c.ctx.SetBody(content) } func (c *FastHTTPContext) SetResponseCode(code int) { diff --git a/httputils/fileserver.go b/httputils/fileserver.go new file mode 100644 index 0000000..b866d34 --- /dev/null +++ b/httputils/fileserver.go @@ -0,0 +1,68 @@ +package httputils + +import ( + "io" + "os" + "time" +) + +type Stream interface { + io.Reader + io.Closer +} + +type SeekableStream interface { + Stream + io.Seeker +} + +type DefinedStream interface { + SeekableStream + Size() int64 + ModTime() time.Time +} + +type FileStream struct { + file *os.File + size int64 + modTime time.Time +} + +const modTimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT" + +func (fs *FileStream) Read(p []byte) (n int, err error) { + return fs.file.Read(p) +} + +func (fs *FileStream) Seek(offset int64, whence int) (int64, error) { + return fs.file.Seek(offset, whence) +} + +func (fs *FileStream) Close() error { + return fs.file.Close() +} + +func (fs *FileStream) Size() int64 { + return fs.size +} + +func (fs *FileStream) ModTime() time.Time { + return fs.modTime +} + +func NewStreamFromFile(file *os.File) DefinedStream { + if file == nil { + return nil + } + + fstat, err := file.Stat() + if err != nil { + return nil + } + + return &FileStream{ + file: file, + size: fstat.Size(), + modTime: fstat.ModTime(), + } +} diff --git a/httputils/nethttp.go b/httputils/nethttp.go index 207abf8..4976d42 100644 --- a/httputils/nethttp.go +++ b/httputils/nethttp.go @@ -2,8 +2,12 @@ package httputils import ( "fmt" + "github.com/valyala/fasthttp" "io" "net/http" + "net/textproto" + "strconv" + "strings" "time" ) @@ -58,9 +62,7 @@ func (c *NetHTTPContext) GetRequestTime() time.Time { } func (c *NetHTTPContext) GetTLSServerName() string { - return c.httpRequest.TLS.ServerName - } func (c *NetHTTPContext) GetRequestHeader(name string) string { @@ -80,25 +82,104 @@ func (c *NetHTTPContext) SetResponseHeader(name string, value string) { c.httpWriter.Header().Set(name, value) } -func (c *NetHTTPContext) ServeFile(path string) { +func (c *NetHTTPContext) ServeStream(stream Stream) { + const rangePrefix = "bytes=" - http.ServeFile(c.httpWriter, c.httpRequest, path) + if stream == nil { + c.SetResponseCode(fasthttp.StatusInternalServerError) + return + } + if dstream, ok := stream.(DefinedStream); ok { + c.SetResponseHeader("Last-Modified", dstream.ModTime().UTC().Format(modTimeFormat)) + c.SetResponseHeader("Accept-Ranges", "bytes") + + size := dstream.Size() + rangeHeader := c.GetRequestHeader("range") + if len(rangeHeader) > 0 { + if !strings.HasPrefix(rangeHeader, rangePrefix) { + c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) + return + } + ranges := strings.Split(rangeHeader[len(rangePrefix):], ",") + if len(ranges) != 1 { + c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) + return + } + + ra := ranges[0] + start, end, ok := strings.Cut(ra, "-") + if !ok { + c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) + return + } + start, end = textproto.TrimString(start), textproto.TrimString(end) + var rangeStart, rangeLength int64 + if start == "" { + //Only supports forward ranges, not backward ranges + c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) + return + } else { + i, err := strconv.ParseInt(start, 10, 64) + if err != nil || i < 0 { + //Only supports forward ranges, not backward ranges + } + if i >= size { + // If the range begins after the size of the content, + // then it does not overlap. + c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) + return + } + rangeStart = i + if end == "" { + // If no end is specified, range extends to end of the file. + rangeLength = size - rangeStart + } else { + i, err := strconv.ParseInt(end, 10, 64) + if err != nil || rangeStart > i { + c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) + return + } + if i >= size { + i = size - 1 + } + rangeLength = i - rangeStart + 1 + } + } + + if _, err := dstream.Seek(rangeStart, io.SeekStart); err != nil { + c.SetResponseCode(fasthttp.StatusRequestedRangeNotSatisfiable) + return + } + + c.SetResponseCode(fasthttp.StatusPartialContent) + c.SetResponseHeader("Content-Range", fmt.Sprintf("bytes %d-%d/%d", rangeStart, rangeStart+rangeLength-1, size)) + size = rangeLength + } + + c.SetResponseHeader("Content-Length", strconv.FormatInt(size, 10)) + + if !c.IsHead() { + io.CopyN(c.httpWriter, dstream, size) + } + } else { + if !c.IsHead() { + c.SetResponseHeader("Transfer-Encoding", "chunked") + io.Copy(c.httpWriter, dstream) + } + } } func (c *NetHTTPContext) ServeBytes(content []byte) { - c.httpWriter.Write(content) } func (c *NetHTTPContext) SetResponseCode(code int) { - c.httpWriter.WriteHeader(code) } func (c *NetHTTPContext) DoRedirect(location string, code int) { http.Redirect(c.httpWriter, c.httpRequest, location, code) - } func (c *NetHTTPContext) GetBody() io.Reader { diff --git a/httputils/server.go b/httputils/server.go index 83d09e3..16dee25 100644 --- a/httputils/server.go +++ b/httputils/server.go @@ -8,6 +8,7 @@ import ( "github.com/valyala/fasthttp" "io" "log" + "mime" "net" "net/http" "sync" @@ -28,6 +29,8 @@ type Server struct { extraHeaders map[string]string } +const ServeStreamChunked = -1 + type RequestContext interface { GetPath() string GetConnectionTime() time.Time @@ -38,7 +41,8 @@ type RequestContext interface { GetResponseHeader(name string) string AddResponseHeader(name string, value string) SetResponseHeader(name string, value string) - ServeFile(path string) + + ServeStream(stream Stream) ServeBytes(content []byte) SetResponseCode(code int) DoRedirect(location string, code int) @@ -62,10 +66,12 @@ type NetHTTPRequestServer func(w http.ResponseWriter, r *http.Request) RequestCo func (server *Server) Serve() { var wg sync.WaitGroup + mime.AddExtensionType(".bin", "application/octet-stream") + if server.EnableHTTP2 { server.TLSConfig.Config.NextProtos = []string{ - "h2", "http/1.1", + http2.H2TLSProto, } } else { server.AddExtraHeader("Connection", "close")