Compare commits

...

3 commits

18 changed files with 438 additions and 462 deletions

View file

@ -58,12 +58,7 @@ Configuration is done in `config.json`, which you'll need to create with the fol
```json
{
"database": {
"username": "chihaya",
"password": "",
"database": "chihaya",
"proto": "tcp",
"addr": "127.0.0.1:3306",
"dsn": "chihaya:@tcp(127.0.0.1:3306)/chihaya",
"deadlock_pause": 1,
"deadlock_retries": 5
},
@ -96,7 +91,6 @@ Configuration is done in `config.json`, which you'll need to create with the fol
"proxy_header": "",
"timeout": {
"read": 1,
"read_header": 2,
"write": 1,
"idle": 30
}
@ -143,8 +137,7 @@ Configuration is done in `config.json`, which you'll need to create with the fol
- `admin_token` - administrative token used in `Authorization` header to access advanced prometheus statistics
- `proxy_header` - header name to look for user's real IP address, for example `X-Real-Ip`
- `timeout`
- `read_header` - timeout in seconds for reading request headers
- `read` - timeout in seconds for reading request; is reset after headers are read
- `read` - timeout in seconds for reading request
- `write` - timeout in seconds for writing response (total time spent)
- `idle` - how long (in seconds) to keep connection open for keep-alive requests
- `announce`

View file

@ -40,18 +40,16 @@ func TestMain(m *testing.M) {
panic(err)
}
f, err := os.OpenFile("config.json", os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
f, err := os.OpenFile(configFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
panic(err)
}
configTest = make(Map)
dbConfig := map[string]interface{}{
"username": "chihaya",
"password": "",
"proto": "tcp",
"addr": "127.0.0.1:3306",
"database": "chihaya",
"dsn": "chihaya:@tcp(127.0.0.1:3306)/chihaya",
"deadlock_pause": json.Number("1"),
"deadlock_retries": json.Number("5"),
}
configTest["database"] = dbConfig
configTest["addr"] = ":34000"
@ -148,5 +146,5 @@ func TestSection(t *testing.T) {
}
func cleanup() {
_ = os.Remove("config.json")
_ = os.Remove(configFile)
}

View file

@ -19,8 +19,8 @@ package database
import (
"bytes"
"context"
"database/sql"
"fmt"
"os"
"sync"
"sync/atomic"
@ -35,11 +35,6 @@ import (
"github.com/go-sql-driver/mysql"
)
type Connection struct {
sqlDb *sql.DB
mutex sync.Mutex
}
type Database struct {
snatchChannel chan *bytes.Buffer
transferHistoryChannel chan *bytes.Buffer
@ -62,13 +57,15 @@ type Database struct {
TorrentGroupFreeleech atomic.Pointer[map[cdb.TorrentGroupKey]*cdb.TorrentGroupFreeleech]
Clients atomic.Pointer[map[uint16]string]
mainConn *Connection // Used for reloading and misc queries
bufferPool *util.BufferPool
transferHistoryLock sync.Mutex
terminate bool
conn *sql.DB
terminate atomic.Bool
ctx context.Context
ctxCancel func()
waitGroup sync.WaitGroup
}
@ -77,71 +74,66 @@ var (
maxDeadlockRetries int
)
var defaultDsn = map[string]string{
"username": "chihaya",
"password": "",
"proto": "tcp",
"addr": "127.0.0.1:3306",
"database": "chihaya",
}
const defaultDsn = "chihaya:@tcp(127.0.0.1:3306)/chihaya"
func (db *Database) Init() {
db.terminate = false
db.terminate.Store(false)
db.ctx, db.ctxCancel = context.WithCancel(context.Background())
log.Info.Print("Opening database connection...")
db.mainConn = Open()
db.conn = Open()
// Used for recording updates, so the max required size should be < 128 bytes. See queue.go for details
db.bufferPool = util.NewBufferPool(128)
var err error
db.loadUsersStmt, err = db.mainConn.sqlDb.Prepare(
db.loadUsersStmt, err = db.conn.Prepare(
"SELECT ID, torrent_pass, DownMultiplier, UpMultiplier, DisableDownload, TrackerHide " +
"FROM users_main WHERE Enabled = '1'")
if err != nil {
panic(err)
}
db.loadHnrStmt, err = db.mainConn.sqlDb.Prepare(
db.loadHnrStmt, err = db.conn.Prepare(
"SELECT h.uid, h.fid FROM transfer_history AS h " +
"JOIN users_main AS u ON u.ID = h.uid WHERE h.hnr = 1 AND u.Enabled = '1'")
if err != nil {
panic(err)
}
db.loadTorrentsStmt, err = db.mainConn.sqlDb.Prepare(
db.loadTorrentsStmt, err = db.conn.Prepare(
"SELECT ID, info_hash, DownMultiplier, UpMultiplier, Snatched, Status, GroupID, TorrentType FROM torrents")
if err != nil {
panic(err)
}
db.loadTorrentGroupFreeleechStmt, err = db.mainConn.sqlDb.Prepare(
db.loadTorrentGroupFreeleechStmt, err = db.conn.Prepare(
"SELECT GroupID, `Type`, DownMultiplier, UpMultiplier FROM torrent_group_freeleech")
if err != nil {
panic(err)
}
db.loadClientsStmt, err = db.mainConn.sqlDb.Prepare(
db.loadClientsStmt, err = db.conn.Prepare(
"SELECT id, peer_id FROM approved_clients WHERE archived = 0")
if err != nil {
panic(err)
}
db.loadFreeleechStmt, err = db.mainConn.sqlDb.Prepare(
db.loadFreeleechStmt, err = db.conn.Prepare(
"SELECT mod_setting FROM mod_core WHERE mod_option = 'global_freeleech'")
if err != nil {
panic(err)
}
db.cleanStalePeersStmt, err = db.mainConn.sqlDb.Prepare(
db.cleanStalePeersStmt, err = db.conn.Prepare(
"UPDATE transfer_history SET active = 0 WHERE last_announce < ? AND active = 1")
if err != nil {
panic(err)
}
db.unPruneTorrentStmt, err = db.mainConn.sqlDb.Prepare(
db.unPruneTorrentStmt, err = db.conn.Prepare(
"UPDATE torrents SET Status = 0 WHERE ID = ?")
if err != nil {
panic(err)
@ -179,7 +171,8 @@ func (db *Database) Init() {
func (db *Database) Terminate() {
log.Info.Print("Terminating database connection...")
db.terminate = true
db.terminate.Store(true)
db.ctxCancel()
log.Info.Print("Closing all flush channels...")
db.closeFlushChannels()
@ -190,13 +183,11 @@ func (db *Database) Terminate() {
}()
db.waitGroup.Wait()
db.mainConn.mutex.Lock()
_ = db.mainConn.Close()
db.mainConn.mutex.Unlock()
_ = db.conn.Close()
db.serialize()
}
func Open() *Connection {
func Open() *sql.DB {
databaseConfig := config.Section("database")
deadlockWaitTime, _ = databaseConfig.GetInt("deadlock_pause", 1)
maxDeadlockRetries, _ = databaseConfig.GetInt("deadlock_retries", 5)
@ -212,18 +203,7 @@ func Open() *Connection {
// First try to load the DSN from environment. USeful for tests.
databaseDsn := os.Getenv("DB_DSN")
if databaseDsn == "" {
dbUsername, _ := databaseConfig.Get("username", defaultDsn["username"])
dbPassword, _ := databaseConfig.Get("password", defaultDsn["password"])
dbProto, _ := databaseConfig.Get("proto", defaultDsn["proto"])
dbAddr, _ := databaseConfig.Get("addr", defaultDsn["addr"])
dbDatabase, _ := databaseConfig.Get("database", defaultDsn["database"])
databaseDsn = fmt.Sprintf("%s:%s@%s(%s)/%s",
dbUsername,
dbPassword,
dbProto,
dbAddr,
dbDatabase,
)
databaseDsn, _ = databaseConfig.Get("dsn", defaultDsn)
}
sqlDb, err := sql.Open("mysql", databaseDsn)
@ -235,16 +215,10 @@ func Open() *Connection {
log.Fatal.Fatalf("Couldn't ping database - %s", err)
}
return &Connection{
sqlDb: sqlDb,
}
return sqlDb
}
func (db *Connection) Close() error {
return db.sqlDb.Close()
}
func (db *Connection) query(stmt *sql.Stmt, args ...interface{}) *sql.Rows { //nolint:unparam
func (db *Database) query(stmt *sql.Stmt, args ...interface{}) *sql.Rows { //nolint:unparam
rows, _ := perform(func() (interface{}, error) {
return stmt.Query(args...)
}).(*sql.Rows)
@ -252,7 +226,7 @@ func (db *Connection) query(stmt *sql.Stmt, args ...interface{}) *sql.Rows { //n
return rows
}
func (db *Connection) execute(stmt *sql.Stmt, args ...interface{}) sql.Result {
func (db *Database) execute(stmt *sql.Stmt, args ...interface{}) sql.Result {
result, _ := perform(func() (interface{}, error) {
return stmt.Exec(args...)
}).(sql.Result)
@ -260,9 +234,9 @@ func (db *Connection) execute(stmt *sql.Stmt, args ...interface{}) sql.Result {
return result
}
func (db *Connection) exec(query *bytes.Buffer, args ...interface{}) sql.Result { //nolint:unparam
func (db *Database) exec(query *bytes.Buffer, args ...interface{}) sql.Result { //nolint:unparam
result, _ := perform(func() (interface{}, error) {
return db.sqlDb.Exec(query.String(), args...)
return db.conn.Exec(query.String(), args...)
}).(sql.Result)
return result

View file

@ -46,7 +46,7 @@ func TestMain(m *testing.M) {
db.Init()
fixtures, err = testfixtures.New(
testfixtures.Database(db.mainConn.sqlDb),
testfixtures.Database(db.conn),
testfixtures.Dialect("mariadb"),
testfixtures.Directory("fixtures"),
testfixtures.DangerousSkipTestDatabaseCheck(),
@ -188,6 +188,10 @@ func TestLoadTorrents(t *testing.T) {
{89, 252, 84, 49, 177, 28, 118, 28, 148, 205, 62, 185, 8, 37, 234, 110, 109, 200, 165, 241}: t3,
}
for k := range torrents {
torrents[k].InitializeLock()
}
// Test with fresh data
db.loadTorrents()
@ -375,7 +379,7 @@ func TestRecordAndFlushUsers(t *testing.T) {
deltaDownload = int64(float64(deltaRawDownload) * math.Float64frombits(testUser.DownMultiplier.Load()))
deltaUpload = int64(float64(deltaRawUpload) * math.Float64frombits(testUser.UpMultiplier.Load()))
row := db.mainConn.sqlDb.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
row := db.conn.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
"FROM users_main WHERE ID = ?", testUser.ID.Load())
err := row.Scan(&initUpload, &initDownload, &initRawUpload, &initRawDownload)
@ -390,7 +394,7 @@ func TestRecordAndFlushUsers(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row = db.mainConn.sqlDb.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
row = db.conn.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
"FROM users_main WHERE ID = ?", testUser.ID.Load())
err = row.Scan(&upload, &download, &rawUpload, &rawDownload)
@ -471,7 +475,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
deltaActiveTime = 267
deltaSeedTime = 15
row := db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
row := db.conn.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
"FROM transfer_history WHERE uid = ? AND fid = ?", testPeer.UserID, testPeer.TorrentID)
err := row.Scan(&initRawUpload, &initRawDownload, &initActiveTime, &initSeedTime, &initActive, &initSnatch)
@ -492,7 +496,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row = db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
row = db.conn.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
"FROM transfer_history WHERE uid = ? AND fid = ?", testPeer.UserID, testPeer.TorrentID)
err = row.Scan(&rawUpload, &rawDownload, &activeTime, &seedTime, &active, &snatch)
@ -553,7 +557,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
var gotStartTime int64
row = db.mainConn.sqlDb.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
row = db.conn.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
"FROM transfer_history WHERE uid = ? AND fid = ?", gotPeer.UserID, gotPeer.TorrentID)
err = row.Scan(&gotPeer.Seeding, &gotStartTime, &gotPeer.LastAnnounce, &gotPeer.Left)
@ -598,7 +602,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row = db.mainConn.sqlDb.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
row = db.conn.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
"FROM transfer_history WHERE uid = ? AND fid = ?", gotPeer.UserID, gotPeer.TorrentID)
err = row.Scan(&gotPeer.Seeding, &gotPeer.StartTime, &gotPeer.LastAnnounce, &gotPeer.Left)
@ -635,7 +639,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
deltaDownload = 236
deltaUpload = 3262
row := db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded "+
row := db.conn.QueryRow("SELECT uploaded, downloaded "+
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
@ -651,7 +655,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row = db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded "+
row = db.conn.QueryRow("SELECT uploaded, downloaded "+
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
@ -687,7 +691,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
var gotStartTime int64
row = db.mainConn.sqlDb.QueryRow("SELECT port, starttime, last_announce "+
row = db.conn.QueryRow("SELECT port, starttime, last_announce "+
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
@ -732,7 +736,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
Addr: cdb.NewPeerAddressFromIPPort(testPeer.Addr.IP(), 0),
}
row = db.mainConn.sqlDb.QueryRow("SELECT port, starttime, last_announce "+
row = db.conn.QueryRow("SELECT port, starttime, last_announce "+
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
@ -770,7 +774,7 @@ func TestRecordAndFlushSnatch(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row := db.mainConn.sqlDb.QueryRow("SELECT snatched_time "+
row := db.conn.QueryRow("SELECT snatched_time "+
"FROM transfer_history WHERE uid = ? AND fid = ?", testPeer.UserID, testPeer.TorrentID)
err := row.Scan(&snatchTime)
@ -823,7 +827,7 @@ func TestRecordAndFlushTorrents(t *testing.T) {
numSeeders int
)
row := db.mainConn.sqlDb.QueryRow("SELECT Snatched, last_action, Seeders, Leechers "+
row := db.conn.QueryRow("SELECT Snatched, last_action, Seeders, Leechers "+
"FROM torrents WHERE ID = ?", torrent.ID.Load())
err := row.Scan(&snatched, &lastAction, &numSeeders, &numLeechers)

View file

@ -19,6 +19,7 @@ package database
import (
"bytes"
"chihaya/util"
"errors"
"time"
@ -111,25 +112,9 @@ func (db *Database) flushTorrents() {
count int
)
conn := Open()
for {
query.Reset()
query.WriteString("CREATE TEMPORARY TABLE IF NOT EXISTS flush_torrents (" +
"ID int unsigned NOT NULL, " +
"Snatched int unsigned NOT NULL DEFAULT 0, " +
"Seeders int unsigned NOT NULL DEFAULT 0, " +
"Leechers int unsigned NOT NULL DEFAULT 0, " +
"last_action int NOT NULL DEFAULT 0, " +
"PRIMARY KEY (ID)) ENGINE=MEMORY")
conn.exec(&query)
query.Reset()
query.WriteString("TRUNCATE flush_torrents")
conn.exec(&query)
query.Reset()
query.WriteString("INSERT INTO flush_torrents VALUES ")
query.WriteString("INSERT IGNORE INTO torrents (ID, Snatched, Seeders, Leechers, last_action) VALUES ")
length := len(db.torrentChannel)
@ -149,7 +134,7 @@ func (db *Database) flushTorrents() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{torrents} Flushing %d", count)
}
@ -158,18 +143,9 @@ func (db *Database) flushTorrents() {
query.WriteString(" ON DUPLICATE KEY UPDATE Snatched = Snatched + VALUE(Snatched), " +
"Seeders = VALUE(Seeders), Leechers = VALUE(Leechers), " +
"last_action = IF(last_action < VALUE(last_action), VALUE(last_action), last_action)")
conn.exec(&query)
db.exec(&query)
query.Reset()
query.WriteString("UPDATE torrents t, flush_torrents ft SET " +
"t.Snatched = t.Snatched + ft.Snatched, " +
"t.Seeders = ft.Seeders, " +
"t.Leechers = ft.Leechers, " +
"t.last_action = IF(t.last_action < ft.last_action, ft.last_action, t.last_action)" +
"WHERE t.ID = ft.ID")
conn.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("torrents", elapsedTime)
collectors.UpdateChannelsLen("torrents", count)
@ -178,14 +154,12 @@ func (db *Database) flushTorrents() {
if length < (torrentFlushBufferSize >> 1) {
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
}
} else if db.terminate {
} else if db.terminate.Load() {
break
} else {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) flushUsers() {
@ -197,25 +171,9 @@ func (db *Database) flushUsers() {
count int
)
conn := Open()
for {
query.Reset()
query.WriteString("CREATE TEMPORARY TABLE IF NOT EXISTS flush_users (" +
"ID int unsigned NOT NULL, " +
"Uploaded bigint unsigned NOT NULL DEFAULT 0, " +
"Downloaded bigint unsigned NOT NULL DEFAULT 0, " +
"rawdl bigint unsigned NOT NULL DEFAULT 0, " +
"rawup bigint unsigned NOT NULL DEFAULT 0, " +
"PRIMARY KEY (ID)) ENGINE=MEMORY")
conn.exec(&query)
query.Reset()
query.WriteString("TRUNCATE flush_users")
conn.exec(&query)
query.Reset()
query.WriteString("INSERT INTO flush_users VALUES ")
query.WriteString("INSERT IGNORE INTO users_main (ID, Uploaded, Downloaded, rawdl, rawup) VALUES ")
length := len(db.userChannel)
@ -235,7 +193,7 @@ func (db *Database) flushUsers() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{users_main} Flushing %d", count)
}
@ -243,18 +201,9 @@ func (db *Database) flushUsers() {
query.WriteString(" ON DUPLICATE KEY UPDATE Uploaded = Uploaded + VALUE(Uploaded), " +
"Downloaded = Downloaded + VALUE(Downloaded), rawdl = rawdl + VALUE(rawdl), rawup = rawup + VALUE(rawup)")
conn.exec(&query)
db.exec(&query)
query.Reset()
query.WriteString("UPDATE users_main u, flush_users fu SET " +
"u.Uploaded = u.Uploaded + fu.Uploaded, " +
"u.Downloaded = u.Downloaded + fu.Downloaded, " +
"u.rawdl = u.rawdl + fu.rawdl, " +
"u.rawup = u.rawup + fu.rawup " +
"WHERE u.ID = fu.ID")
conn.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("users", elapsedTime)
collectors.UpdateChannelsLen("users", count)
@ -263,14 +212,12 @@ func (db *Database) flushUsers() {
if length < (userFlushBufferSize >> 1) {
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
}
} else if db.terminate {
} else if db.terminate.Load() {
break
} else {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) flushTransferHistory() {
@ -282,8 +229,6 @@ func (db *Database) flushTransferHistory() {
count int
)
conn := Open()
for {
length, err := func() (int, error) {
db.transferHistoryLock.Lock()
@ -311,7 +256,7 @@ func (db *Database) flushTransferHistory() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{transfer_history} Flushing %d", count)
}
@ -323,16 +268,16 @@ func (db *Database) flushTransferHistory() {
"seedtime = seedtime + VALUE(seedtime), last_announce = VALUE(last_announce), " +
"active = VALUE(active), snatched = snatched + VALUE(snatched);")
conn.exec(&query)
db.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("transfer_history", elapsedTime)
collectors.UpdateChannelsLen("transfer_history", count)
}
return length, nil
} else if db.terminate {
} else if db.terminate.Load() {
return 0, errDbTerminate
}
@ -347,8 +292,6 @@ func (db *Database) flushTransferHistory() {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) flushTransferIps() {
@ -360,8 +303,6 @@ func (db *Database) flushTransferIps() {
count int
)
conn := Open()
for {
query.Reset()
query.WriteString("INSERT INTO transfer_ips (uid, fid, client_id, ip, port, uploaded, downloaded, " +
@ -385,7 +326,7 @@ func (db *Database) flushTransferIps() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{transfer_ips} Flushing %d", count)
}
@ -394,9 +335,9 @@ func (db *Database) flushTransferIps() {
// todo: port should be part of PK
query.WriteString("\nON DUPLICATE KEY UPDATE port = VALUE(port), downloaded = downloaded + VALUE(downloaded), " +
"uploaded = uploaded + VALUE(uploaded), last_announce = VALUE(last_announce)")
conn.exec(&query)
db.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("transfer_ips", elapsedTime)
collectors.UpdateChannelsLen("transfer_ips", count)
@ -405,14 +346,12 @@ func (db *Database) flushTransferIps() {
if length < (transferIpsFlushBufferSize >> 1) {
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
}
} else if db.terminate {
} else if db.terminate.Load() {
break
} else {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) flushSnatches() {
@ -424,8 +363,6 @@ func (db *Database) flushSnatches() {
count int
)
conn := Open()
for {
query.Reset()
query.WriteString("INSERT INTO transfer_history (uid, fid, snatched_time) VALUES\n")
@ -448,16 +385,16 @@ func (db *Database) flushSnatches() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{snatches} Flushing %d", count)
}
startTime := time.Now()
query.WriteString("\nON DUPLICATE KEY UPDATE snatched_time = VALUE(snatched_time)")
conn.exec(&query)
db.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("snatches", elapsedTime)
collectors.UpdateChannelsLen("snatches", count)
@ -466,14 +403,12 @@ func (db *Database) flushSnatches() {
if length < (snatchFlushBufferSize >> 1) {
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
}
} else if db.terminate {
} else if db.terminate.Load() {
break
} else {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) purgeInactivePeers() {
@ -483,7 +418,7 @@ func (db *Database) purgeInactivePeers() {
count int
)
for !db.terminate {
util.ContextTick(db.ctx, time.Duration(purgeInactivePeersInterval)*time.Second, func() {
start = time.Now()
now = start.Unix()
count = 0
@ -544,11 +479,8 @@ func (db *Database) purgeInactivePeers() {
db.transferHistoryLock.Lock()
defer db.transferHistoryLock.Unlock()
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start = time.Now()
result := db.mainConn.execute(db.cleanStalePeersStmt, oldestActive)
result := db.execute(db.cleanStalePeersStmt, oldestActive)
if result != nil {
rows, err := result.RowsAffected()
@ -559,7 +491,5 @@ func (db *Database) purgeInactivePeers() {
}
}
}()
time.Sleep(time.Duration(purgeInactivePeersInterval) * time.Second)
}
})
}

View file

@ -179,7 +179,5 @@ func (db *Database) QueueSnatch(peer *cdb.Peer, now int64) {
}
func (db *Database) UnPrune(torrent *cdb.Torrent) {
db.mainConn.mutex.Lock()
db.mainConn.execute(db.unPruneTorrentStmt, torrent.ID.Load())
db.mainConn.mutex.Unlock()
db.execute(db.unPruneTorrentStmt, torrent.ID.Load())
}

View file

@ -26,6 +26,7 @@ import (
"chihaya/config"
cdb "chihaya/database/types"
"chihaya/log"
"chihaya/util"
)
var (
@ -50,31 +51,26 @@ func init() {
*/
func (db *Database) startReloading() {
go func() {
for !db.terminate {
time.Sleep(time.Duration(reloadInterval) * time.Second)
util.ContextTick(db.ctx, time.Duration(reloadInterval)*time.Second, func() {
db.waitGroup.Add(1)
defer db.waitGroup.Done()
db.loadUsers()
db.loadHitAndRuns()
db.loadTorrents()
db.loadGroupsFreeleech()
db.loadConfig()
db.loadClients()
db.waitGroup.Done()
}
})
}()
}
func (db *Database) loadUsers() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start := time.Now()
dbUsers := *db.Users.Load()
newUsers := make(map[string]*cdb.User, len(dbUsers))
rows := db.mainConn.query(db.loadUsersStmt)
rows := db.query(db.loadUsersStmt)
if rows == nil {
log.Error.Print("Failed to load hit and runs from database")
log.WriteStack()
@ -126,13 +122,10 @@ func (db *Database) loadUsers() {
}
func (db *Database) loadHitAndRuns() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start := time.Now()
newHnr := make(map[cdb.UserTorrentPair]struct{})
rows := db.mainConn.query(db.loadHnrStmt)
rows := db.query(db.loadHnrStmt)
if rows == nil {
log.Error.Print("Failed to load hit and runs from database")
log.WriteStack()
@ -173,12 +166,9 @@ func (db *Database) loadTorrents() {
newTorrents := make(map[cdb.TorrentHash]*cdb.Torrent, len(dbTorrents))
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start = time.Now()
rows := db.mainConn.query(db.loadTorrentsStmt)
rows := db.query(db.loadTorrentsStmt)
if rows == nil {
log.Error.Print("Failed to load torrents from database")
log.WriteStack()
@ -261,13 +251,10 @@ func (db *Database) loadTorrents() {
}
func (db *Database) loadGroupsFreeleech() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start := time.Now()
newTorrentGroupFreeleech := make(map[cdb.TorrentGroupKey]*cdb.TorrentGroupFreeleech)
rows := db.mainConn.query(db.loadTorrentGroupFreeleechStmt)
rows := db.query(db.loadTorrentGroupFreeleechStmt)
if rows == nil {
log.Error.Print("Failed to load torrent group freeleech data from database")
log.WriteStack()
@ -311,10 +298,7 @@ func (db *Database) loadGroupsFreeleech() {
}
func (db *Database) loadConfig() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
rows := db.mainConn.query(db.loadFreeleechStmt)
rows := db.query(db.loadFreeleechStmt)
if rows == nil {
log.Error.Print("Failed to load config from database")
log.WriteStack()
@ -339,13 +323,10 @@ func (db *Database) loadConfig() {
}
func (db *Database) loadClients() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start := time.Now()
newClients := make(map[uint16]string)
rows := db.mainConn.query(db.loadClientsStmt)
rows := db.query(db.loadClientsStmt)
if rows == nil {
log.Error.Print("Failed to load clients from database")
log.WriteStack()

View file

@ -18,6 +18,7 @@
package database
import (
"chihaya/util"
"fmt"
"os"
"time"
@ -37,10 +38,9 @@ func init() {
func (db *Database) startSerializing() {
go func() {
for !db.terminate {
time.Sleep(time.Duration(serializeInterval) * time.Second)
util.ContextTick(db.ctx, time.Duration(serializeInterval)*time.Second, func() {
db.serialize()
}
})
}()
}

2
go.mod
View file

@ -9,6 +9,7 @@ require (
github.com/jinzhu/copier v0.3.5
github.com/prometheus/client_golang v1.15.1
github.com/prometheus/common v0.44.0
github.com/valyala/fasthttp v1.47.0
github.com/viney-shih/go-lock v1.1.2
github.com/zeebo/bencode v1.0.0
)
@ -33,6 +34,7 @@ require (
github.com/prometheus/procfs v0.9.0 // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/shopspring/decimal v1.3.1 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
go.opentelemetry.io/otel v1.15.0 // indirect
go.opentelemetry.io/otel/trace v1.15.0 // indirect
golang.org/x/sync v0.1.0 // indirect

6
go.sum
View file

@ -89,6 +89,10 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.47.0 h1:y7moDoxYzMooFpT5aHgNgVOQDrS3qlkfiP9mDtGGK9c=
github.com/valyala/fasthttp v1.47.0/go.mod h1:k2zXd82h/7UZc3VOdJ2WaUqt1uZ/XpXAfE9i+HBC3lA=
github.com/viney-shih/go-lock v1.1.2 h1:3TdGTiHZCPqBdTvFbQZQN/TRZzKF3KWw2rFEyKz3YqA=
github.com/viney-shih/go-lock v1.1.2/go.mod h1:Yijm78Ljteb3kRiJrbLAxVntkUukGu5uzSxq/xV7OO8=
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
@ -108,7 +112,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc=
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=

View file

@ -20,7 +20,7 @@ package server
import (
"bytes"
"encoding/json"
"net/http"
"github.com/valyala/fasthttp"
"time"
"chihaya/log"
@ -35,10 +35,10 @@ func alive(buf *bytes.Buffer) int {
res, err := json.Marshal(response{time.Now().UnixMilli(), time.Since(handler.startTime).Milliseconds()})
if err != nil {
log.Error.Print("Failed to marshal JSON alive response: ", err)
return http.StatusInternalServerError
return fasthttp.StatusInternalServerError
}
buf.Write(res)
return http.StatusOK
return fasthttp.StatusOK
}

View file

@ -21,9 +21,9 @@ import (
"bytes"
"context"
"fmt"
"github.com/valyala/fasthttp"
"math"
"net"
"net/http"
"strings"
"time"
@ -111,67 +111,68 @@ func getPublicIPV4(ipAddr string, exists bool) (string, bool) {
return ipAddr, !private
}
func announce(ctx context.Context, qs string, header http.Header, remoteAddr string, user *cdb.User,
db *database.Database, buf *bytes.Buffer) int {
qp, err := params.ParseQuery(qs)
func announce(
ctx context.Context, queryArgs *fasthttp.Args, header *fasthttp.RequestHeader,
remoteAddr net.Addr, user *cdb.User, db *database.Database, buf *bytes.Buffer) int {
qp, err := params.ParseQuery(queryArgs)
if err != nil {
panic(err)
}
// Mandatory parameters
infoHashes := qp.InfoHashes()
peerID, _ := qp.Get("peer_id")
port, portExists := qp.GetUint16("port")
uploaded, uploadedExists := qp.GetUint64("uploaded")
downloaded, downloadedExists := qp.GetUint64("downloaded")
left, leftExists := qp.GetUint64("left")
infoHashes := qp.Params.InfoHashes
peerID := qp.Params.PeerID
port, portExists := qp.Params.Port, qp.Exists.Port
uploaded, uploadedExists := qp.Params.Uploaded, qp.Exists.Uploaded
downloaded, downloadedExists := qp.Params.Downloaded, qp.Exists.Downloaded
left, leftExists := qp.Params.Left, qp.Exists.Left
if infoHashes == nil {
if len(infoHashes) == 0 {
failure("Malformed request - missing info_hash", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
} else if len(infoHashes) > 1 {
failure("Malformed request - can only announce singular info_hash", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if peerID == "" {
if len(peerID) == 0 {
failure("Malformed request - missing peer_id", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if len(peerID) != 20 {
failure("Malformed request - invalid peer_id", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if !portExists {
failure("Malformed request - missing port", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if strictPort && port < 1024 || port > 65535 {
failure(fmt.Sprintf("Malformed request - port outside of acceptable range (port: %d)", port), buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if !uploadedExists {
failure("Malformed request - missing uploaded", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if !downloadedExists {
failure("Malformed request - missing downloaded", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if !leftExists {
failure("Malformed request - missing left", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
ipAddr := func() string {
ipV4, existsV4 := getPublicIPV4(qp.Get("ipv4")) // First try to get IPv4 address if client sent it
ip, exists := getPublicIPV4(qp.Get("ip")) // ... then try to get public IP if sent by client
ipV4, existsV4 := getPublicIPV4(qp.Params.IPv4, qp.Exists.IPv4) // First try to get IPv4 address if client sent it
ip, exists := getPublicIPV4(qp.Params.IP, qp.Exists.IP) // ... then try to get public IP if sent by client
// Fail if ip and ipv4 are not same, and both are provided
if (existsV4 && exists) && (ip != ipV4) {
@ -188,15 +189,17 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
// Check for proxied IP header
if proxyHeader, exists := config.Section("http").Get("proxy_header", ""); exists {
if ips, exists := header[proxyHeader]; exists && len(ips) > 0 {
return ips[0]
if ips := header.PeekAll(proxyHeader); len(ips) > 0 {
return string(ips[0])
}
}
// Check for IP in socket
portIndex := strings.LastIndex(remoteAddr, ":")
remoteAddrString := remoteAddr.String()
portIndex := strings.LastIndex(remoteAddrString, ":")
if portIndex != -1 {
return remoteAddr[0:portIndex]
return remoteAddrString[0:portIndex]
}
// Everything else failed
@ -206,19 +209,19 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
ipBytes := net.ParseIP(ipAddr).To4()
if nil == ipBytes {
failure(fmt.Sprintf("Failed to parse IP address (ip: %s)", ipAddr), buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
clientID, matched := clientApproved(peerID, db)
if !matched {
failure(fmt.Sprintf("Your client is not approved (peer_id: %s)", peerID), buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
torrent, exists := (*db.Torrents.Load())[infoHashes[0]]
if !exists {
failure("This torrent does not exist", buf, 5*time.Minute)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if torrentStatus := torrent.Status.Load(); torrentStatus == 1 && left == 0 {
@ -232,10 +235,10 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
go db.UnPrune(torrent)
} else if torrentStatus != 0 {
failure(fmt.Sprintf("This torrent does not exist (status: %d, left: %d)", torrentStatus, left), buf, 15*time.Minute)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
numWant, exists := qp.GetUint16("numwant")
numWant, exists := qp.Params.NumWant, qp.Exists.NumWant
if !exists {
numWant = uint16(defaultNumWant)
} else if numWant > uint16(maxNumWant) {
@ -250,21 +253,21 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
active = true
)
event, _ := qp.Get("event")
event := qp.Params.Event
completed := event == "completed"
peerKey := cdb.NewPeerKey(user.ID.Load(), cdb.PeerIDFromRawString(peerID))
// Take torrent peers lock to read/write on it to prevent race conditions
if !torrent.TryLockWithContext(ctx) {
return http.StatusRequestTimeout
return fasthttp.StatusRequestTimeout
}
defer torrent.Unlock()
if left > 0 {
if isDisabledDownload(db, user, torrent) {
failure("Your download privileges are disabled", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
peer, exists = torrent.Leechers[peerKey]
@ -446,11 +449,9 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
response["interval"] = announceInterval + announceDrift
if numWant > 0 && active {
compactString, exists := qp.Get("compact")
compact := !exists || compactString != "0" // Defaults to being compact
compact := !qp.Exists.Compact || !qp.Params.Compact
noPeerIDString, exists := qp.Get("no_peer_id")
noPeerID := exists && noPeerIDString == "1"
noPeerID := qp.Exists.NoPeerID && qp.Params.NoPeerID
var peerCount int
if seeding {
@ -531,14 +532,14 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
// Early exit before response write
select {
case <-ctx.Done():
return http.StatusRequestTimeout
return fasthttp.StatusRequestTimeout
default:
}
encoder := bencode.NewEncoder(buf)
if err = encoder.Encode(response); err != nil {
if err := encoder.Encode(response); err != nil {
panic(err)
}
return http.StatusOK
return fasthttp.StatusOK
}

View file

@ -20,7 +20,7 @@ package server
import (
"bytes"
"context"
"net/http"
"github.com/valyala/fasthttp"
"time"
"chihaya/collectors"
@ -31,9 +31,9 @@ import (
"github.com/prometheus/common/expfmt"
)
var bearerPrefix = "Bearer "
var bearerPrefix = []byte("Bearer ")
func metrics(ctx context.Context, auth string, db *database.Database, buf *bytes.Buffer) int {
func metrics(ctx context.Context, auth []byte, db *database.Database, buf *bytes.Buffer) int {
dbUsers := *db.Users.Load()
dbTorrents := *db.Torrents.Load()
@ -46,7 +46,7 @@ func metrics(ctx context.Context, auth string, db *database.Database, buf *bytes
// Early exit before response write
select {
case <-ctx.Done():
return http.StatusRequestTimeout
return fasthttp.StatusRequestTimeout
default:
}
@ -68,9 +68,9 @@ func metrics(ctx context.Context, auth string, db *database.Database, buf *bytes
}
n := len(bearerPrefix)
if len(auth) > n && auth[:n] == bearerPrefix {
if len(auth) > n && bytes.Equal(auth[:n], bearerPrefix) {
adminToken, exists := config.Section("http").Get("admin_token", "")
if exists && auth[n:] == adminToken {
if exists && bytes.Equal(auth[:n], []byte(adminToken)) {
mfs, _ := prometheus.DefaultGatherer.Gather()
for _, mf := range mfs {
@ -82,5 +82,5 @@ func metrics(ctx context.Context, auth string, db *database.Database, buf *bytes
}
}
return http.StatusOK
return fasthttp.StatusOK
}

View file

@ -19,99 +19,153 @@
package params
import (
"net/url"
"strconv"
"strings"
"bytes"
cdb "chihaya/database/types"
"github.com/valyala/fasthttp"
"strconv"
)
type QueryParam struct {
query string
params map[string]string
infoHashes []cdb.TorrentHash
}
Params struct {
Uploaded uint64
Downloaded uint64
Left uint64
func ParseQuery(query string) (qp *QueryParam, err error) {
qp = &QueryParam{
query: query,
infoHashes: nil,
params: make(map[string]string),
Port uint16
NumWant uint16
PeerID string
IPv4 string
IP string
Event string
TestGarbageUnescape string
Compact bool
NoPeerID bool
InfoHashes []cdb.TorrentHash
}
for query != "" {
key := query
if i := strings.Index(key, "&"); i >= 0 {
key, query = key[:i], key[i+1:]
} else {
query = ""
Exists struct {
Uploaded bool
Downloaded bool
Left bool
Port bool
NumWant bool
PeerID bool
IPv4 bool
IP bool
Event bool
TestGarbageUnescape bool
Compact bool
NoPeerID bool
InfoHashes bool
}
}
var uploadedKey = []byte("uploaded")
var downloadedKey = []byte("downloaded")
var leftKey = []byte("left")
var portKey = []byte("port")
var numWant = []byte("numwant")
var peerIDKey = []byte("peer_id")
var ipv4Key = []byte("ipv4")
var ipKey = []byte("ip")
var eventKey = []byte("event")
var testGarbageUnescapeKey = []byte("!@#")
var infoHashKey = []byte("info_hash")
var compactKey = []byte("compact")
var noPeerIDKey = []byte("no_peer_id")
func ParseQuery(queryArgs *fasthttp.Args) (qp QueryParam, err error) {
var returnError error
queryArgs.VisitAll(func(key, value []byte) {
if returnError != nil {
return
}
if key == "" {
continue
}
value := ""
if i := strings.Index(key, "="); i >= 0 {
key, value = key[:i], key[i+1:]
}
key, err = url.QueryUnescape(key)
if err != nil {
panic(err)
}
value, err = url.QueryUnescape(value)
if err != nil {
panic(err)
}
if key == "info_hash" {
if len(value) == cdb.TorrentHashSize {
qp.infoHashes = append(qp.infoHashes, cdb.TorrentHashFromBytes([]byte(value)))
key = bytes.ToLower(key)
switch true {
case bytes.Equal(key, uploadedKey):
n, err := strconv.ParseUint(string(value), 10, 0)
if err != nil {
returnError = err
return
}
} else {
qp.params[strings.ToLower(key)] = value
qp.Params.Uploaded = n
qp.Exists.Uploaded = true
case bytes.Equal(key, downloadedKey):
n, err := strconv.ParseUint(string(value), 10, 0)
if err != nil {
returnError = err
return
}
qp.Params.Downloaded = n
qp.Exists.Downloaded = true
case bytes.Equal(key, leftKey):
n, err := strconv.ParseUint(string(value), 10, 0)
if err != nil {
returnError = err
return
}
qp.Params.Left = n
qp.Exists.Left = true
case bytes.Equal(key, portKey):
n, err := strconv.ParseUint(string(value), 10, 16)
if err != nil {
returnError = err
return
}
qp.Params.Port = uint16(n)
qp.Exists.Port = true
case bytes.Equal(key, numWant):
n, err := strconv.ParseUint(string(value), 10, 16)
if err != nil {
returnError = err
return
}
qp.Params.NumWant = uint16(n)
qp.Exists.NumWant = true
case bytes.Equal(key, peerIDKey):
qp.Params.PeerID = string(value)
qp.Exists.PeerID = true
case bytes.Equal(key, ipv4Key):
qp.Params.IPv4 = string(value)
qp.Exists.IPv4 = true
case bytes.Equal(key, ipKey):
qp.Params.IP = string(value)
qp.Exists.IP = true
case bytes.Equal(key, eventKey):
qp.Params.Event = string(value)
qp.Exists.Event = true
case bytes.Equal(key, testGarbageUnescapeKey):
qp.Params.TestGarbageUnescape = string(value)
qp.Exists.TestGarbageUnescape = true
case bytes.Equal(key, infoHashKey):
if len(value) == cdb.TorrentHashSize {
qp.Params.InfoHashes = append(qp.Params.InfoHashes, cdb.TorrentHashFromBytes(value))
qp.Exists.InfoHashes = true
}
case bytes.Equal(key, compactKey):
qp.Params.Compact = bytes.Equal(value, []byte{'1'})
qp.Exists.Compact = true
case bytes.Equal(key, noPeerIDKey):
qp.Params.NoPeerID = bytes.Equal(value, []byte{'1'})
qp.Exists.NoPeerID = true
}
}
})
return qp, nil
}
func (qp *QueryParam) getUint(which string, bitSize int) (ret uint64, exists bool) {
str, exists := qp.params[which]
if exists {
var err error
ret, err = strconv.ParseUint(str, 10, bitSize)
if err != nil {
exists = false
}
}
return
}
func (qp *QueryParam) Get(which string) (ret string, exists bool) {
ret, exists = qp.params[which]
return
}
func (qp *QueryParam) GetUint64(which string) (ret uint64, exists bool) {
return qp.getUint(which, 64)
}
func (qp *QueryParam) GetUint16(which string) (ret uint16, exists bool) {
tmp, exists := qp.getUint(which, 16)
ret = uint16(tmp)
return
}
func (qp *QueryParam) InfoHashes() []cdb.TorrentHash {
return qp.infoHashes
}
func (qp *QueryParam) RawQuery() string {
return qp.query
return qp, returnError
}

View file

@ -19,6 +19,7 @@ package params
import (
"fmt"
"github.com/valyala/fasthttp"
"net/url"
"os"
"reflect"
@ -44,83 +45,95 @@ func TestMain(m *testing.M) {
}
func TestParseQuery(t *testing.T) {
query := ""
var queryParsed QueryParam
queryParsed.Params.Event, queryParsed.Exists.Event = "completed", true
queryParsed.Params.Port, queryParsed.Exists.Port = 25362, true
queryParsed.Params.PeerID, queryParsed.Exists.PeerID = "-CH010-VnpZR7uz31I1A", true
queryParsed.Params.Left, queryParsed.Exists.Left = 0, true
query := fmt.Sprintf("event=%s&port=%d&peer_id=%s&left=%d",
queryParsed.Params.Event,
queryParsed.Params.Port,
queryParsed.Params.PeerID,
queryParsed.Params.Left,
)
for _, infoHash := range infoHashes {
query += "info_hash=" + url.QueryEscape(string(infoHash[:])) + "&"
queryParsed.Params.InfoHashes = append(queryParsed.Params.InfoHashes, infoHash)
queryParsed.Exists.InfoHashes = true
query += "&info_hash=" + url.QueryEscape(string(infoHash[:]))
}
queryMap := make(map[string]string)
queryMap["event"] = "completed"
queryMap["port"] = "25362"
queryMap["peer_id"] = "-CH010-VnpZR7uz31I1A"
queryMap["left"] = "0"
args := fasthttp.Args{}
args.Parse(query)
for k, v := range queryMap {
query += k + "=" + v + "&"
}
query = query[:len(query)-1]
qp, err := ParseQuery(query)
qp, err := ParseQuery(&args)
if err != nil {
panic(err)
t.Fatal(err)
}
if !reflect.DeepEqual(qp.params, queryMap) {
t.Fatalf("Parsed query map (%v) is not deeply equal as original (%v)!", qp.params, queryMap)
}
if !reflect.DeepEqual(qp.infoHashes, infoHashes) {
t.Fatalf("Parsed info hashes (%v) are not deeply equal as original (%v)!", qp.infoHashes, infoHashes)
if !reflect.DeepEqual(qp, queryParsed) {
t.Fatalf("Parsed query map (%v) is not deeply equal as original (%v)!", qp, queryParsed)
}
}
func TestBrokenParseQuery(t *testing.T) {
brokenQueryMap := make(map[string]string)
brokenQueryMap["event"] = "started"
brokenQueryMap["bug"] = ""
brokenQueryMap["yes"] = ""
var brokenQueryParsed QueryParam
brokenQueryParsed.Params.Event, brokenQueryParsed.Exists.Event = "started", true
brokenQueryParsed.Params.IPv4, brokenQueryParsed.Exists.IPv4 = "", true
brokenQueryParsed.Params.IP, brokenQueryParsed.Exists.IP = "", true
qp, err := ParseQuery("event=started&bug=&yes=")
args := fasthttp.Args{}
args.Parse("event=started&ipv4=&ip=")
qp, err := ParseQuery(&args)
if err != nil {
panic(err)
t.Fatal(err)
}
if !reflect.DeepEqual(qp.params, brokenQueryMap) {
t.Fatalf("Parsed query map (%v) is not deeply equal as original (%v)!", qp.params, brokenQueryMap)
if !reflect.DeepEqual(qp, brokenQueryParsed) {
t.Fatalf("Parsed query map (%v) is not deeply equal as original (%v)!", qp, brokenQueryParsed)
}
}
func TestLowerKey(t *testing.T) {
qp, err := ParseQuery("EvEnT=c0mPl3tED")
args := fasthttp.Args{}
args.Parse("EvEnT=c0mPl3tED")
qp, err := ParseQuery(&args)
if err != nil {
panic(err)
t.Fatal(err)
}
if param, exists := qp.Get("event"); !exists || param != "c0mPl3tED" {
if param, exists := qp.Params.Event, qp.Exists.Event; !exists || param != "c0mPl3tED" {
t.Fatalf("Got parsed value %s but expected c0mPl3tED for \"event\"!", param)
}
}
func TestUnescape(t *testing.T) {
qp, err := ParseQuery("%21%40%23=%24%25%5E")
args := fasthttp.Args{}
args.Parse("%21%40%23=%24%25%5E")
qp, err := ParseQuery(&args)
if err != nil {
panic(err)
t.Fatal(err)
}
if param, exists := qp.Get("!@#"); !exists || param != "$%^" {
if param, exists := qp.Params.TestGarbageUnescape, qp.Exists.TestGarbageUnescape; !exists || param != "$%^" {
t.Fatal(fmt.Sprintf("Got parsed value %s but expected", param), "$%^ for \"!@#\"!")
}
}
func TestGet(t *testing.T) {
qp, err := ParseQuery("event=completed")
func TestString(t *testing.T) {
args := fasthttp.Args{}
args.Parse("event=completed")
qp, err := ParseQuery(&args)
if err != nil {
panic(err)
t.Fatal(err)
}
if param, exists := qp.Get("event"); !exists || param != "completed" {
if param, exists := qp.Params.Event, qp.Exists.Event; !exists || param != "completed" {
t.Fatalf("Got parsed value %s but expected completed for \"event\"!", param)
}
}
@ -128,12 +141,15 @@ func TestGet(t *testing.T) {
func TestGetUint64(t *testing.T) {
val := uint64(1<<62 + 42)
qp, err := ParseQuery("left=" + strconv.FormatUint(val, 10))
args := fasthttp.Args{}
args.Parse("left=" + strconv.FormatUint(val, 10))
qp, err := ParseQuery(&args)
if err != nil {
panic(err)
t.Fatal(err)
}
if parsedVal, exists := qp.GetUint64("left"); !exists || parsedVal != val {
if parsedVal, exists := qp.Params.Left, qp.Exists.Left; !exists || parsedVal != val {
t.Fatalf("Got parsed value %v but expected %v for \"left\"!", parsedVal, val)
}
}
@ -141,12 +157,15 @@ func TestGetUint64(t *testing.T) {
func TestGetUint16(t *testing.T) {
val := uint16(1<<15 + 4242)
qp, err := ParseQuery("port=" + strconv.FormatUint(uint64(val), 10))
args := fasthttp.Args{}
args.Parse("port=" + strconv.FormatUint(uint64(val), 10))
qp, err := ParseQuery(&args)
if err != nil {
panic(err)
t.Fatal(err)
}
if parsedVal, exists := qp.GetUint16("port"); !exists || parsedVal != val {
if parsedVal, exists := qp.Params.Port, qp.Exists.Port; !exists || parsedVal != val {
t.Fatalf("Got parsed value %v but expected %v for \"port\"!", parsedVal, val)
}
}
@ -160,25 +179,15 @@ func TestInfoHashes(t *testing.T) {
query = query[:len(query)-1]
qp, err := ParseQuery(query)
args := fasthttp.Args{}
args.Parse(query)
qp, err := ParseQuery(&args)
if err != nil {
panic(err)
t.Fatal(err)
}
if !reflect.DeepEqual(qp.InfoHashes(), infoHashes) {
t.Fatalf("Parsed info hashes (%v) are not deeply equal as original (%v)!", qp.InfoHashes(), infoHashes)
}
}
func TestRawQuery(t *testing.T) {
q := "event=completed&port=25541&left=0&uploaded=0&downloaded=0"
qp, err := ParseQuery(q)
if err != nil {
panic(err)
}
if rq := qp.RawQuery(); rq != q {
t.Fatalf("Got raw query %s but expected %s", rq, q)
if !reflect.DeepEqual(qp.Params.InfoHashes, infoHashes) {
t.Fatalf("Parsed info hashes (%v) are not deeply equal as original (%v)!", qp.Params.InfoHashes, infoHashes)
}
}

View file

@ -19,13 +19,12 @@ package server
import (
"bytes"
"context"
"net/http"
"chihaya/config"
"chihaya/database"
cdb "chihaya/database/types"
"chihaya/server/params"
"context"
"github.com/valyala/fasthttp"
"github.com/zeebo/bencode"
)
@ -45,8 +44,10 @@ func writeScrapeInfo(torrent *cdb.Torrent) map[string]interface{} {
return ret
}
func scrape(ctx context.Context, qs string, user *cdb.User, db *database.Database, buf *bytes.Buffer) int {
qp, err := params.ParseQuery(qs)
func scrape(
ctx context.Context, queryArgs *fasthttp.Args,
user *cdb.User, db *database.Database, buf *bytes.Buffer) int {
qp, err := params.ParseQuery(queryArgs)
if err != nil {
panic(err)
}
@ -56,8 +57,8 @@ func scrape(ctx context.Context, qs string, user *cdb.User, db *database.Databas
dbTorrents := *db.Torrents.Load()
if qp.InfoHashes() != nil {
for _, infoHash := range qp.InfoHashes() {
if len(qp.Params.InfoHashes) > 0 {
for _, infoHash := range qp.Params.InfoHashes {
torrent, exists := dbTorrents[infoHash]
if exists {
if !isDisabledDownload(db, user, torrent) {
@ -80,14 +81,14 @@ func scrape(ctx context.Context, qs string, user *cdb.User, db *database.Databas
// Early exit before response write
select {
case <-ctx.Done():
return http.StatusRequestTimeout
return fasthttp.StatusRequestTimeout
default:
}
encoder := bencode.NewEncoder(buf)
if err = encoder.Encode(scrapeData); err != nil {
if err := encoder.Encode(scrapeData); err != nil {
panic(err)
}
return http.StatusOK
return fasthttp.StatusOK
}

View file

@ -20,10 +20,9 @@ package server
import (
"bytes"
"context"
"github.com/valyala/fasthttp"
"net"
"net/http"
"path"
"strconv"
"sync"
"sync/atomic"
"time"
@ -63,10 +62,10 @@ var (
listener net.Listener
)
func (handler *httpHandler) respond(ctx context.Context, r *http.Request, buf *bytes.Buffer) int {
dir, action := path.Split(r.URL.Path)
func (handler *httpHandler) respond(ctx context.Context, requestContext *fasthttp.RequestCtx, buf *bytes.Buffer) int {
dir, action := path.Split(string(requestContext.Request.URI().Path()))
if action == "" {
return http.StatusNotFound
return fasthttp.StatusNotFound
}
/*
@ -82,7 +81,7 @@ func (handler *httpHandler) respond(ctx context.Context, r *http.Request, buf *b
return alive(buf)
}
return http.StatusNotFound
return fasthttp.StatusNotFound
}
/*
@ -94,26 +93,33 @@ func (handler *httpHandler) respond(ctx context.Context, r *http.Request, buf *b
user := isPasskeyValid(passkey, handler.db)
if user == nil {
failure("Your passkey is invalid", buf, 1*time.Hour)
return http.StatusOK
return fasthttp.StatusOK
}
switch action {
case "announce":
return announce(ctx, r.URL.RawQuery, r.Header, r.RemoteAddr, user, handler.db, buf)
return announce(
ctx,
requestContext.Request.URI().QueryArgs(),
&requestContext.Request.Header,
requestContext.RemoteAddr(),
user,
handler.db,
buf)
case "scrape":
if enabled, _ := config.GetBool("scrape", true); !enabled {
return http.StatusNotFound
return fasthttp.StatusNotFound
}
return scrape(ctx, r.URL.RawQuery, user, handler.db, buf)
return scrape(ctx, requestContext.Request.URI().QueryArgs(), user, handler.db, buf)
case "metrics":
return metrics(ctx, r.Header.Get("Authorization"), handler.db, buf)
return metrics(ctx, requestContext.Request.Header.PeekBytes([]byte("Authorization")), handler.db, buf)
}
return http.StatusNotFound
return fasthttp.StatusNotFound
}
func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (handler *httpHandler) ServeHTTP(requestContext *fasthttp.RequestCtx) {
if handler.terminate {
return
}
@ -125,9 +131,6 @@ func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
handler.waitGroup.Add(1)
defer handler.waitGroup.Done()
// Flush buffered data to client
defer w.(http.Flusher).Flush()
buf := handler.bufferPool.Take()
// Mark buf to be returned to bufferPool after we are done with it
defer handler.bufferPool.Give(buf)
@ -135,25 +138,25 @@ func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Gracefully handle panics so that they're confined to single request and don't crash server
defer func() {
if err := recover(); err != nil {
log.Error.Printf("Recovered from panicking request handler - %v\nURL was: %s", err, r.URL)
log.Error.Printf("Recovered from panicking request handler - %v\nURL was: %s", err, requestContext.URI().String())
log.WriteStack()
collectors.IncrementErroredRequests()
if w.Header().Get("Content-Type") == "" {
if len(requestContext.Response.Header.ContentType()) == 0 {
buf.Reset()
w.WriteHeader(http.StatusInternalServerError)
requestContext.Response.SetStatusCode(fasthttp.StatusInternalServerError)
}
}
}()
// Prepare and start new context with timeout to abort long-running requests
ctx, cancel := context.WithTimeout(r.Context(), handler.contextTimeout)
ctx, cancel := context.WithTimeout(context.Background(), handler.contextTimeout)
defer cancel()
/* Pass flow to handler; note that handler should be responsible for actually cancelling
its own work based on request context cancellation */
status := handler.respond(ctx, r, buf)
status := handler.respond(ctx, requestContext, buf)
select {
case <-ctx.Done():
@ -163,20 +166,20 @@ func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
failure("Request context deadline exceeded", buf, 5*time.Minute)
w.Header().Add("Content-Length", strconv.Itoa(buf.Len()))
w.Header().Add("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK) // Required by torrent clients to interpret failure response
_, _ = w.Write(buf.Bytes())
requestContext.Request.Header.SetContentLength(buf.Len())
requestContext.Request.Header.SetContentTypeBytes([]byte("text/plain"))
requestContext.Response.SetStatusCode(fasthttp.StatusOK) // Required by torrent clients to interpret failure response
_, _ = requestContext.Write(buf.Bytes())
case context.Canceled:
collectors.IncrementCancelRequests()
w.WriteHeader(http.StatusRequestTimeout)
requestContext.Response.SetStatusCode(fasthttp.StatusRequestTimeout)
}
default:
w.Header().Add("Content-Length", strconv.Itoa(buf.Len()))
w.Header().Add("Content-Type", "text/plain")
w.WriteHeader(status)
_, _ = w.Write(buf.Bytes())
requestContext.Request.Header.SetContentLength(buf.Len())
requestContext.Request.Header.SetContentTypeBytes([]byte("text/plain"))
requestContext.Response.SetStatusCode(status)
_, _ = requestContext.Write(buf.Bytes())
}
}
@ -190,7 +193,6 @@ func Start() {
addr, _ := config.Section("http").Get("addr", ":34000")
readTimeout, _ := config.Section("http").Section("timeout").GetInt("read", 1)
readHeaderTimeout, _ := config.Section("http").Section("timeout").GetInt("read_header", 2)
writeTimeout, _ := config.Section("http").Section("timeout").GetInt("write", 3)
idleTimeout, _ := config.Section("http").Section("timeout").GetInt("idle", 30)
@ -198,17 +200,23 @@ func Start() {
handler.contextTimeout = time.Duration(writeTimeout)*time.Second - 200*time.Millisecond
// Create new server instance
server := &http.Server{
Handler: handler,
ReadTimeout: time.Duration(readTimeout) * time.Second,
ReadHeaderTimeout: time.Duration(readHeaderTimeout) * time.Second,
WriteTimeout: time.Duration(writeTimeout) * time.Second,
IdleTimeout: time.Duration(idleTimeout) * time.Second,
server := &fasthttp.Server{
Handler: handler.ServeHTTP,
ReadTimeout: time.Duration(readTimeout) * time.Second,
WriteTimeout: time.Duration(writeTimeout) * time.Second,
IdleTimeout: time.Duration(idleTimeout) * time.Second,
GetOnly: true,
DisablePreParseMultipartForm: true,
NoDefaultServerHeader: true,
NoDefaultDate: true,
NoDefaultContentType: true,
CloseOnShutdown: true,
}
if idleTimeout <= 0 {
log.Warning.Print("Setting idleTimeout <= 0 disables Keep-Alive which might negatively impact performance")
server.SetKeepAlivesEnabled(false)
server.DisableKeepalive = true
}
// Start new goroutine to calculate throughput
@ -261,8 +269,7 @@ func Start() {
// Wait for active connections to finish processing
handler.waitGroup.Wait()
// Close server so that it does not Accept(), https://github.com/golang/go/issues/10527
_ = server.Close()
_ = server.Shutdown()
log.Info.Print("Now closed and not accepting any new connections")

20
util/context_ticker.go Normal file
View file

@ -0,0 +1,20 @@
package util
import (
"context"
"time"
)
func ContextTick(ctx context.Context, d time.Duration, onTick func()) {
ticker := time.NewTicker(d)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
onTick()
}
}
}