This commit is contained in:
DataHoarder 2023-06-15 12:17:52 +02:00
parent 6fc797f884
commit fad91c9146
Signed by: DataHoarder
SSH key fingerprint: SHA256:OLTRf6Fl87G52SiR7sWLGNzlJt4WOX+tfI2yxo0z7xk
8 changed files with 94 additions and 82 deletions

View file

@ -96,7 +96,6 @@ Configuration is done in `config.json`, which you'll need to create with the fol
"proxy_header": "",
"timeout": {
"read": 1,
"read_header": 2,
"write": 1,
"idle": 30
}
@ -143,8 +142,7 @@ Configuration is done in `config.json`, which you'll need to create with the fol
- `admin_token` - administrative token used in `Authorization` header to access advanced prometheus statistics
- `proxy_header` - header name to look for user's real IP address, for example `X-Real-Ip`
- `timeout`
- `read_header` - timeout in seconds for reading request headers
- `read` - timeout in seconds for reading request; is reset after headers are read
- `read` - timeout in seconds for reading request
- `write` - timeout in seconds for writing response (total time spent)
- `idle` - how long (in seconds) to keep connection open for keep-alive requests
- `announce`

2
go.mod
View file

@ -9,6 +9,7 @@ require (
github.com/jinzhu/copier v0.3.5
github.com/prometheus/client_golang v1.15.1
github.com/prometheus/common v0.44.0
github.com/valyala/fasthttp v1.47.0
github.com/viney-shih/go-lock v1.1.2
github.com/zeebo/bencode v1.0.0
)
@ -33,6 +34,7 @@ require (
github.com/prometheus/procfs v0.9.0 // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/shopspring/decimal v1.3.1 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
go.opentelemetry.io/otel v1.15.0 // indirect
go.opentelemetry.io/otel/trace v1.15.0 // indirect
golang.org/x/sync v0.1.0 // indirect

6
go.sum
View file

@ -89,6 +89,10 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.47.0 h1:y7moDoxYzMooFpT5aHgNgVOQDrS3qlkfiP9mDtGGK9c=
github.com/valyala/fasthttp v1.47.0/go.mod h1:k2zXd82h/7UZc3VOdJ2WaUqt1uZ/XpXAfE9i+HBC3lA=
github.com/viney-shih/go-lock v1.1.2 h1:3TdGTiHZCPqBdTvFbQZQN/TRZzKF3KWw2rFEyKz3YqA=
github.com/viney-shih/go-lock v1.1.2/go.mod h1:Yijm78Ljteb3kRiJrbLAxVntkUukGu5uzSxq/xV7OO8=
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
@ -108,7 +112,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc=
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=

View file

@ -20,7 +20,7 @@ package server
import (
"bytes"
"encoding/json"
"net/http"
"github.com/valyala/fasthttp"
"time"
"chihaya/log"
@ -35,10 +35,10 @@ func alive(buf *bytes.Buffer) int {
res, err := json.Marshal(response{time.Now().UnixMilli(), time.Since(handler.startTime).Milliseconds()})
if err != nil {
log.Error.Print("Failed to marshal JSON alive response: ", err)
return http.StatusInternalServerError
return fasthttp.StatusInternalServerError
}
buf.Write(res)
return http.StatusOK
return fasthttp.StatusOK
}

View file

@ -21,9 +21,9 @@ import (
"bytes"
"context"
"fmt"
"github.com/valyala/fasthttp"
"math"
"net"
"net/http"
"strings"
"time"
@ -111,7 +111,7 @@ func getPublicIPV4(ipAddr string, exists bool) (string, bool) {
return ipAddr, !private
}
func announce(ctx context.Context, qs string, header http.Header, remoteAddr string, user *cdb.User,
func announce(ctx context.Context, qs string, header *fasthttp.RequestHeader, remoteAddr net.Addr, user *cdb.User,
db *database.Database, buf *bytes.Buffer) int {
qp, err := params.ParseQuery(qs)
if err != nil {
@ -128,45 +128,45 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
if infoHashes == nil {
failure("Malformed request - missing info_hash", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
} else if len(infoHashes) > 1 {
failure("Malformed request - can only announce singular info_hash", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if peerID == "" {
failure("Malformed request - missing peer_id", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if len(peerID) != 20 {
failure("Malformed request - invalid peer_id", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if !portExists {
failure("Malformed request - missing port", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if strictPort && port < 1024 || port > 65535 {
failure(fmt.Sprintf("Malformed request - port outside of acceptable range (port: %d)", port), buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if !uploadedExists {
failure("Malformed request - missing uploaded", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if !downloadedExists {
failure("Malformed request - missing downloaded", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if !leftExists {
failure("Malformed request - missing left", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
ipAddr := func() string {
@ -188,15 +188,17 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
// Check for proxied IP header
if proxyHeader, exists := config.Section("http").Get("proxy_header", ""); exists {
if ips, exists := header[proxyHeader]; exists && len(ips) > 0 {
return ips[0]
if ips := header.PeekAll(proxyHeader); len(ips) > 0 {
return string(ips[0])
}
}
// Check for IP in socket
portIndex := strings.LastIndex(remoteAddr, ":")
remoteAddrString := remoteAddr.String()
portIndex := strings.LastIndex(remoteAddrString, ":")
if portIndex != -1 {
return remoteAddr[0:portIndex]
return remoteAddrString[0:portIndex]
}
// Everything else failed
@ -206,19 +208,19 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
ipBytes := net.ParseIP(ipAddr).To4()
if nil == ipBytes {
failure(fmt.Sprintf("Failed to parse IP address (ip: %s)", ipAddr), buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
clientID, matched := clientApproved(peerID, db)
if !matched {
failure(fmt.Sprintf("Your client is not approved (peer_id: %s)", peerID), buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
torrent, exists := (*db.Torrents.Load())[infoHashes[0]]
if !exists {
failure("This torrent does not exist", buf, 5*time.Minute)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if torrentStatus := torrent.Status.Load(); torrentStatus == 1 && left == 0 {
@ -232,7 +234,7 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
go db.UnPrune(torrent)
} else if torrentStatus != 0 {
failure(fmt.Sprintf("This torrent does not exist (status: %d, left: %d)", torrentStatus, left), buf, 15*time.Minute)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
numWant, exists := qp.GetUint16("numwant")
@ -257,14 +259,14 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
// Take torrent peers lock to read/write on it to prevent race conditions
if !torrent.TryLockWithContext(ctx) {
return http.StatusRequestTimeout
return fasthttp.StatusRequestTimeout
}
defer torrent.Unlock()
if left > 0 {
if isDisabledDownload(db, user, torrent) {
failure("Your download privileges are disabled", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
peer, exists = torrent.Leechers[peerKey]
@ -531,7 +533,7 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
// Early exit before response write
select {
case <-ctx.Done():
return http.StatusRequestTimeout
return fasthttp.StatusRequestTimeout
default:
}
@ -540,5 +542,5 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
panic(err)
}
return http.StatusOK
return fasthttp.StatusOK
}

View file

@ -20,7 +20,7 @@ package server
import (
"bytes"
"context"
"net/http"
"github.com/valyala/fasthttp"
"time"
"chihaya/collectors"
@ -31,9 +31,9 @@ import (
"github.com/prometheus/common/expfmt"
)
var bearerPrefix = "Bearer "
var bearerPrefix = []byte("Bearer ")
func metrics(ctx context.Context, auth string, db *database.Database, buf *bytes.Buffer) int {
func metrics(ctx context.Context, auth []byte, db *database.Database, buf *bytes.Buffer) int {
dbUsers := *db.Users.Load()
dbTorrents := *db.Torrents.Load()
@ -46,7 +46,7 @@ func metrics(ctx context.Context, auth string, db *database.Database, buf *bytes
// Early exit before response write
select {
case <-ctx.Done():
return http.StatusRequestTimeout
return fasthttp.StatusRequestTimeout
default:
}
@ -68,9 +68,9 @@ func metrics(ctx context.Context, auth string, db *database.Database, buf *bytes
}
n := len(bearerPrefix)
if len(auth) > n && auth[:n] == bearerPrefix {
if len(auth) > n && bytes.Equal(auth[:n], bearerPrefix) {
adminToken, exists := config.Section("http").Get("admin_token", "")
if exists && auth[n:] == adminToken {
if exists && bytes.Equal(auth[:n], []byte(adminToken)) {
mfs, _ := prometheus.DefaultGatherer.Gather()
for _, mf := range mfs {
@ -82,5 +82,5 @@ func metrics(ctx context.Context, auth string, db *database.Database, buf *bytes
}
}
return http.StatusOK
return fasthttp.StatusOK
}

View file

@ -19,13 +19,12 @@ package server
import (
"bytes"
"context"
"net/http"
"chihaya/config"
"chihaya/database"
cdb "chihaya/database/types"
"chihaya/server/params"
"context"
"github.com/valyala/fasthttp"
"github.com/zeebo/bencode"
)
@ -80,7 +79,7 @@ func scrape(ctx context.Context, qs string, user *cdb.User, db *database.Databas
// Early exit before response write
select {
case <-ctx.Done():
return http.StatusRequestTimeout
return fasthttp.StatusRequestTimeout
default:
}
@ -89,5 +88,5 @@ func scrape(ctx context.Context, qs string, user *cdb.User, db *database.Databas
panic(err)
}
return http.StatusOK
return fasthttp.StatusOK
}

View file

@ -20,10 +20,9 @@ package server
import (
"bytes"
"context"
"github.com/valyala/fasthttp"
"net"
"net/http"
"path"
"strconv"
"sync"
"sync/atomic"
"time"
@ -63,10 +62,10 @@ var (
listener net.Listener
)
func (handler *httpHandler) respond(ctx context.Context, r *http.Request, buf *bytes.Buffer) int {
dir, action := path.Split(r.URL.Path)
func (handler *httpHandler) respond(ctx context.Context, requestContext *fasthttp.RequestCtx, buf *bytes.Buffer) int {
dir, action := path.Split(string(requestContext.Request.URI().Path()))
if action == "" {
return http.StatusNotFound
return fasthttp.StatusNotFound
}
/*
@ -82,7 +81,7 @@ func (handler *httpHandler) respond(ctx context.Context, r *http.Request, buf *b
return alive(buf)
}
return http.StatusNotFound
return fasthttp.StatusNotFound
}
/*
@ -94,26 +93,33 @@ func (handler *httpHandler) respond(ctx context.Context, r *http.Request, buf *b
user := isPasskeyValid(passkey, handler.db)
if user == nil {
failure("Your passkey is invalid", buf, 1*time.Hour)
return http.StatusOK
return fasthttp.StatusOK
}
switch action {
case "announce":
return announce(ctx, r.URL.RawQuery, r.Header, r.RemoteAddr, user, handler.db, buf)
return announce(
ctx,
string(requestContext.Request.URI().QueryString()),
&requestContext.Request.Header,
requestContext.RemoteAddr(),
user,
handler.db,
buf)
case "scrape":
if enabled, _ := config.GetBool("scrape", true); !enabled {
return http.StatusNotFound
return fasthttp.StatusNotFound
}
return scrape(ctx, r.URL.RawQuery, user, handler.db, buf)
return scrape(ctx, string(requestContext.Request.URI().QueryString()), user, handler.db, buf)
case "metrics":
return metrics(ctx, r.Header.Get("Authorization"), handler.db, buf)
return metrics(ctx, requestContext.Request.Header.PeekBytes([]byte("Authorization")), handler.db, buf)
}
return http.StatusNotFound
return fasthttp.StatusNotFound
}
func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (handler *httpHandler) handler(requestContext *fasthttp.RequestCtx) {
if handler.terminate {
return
}
@ -125,9 +131,6 @@ func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
handler.waitGroup.Add(1)
defer handler.waitGroup.Done()
// Flush buffered data to client
defer w.(http.Flusher).Flush()
buf := handler.bufferPool.Take()
// Mark buf to be returned to bufferPool after we are done with it
defer handler.bufferPool.Give(buf)
@ -135,25 +138,25 @@ func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Gracefully handle panics so that they're confined to single request and don't crash server
defer func() {
if err := recover(); err != nil {
log.Error.Printf("Recovered from panicking request handler - %v\nURL was: %s", err, r.URL)
log.Error.Printf("Recovered from panicking request handler - %v\nURL was: %s", err, requestContext.URI().String())
log.WriteStack()
collectors.IncrementErroredRequests()
if w.Header().Get("Content-Type") == "" {
if len(requestContext.Response.Header.ContentType()) == 0 {
buf.Reset()
w.WriteHeader(http.StatusInternalServerError)
requestContext.Response.SetStatusCode(fasthttp.StatusInternalServerError)
}
}
}()
// Prepare and start new context with timeout to abort long-running requests
ctx, cancel := context.WithTimeout(r.Context(), handler.contextTimeout)
ctx, cancel := context.WithTimeout(context.Background(), handler.contextTimeout)
defer cancel()
/* Pass flow to handler; note that handler should be responsible for actually cancelling
its own work based on request context cancellation */
status := handler.respond(ctx, r, buf)
status := handler.respond(ctx, requestContext, buf)
select {
case <-ctx.Done():
@ -163,20 +166,20 @@ func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
failure("Request context deadline exceeded", buf, 5*time.Minute)
w.Header().Add("Content-Length", strconv.Itoa(buf.Len()))
w.Header().Add("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK) // Required by torrent clients to interpret failure response
_, _ = w.Write(buf.Bytes())
requestContext.Response.Header.SetContentLength(buf.Len())
requestContext.Response.Header.SetContentTypeBytes([]byte("text/plain"))
requestContext.Response.SetStatusCode(fasthttp.StatusOK) // Required by torrent clients to interpret failure response
_, _ = requestContext.Write(buf.Bytes())
case context.Canceled:
collectors.IncrementCancelRequests()
w.WriteHeader(http.StatusRequestTimeout)
requestContext.Response.SetStatusCode(fasthttp.StatusRequestTimeout)
}
default:
w.Header().Add("Content-Length", strconv.Itoa(buf.Len()))
w.Header().Add("Content-Type", "text/plain")
w.WriteHeader(status)
_, _ = w.Write(buf.Bytes())
requestContext.Response.Header.SetContentLength(buf.Len())
requestContext.Response.Header.SetContentTypeBytes([]byte("text/plain"))
requestContext.Response.SetStatusCode(status)
_, _ = requestContext.Write(buf.Bytes())
}
}
@ -190,7 +193,6 @@ func Start() {
addr, _ := config.Section("http").Get("addr", ":34000")
readTimeout, _ := config.Section("http").Section("timeout").GetInt("read", 1)
readHeaderTimeout, _ := config.Section("http").Section("timeout").GetInt("read_header", 2)
writeTimeout, _ := config.Section("http").Section("timeout").GetInt("write", 3)
idleTimeout, _ := config.Section("http").Section("timeout").GetInt("idle", 30)
@ -198,17 +200,23 @@ func Start() {
handler.contextTimeout = time.Duration(writeTimeout)*time.Second - 200*time.Millisecond
// Create new server instance
server := &http.Server{
Handler: handler,
ReadTimeout: time.Duration(readTimeout) * time.Second,
ReadHeaderTimeout: time.Duration(readHeaderTimeout) * time.Second,
WriteTimeout: time.Duration(writeTimeout) * time.Second,
IdleTimeout: time.Duration(idleTimeout) * time.Second,
server := &fasthttp.Server{
Handler: handler.handler,
ReadTimeout: time.Duration(readTimeout) * time.Second,
WriteTimeout: time.Duration(writeTimeout) * time.Second,
IdleTimeout: time.Duration(idleTimeout) * time.Second,
GetOnly: true,
DisablePreParseMultipartForm: true,
NoDefaultServerHeader: true,
NoDefaultDate: true,
NoDefaultContentType: true,
CloseOnShutdown: true,
}
if idleTimeout <= 0 {
log.Warning.Print("Setting idleTimeout <= 0 disables Keep-Alive which might negatively impact performance")
server.SetKeepAlivesEnabled(false)
server.DisableKeepalive = true
}
// Start new goroutine to calculate throughput
@ -261,8 +269,7 @@ func Start() {
// Wait for active connections to finish processing
handler.waitGroup.Wait()
// Close server so that it does not Accept(), https://github.com/golang/go/issues/10527
_ = server.Close()
_ = server.Shutdown()
log.Info.Print("Now closed and not accepting any new connections")