Compare commits

...

7 commits

28 changed files with 1256 additions and 969 deletions

View file

@ -58,12 +58,7 @@ Configuration is done in `config.json`, which you'll need to create with the fol
```json
{
"database": {
"username": "chihaya",
"password": "",
"database": "chihaya",
"proto": "tcp",
"addr": "127.0.0.1:3306",
"dsn": "chihaya:@tcp(127.0.0.1:3306)/chihaya",
"deadlock_pause": 1,
"deadlock_retries": 5
},
@ -96,7 +91,6 @@ Configuration is done in `config.json`, which you'll need to create with the fol
"proxy_header": "",
"timeout": {
"read": 1,
"read_header": 2,
"write": 1,
"idle": 30
}
@ -143,8 +137,7 @@ Configuration is done in `config.json`, which you'll need to create with the fol
- `admin_token` - administrative token used in `Authorization` header to access advanced prometheus statistics
- `proxy_header` - header name to look for user's real IP address, for example `X-Real-Ip`
- `timeout`
- `read_header` - timeout in seconds for reading request headers
- `read` - timeout in seconds for reading request; is reset after headers are read
- `read` - timeout in seconds for reading request
- `write` - timeout in seconds for writing response (total time spent)
- `idle` - how long (in seconds) to keep connection open for keep-alive requests
- `announce`

View file

@ -133,17 +133,17 @@ func main() {
newUserID++
// Create mapping to get consistent peers
anonUserMapping[user.ID] = newUserID
anonUserMapping[user.ID.Load()] = newUserID
// Replaces user id
user.ID = newUserID
user.ID.Store(newUserID)
// Replaces hidden flag
user.TrackerHide = false
user.TrackerHide.Store(false)
// Replace Up/Down multipliers with baseline
user.UpMultiplier = 1.0
user.DownMultiplier = 1.0
user.UpMultiplier.Store(math.Float64bits(1.0))
user.DownMultiplier.Store(math.Float64bits(1.0))
// Replace passkey with a random provided one of same length
for {

View file

@ -32,6 +32,7 @@ type AdminCollector struct {
deadlockAbortedMetric *prometheus.Desc
erroredRequestsMetric *prometheus.Desc
timeoutRequestsMetric *prometheus.Desc
cancelRequestsMetric *prometheus.Desc
sqlErrorCountMetric *prometheus.Desc
serializationTimeSummary *prometheus.Histogram
@ -81,6 +82,7 @@ var (
deadlockAborted = 0
erroredRequests = 0
timeoutRequests = 0
cancelRequests = 0
sqlErrorCount = 0
)
@ -131,6 +133,8 @@ func NewAdminCollector() *AdminCollector {
"Number of failed requests", nil, nil),
timeoutRequestsMetric: prometheus.NewDesc("chihaya_requests_timeout",
"Number of requests for which context deadline was exceeded", nil, nil),
cancelRequestsMetric: prometheus.NewDesc("chihaya_requests_cancel",
"Number of requests for which context was prematurely cancelled", nil, nil),
sqlErrorCountMetric: prometheus.NewDesc("chihaya_sql_errors_count",
"Number of SQL errors", nil, nil),
@ -152,6 +156,7 @@ func (collector *AdminCollector) Describe(ch chan<- *prometheus.Desc) {
ch <- collector.deadlockTimeMetric
ch <- collector.erroredRequestsMetric
ch <- collector.timeoutRequestsMetric
ch <- collector.cancelRequestsMetric
ch <- collector.sqlErrorCountMetric
serializationTime.Describe(ch)
@ -171,6 +176,7 @@ func (collector *AdminCollector) Collect(ch chan<- prometheus.Metric) {
ch <- prometheus.MustNewConstMetric(collector.deadlockTimeMetric, prometheus.CounterValue, deadlockTime.Seconds())
ch <- prometheus.MustNewConstMetric(collector.erroredRequestsMetric, prometheus.CounterValue, float64(erroredRequests))
ch <- prometheus.MustNewConstMetric(collector.timeoutRequestsMetric, prometheus.CounterValue, float64(timeoutRequests))
ch <- prometheus.MustNewConstMetric(collector.cancelRequestsMetric, prometheus.CounterValue, float64(cancelRequests))
ch <- prometheus.MustNewConstMetric(collector.sqlErrorCountMetric, prometheus.CounterValue, float64(sqlErrorCount))
serializationTime.Collect(ch)
@ -204,6 +210,10 @@ func IncrementTimeoutRequests() {
timeoutRequests++
}
func IncrementCancelRequests() {
cancelRequests++
}
func IncrementSQLErrorCount() {
sqlErrorCount++
}

View file

@ -40,18 +40,16 @@ func TestMain(m *testing.M) {
panic(err)
}
f, err := os.OpenFile("config.json", os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
f, err := os.OpenFile(configFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
panic(err)
}
configTest = make(Map)
dbConfig := map[string]interface{}{
"username": "chihaya",
"password": "",
"proto": "tcp",
"addr": "127.0.0.1:3306",
"database": "chihaya",
"dsn": "chihaya:@tcp(127.0.0.1:3306)/chihaya",
"deadlock_pause": json.Number("1"),
"deadlock_retries": json.Number("5"),
}
configTest["database"] = dbConfig
configTest["addr"] = ":34000"
@ -148,5 +146,5 @@ func TestSection(t *testing.T) {
}
func cleanup() {
_ = os.Remove("config.json")
_ = os.Remove(configFile)
}

View file

@ -19,9 +19,8 @@ package database
import (
"bytes"
"context"
"database/sql"
"fmt"
"github.com/viney-shih/go-lock"
"os"
"sync"
"sync/atomic"
@ -36,15 +35,7 @@ import (
"github.com/go-sql-driver/mysql"
)
type Connection struct {
sqlDb *sql.DB
mutex sync.Mutex
}
type Database struct {
TorrentsLock lock.RWMutex
UsersLock lock.RWMutex
snatchChannel chan *bytes.Buffer
transferHistoryChannel chan *bytes.Buffer
transferIpsChannel chan *bytes.Buffer
@ -60,19 +51,21 @@ type Database struct {
cleanStalePeersStmt *sql.Stmt
unPruneTorrentStmt *sql.Stmt
Users map[string]*cdb.User
Users atomic.Pointer[map[string]*cdb.User]
HitAndRuns atomic.Pointer[map[cdb.UserTorrentPair]struct{}]
Torrents map[cdb.TorrentHash]*cdb.Torrent // SHA-1 hash (20 bytes)
TorrentGroupFreeleech atomic.Pointer[map[cdb.TorrentGroup]*cdb.TorrentGroupFreeleech]
Torrents atomic.Pointer[map[cdb.TorrentHash]*cdb.Torrent]
TorrentGroupFreeleech atomic.Pointer[map[cdb.TorrentGroupKey]*cdb.TorrentGroupFreeleech]
Clients atomic.Pointer[map[uint16]string]
mainConn *Connection // Used for reloading and misc queries
bufferPool *util.BufferPool
transferHistoryLock lock.Mutex
transferHistoryLock sync.Mutex
terminate bool
conn *sql.DB
terminate atomic.Bool
ctx context.Context
ctxCancel func()
waitGroup sync.WaitGroup
}
@ -81,82 +74,76 @@ var (
maxDeadlockRetries int
)
var defaultDsn = map[string]string{
"username": "chihaya",
"password": "",
"proto": "tcp",
"addr": "127.0.0.1:3306",
"database": "chihaya",
}
const defaultDsn = "chihaya:@tcp(127.0.0.1:3306)/chihaya"
func (db *Database) Init() {
db.terminate = false
db.terminate.Store(false)
db.ctx, db.ctxCancel = context.WithCancel(context.Background())
log.Info.Print("Opening database connection...")
db.mainConn = Open()
// Initializing locks
db.TorrentsLock = lock.NewCASMutex()
db.UsersLock = lock.NewCASMutex()
db.conn = Open()
// Used for recording updates, so the max required size should be < 128 bytes. See queue.go for details
db.bufferPool = util.NewBufferPool(128)
var err error
db.loadUsersStmt, err = db.mainConn.sqlDb.Prepare(
db.loadUsersStmt, err = db.conn.Prepare(
"SELECT ID, torrent_pass, DownMultiplier, UpMultiplier, DisableDownload, TrackerHide " +
"FROM users_main WHERE Enabled = '1'")
if err != nil {
panic(err)
}
db.loadHnrStmt, err = db.mainConn.sqlDb.Prepare(
db.loadHnrStmt, err = db.conn.Prepare(
"SELECT h.uid, h.fid FROM transfer_history AS h " +
"JOIN users_main AS u ON u.ID = h.uid WHERE h.hnr = 1 AND u.Enabled = '1'")
if err != nil {
panic(err)
}
db.loadTorrentsStmt, err = db.mainConn.sqlDb.Prepare(
db.loadTorrentsStmt, err = db.conn.Prepare(
"SELECT ID, info_hash, DownMultiplier, UpMultiplier, Snatched, Status, GroupID, TorrentType FROM torrents")
if err != nil {
panic(err)
}
db.loadTorrentGroupFreeleechStmt, err = db.mainConn.sqlDb.Prepare(
db.loadTorrentGroupFreeleechStmt, err = db.conn.Prepare(
"SELECT GroupID, `Type`, DownMultiplier, UpMultiplier FROM torrent_group_freeleech")
if err != nil {
panic(err)
}
db.loadClientsStmt, err = db.mainConn.sqlDb.Prepare(
db.loadClientsStmt, err = db.conn.Prepare(
"SELECT id, peer_id FROM approved_clients WHERE archived = 0")
if err != nil {
panic(err)
}
db.loadFreeleechStmt, err = db.mainConn.sqlDb.Prepare(
db.loadFreeleechStmt, err = db.conn.Prepare(
"SELECT mod_setting FROM mod_core WHERE mod_option = 'global_freeleech'")
if err != nil {
panic(err)
}
db.cleanStalePeersStmt, err = db.mainConn.sqlDb.Prepare(
db.cleanStalePeersStmt, err = db.conn.Prepare(
"UPDATE transfer_history SET active = 0 WHERE last_announce < ? AND active = 1")
if err != nil {
panic(err)
}
db.unPruneTorrentStmt, err = db.mainConn.sqlDb.Prepare(
db.unPruneTorrentStmt, err = db.conn.Prepare(
"UPDATE torrents SET Status = 0 WHERE ID = ?")
if err != nil {
panic(err)
}
db.Users = make(map[string]*cdb.User)
db.Torrents = make(map[cdb.TorrentHash]*cdb.Torrent)
dbUsers := make(map[string]*cdb.User)
db.Users.Store(&dbUsers)
dbTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
db.Torrents.Store(&dbTorrents)
dbHitAndRuns := make(map[cdb.UserTorrentPair]struct{})
db.HitAndRuns.Store(&dbHitAndRuns)
@ -184,7 +171,8 @@ func (db *Database) Init() {
func (db *Database) Terminate() {
log.Info.Print("Terminating database connection...")
db.terminate = true
db.terminate.Store(true)
db.ctxCancel()
log.Info.Print("Closing all flush channels...")
db.closeFlushChannels()
@ -195,13 +183,11 @@ func (db *Database) Terminate() {
}()
db.waitGroup.Wait()
db.mainConn.mutex.Lock()
_ = db.mainConn.Close()
db.mainConn.mutex.Unlock()
_ = db.conn.Close()
db.serialize()
}
func Open() *Connection {
func Open() *sql.DB {
databaseConfig := config.Section("database")
deadlockWaitTime, _ = databaseConfig.GetInt("deadlock_pause", 1)
maxDeadlockRetries, _ = databaseConfig.GetInt("deadlock_retries", 5)
@ -217,18 +203,7 @@ func Open() *Connection {
// First try to load the DSN from environment. USeful for tests.
databaseDsn := os.Getenv("DB_DSN")
if databaseDsn == "" {
dbUsername, _ := databaseConfig.Get("username", defaultDsn["username"])
dbPassword, _ := databaseConfig.Get("password", defaultDsn["password"])
dbProto, _ := databaseConfig.Get("proto", defaultDsn["proto"])
dbAddr, _ := databaseConfig.Get("addr", defaultDsn["addr"])
dbDatabase, _ := databaseConfig.Get("database", defaultDsn["database"])
databaseDsn = fmt.Sprintf("%s:%s@%s(%s)/%s",
dbUsername,
dbPassword,
dbProto,
dbAddr,
dbDatabase,
)
databaseDsn, _ = databaseConfig.Get("dsn", defaultDsn)
}
sqlDb, err := sql.Open("mysql", databaseDsn)
@ -240,16 +215,10 @@ func Open() *Connection {
log.Fatal.Fatalf("Couldn't ping database - %s", err)
}
return &Connection{
sqlDb: sqlDb,
}
return sqlDb
}
func (db *Connection) Close() error {
return db.sqlDb.Close()
}
func (db *Connection) query(stmt *sql.Stmt, args ...interface{}) *sql.Rows { //nolint:unparam
func (db *Database) query(stmt *sql.Stmt, args ...interface{}) *sql.Rows { //nolint:unparam
rows, _ := perform(func() (interface{}, error) {
return stmt.Query(args...)
}).(*sql.Rows)
@ -257,7 +226,7 @@ func (db *Connection) query(stmt *sql.Stmt, args ...interface{}) *sql.Rows { //n
return rows
}
func (db *Connection) execute(stmt *sql.Stmt, args ...interface{}) sql.Result {
func (db *Database) execute(stmt *sql.Stmt, args ...interface{}) sql.Result {
result, _ := perform(func() (interface{}, error) {
return stmt.Exec(args...)
}).(sql.Result)
@ -265,9 +234,9 @@ func (db *Connection) execute(stmt *sql.Stmt, args ...interface{}) sql.Result {
return result
}
func (db *Connection) exec(query *bytes.Buffer, args ...interface{}) sql.Result { //nolint:unparam
func (db *Database) exec(query *bytes.Buffer, args ...interface{}) sql.Result { //nolint:unparam
result, _ := perform(func() (interface{}, error) {
return db.sqlDb.Exec(query.String(), args...)
return db.conn.Exec(query.String(), args...)
}).(sql.Result)
return result

View file

@ -19,6 +19,8 @@ package database
import (
"fmt"
"github.com/google/go-cmp/cmp"
"math"
"net"
"os"
"reflect"
@ -44,7 +46,7 @@ func TestMain(m *testing.M) {
db.Init()
fixtures, err = testfixtures.New(
testfixtures.Database(db.mainConn.sqlDb),
testfixtures.Database(db.conn),
testfixtures.Dialect("mariadb"),
testfixtures.Directory("fixtures"),
testfixtures.DangerousSkipTestDatabaseCheck(),
@ -59,47 +61,55 @@ func TestMain(m *testing.M) {
func TestLoadUsers(t *testing.T) {
prepareTestDatabase()
db.Users = make(map[string]*cdb.User)
dbUsers := make(map[string]*cdb.User)
db.Users.Store(&dbUsers)
testUser1 := &cdb.User{}
testUser1.ID.Store(1)
testUser1.DownMultiplier.Store(math.Float64bits(1))
testUser1.UpMultiplier.Store(math.Float64bits(1))
testUser1.DisableDownload.Store(false)
testUser1.TrackerHide.Store(false)
testUser2 := &cdb.User{}
testUser2.ID.Store(2)
testUser2.DownMultiplier.Store(math.Float64bits(2))
testUser2.UpMultiplier.Store(math.Float64bits(0.5))
testUser2.DisableDownload.Store(true)
testUser2.TrackerHide.Store(true)
users := map[string]*cdb.User{
"mUztWMpBYNCqzmge6vGeEUGSrctJbgpQ": {
ID: 1,
UpMultiplier: 1,
DownMultiplier: 1,
TrackerHide: false,
DisableDownload: false,
},
"tbHfQDQ9xDaQdsNv5CZBtHPfk7KGzaCw": {
ID: 2,
UpMultiplier: 0.5,
DownMultiplier: 2,
TrackerHide: true,
DisableDownload: true,
},
"mUztWMpBYNCqzmge6vGeEUGSrctJbgpQ": testUser1,
"tbHfQDQ9xDaQdsNv5CZBtHPfk7KGzaCw": testUser2,
}
// Test with fresh data
db.loadUsers()
if len(db.Users) != len(users) {
t.Fatal(fixtureFailure("Did not load all users as expected from fixture file", len(users), len(db.Users)))
dbUsers = *db.Users.Load()
if len(dbUsers) != len(users) {
t.Fatal(fixtureFailure("Did not load all users as expected from fixture file", len(users), len(dbUsers)))
}
for passkey, user := range users {
if !reflect.DeepEqual(user, db.Users[passkey]) {
if !reflect.DeepEqual(user, dbUsers[passkey]) {
t.Fatal(fixtureFailure(
fmt.Sprintf("Did not load user (%s) as expected from fixture file", passkey),
user,
db.Users[passkey]))
dbUsers[passkey]))
}
}
// Now test load on top of existing data
oldUsers := db.Users
oldUsers := dbUsers
db.loadUsers()
if !reflect.DeepEqual(oldUsers, db.Users) {
t.Fatal(fixtureFailure("Did not reload users as expected from fixture file", oldUsers, db.Users))
dbUsers = *db.Users.Load()
if !reflect.DeepEqual(oldUsers, dbUsers) {
t.Fatal(fixtureFailure("Did not reload users as expected from fixture file", oldUsers, dbUsers))
}
}
@ -133,89 +143,96 @@ func TestLoadHitAndRuns(t *testing.T) {
func TestLoadTorrents(t *testing.T) {
prepareTestDatabase()
db.Torrents = make(map[cdb.TorrentHash]*cdb.Torrent)
dbTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
db.Torrents.Store(&dbTorrents)
t1 := &cdb.Torrent{
Seeders: map[cdb.PeerKey]*cdb.Peer{},
Leechers: map[cdb.PeerKey]*cdb.Peer{},
}
t1.ID.Store(1)
t1.Status.Store(1)
t1.Snatched.Store(2)
t1.DownMultiplier.Store(math.Float64bits(1))
t1.UpMultiplier.Store(math.Float64bits(1))
t1.Group.GroupID.Store(1)
t1.Group.TorrentType.Store(cdb.MustTorrentTypeFromString("anime"))
t2 := &cdb.Torrent{
Seeders: map[cdb.PeerKey]*cdb.Peer{},
Leechers: map[cdb.PeerKey]*cdb.Peer{},
}
t2.ID.Store(2)
t2.Status.Store(0)
t2.Snatched.Store(0)
t2.DownMultiplier.Store(math.Float64bits(2))
t2.UpMultiplier.Store(math.Float64bits(0.5))
t2.Group.GroupID.Store(1)
t2.Group.TorrentType.Store(cdb.MustTorrentTypeFromString("music"))
t3 := &cdb.Torrent{
Seeders: map[cdb.PeerKey]*cdb.Peer{},
Leechers: map[cdb.PeerKey]*cdb.Peer{},
}
t3.ID.Store(3)
t3.Status.Store(0)
t3.Snatched.Store(0)
t3.DownMultiplier.Store(math.Float64bits(1))
t3.UpMultiplier.Store(math.Float64bits(1))
t3.Group.GroupID.Store(2)
t3.Group.TorrentType.Store(cdb.MustTorrentTypeFromString("anime"))
torrents := map[cdb.TorrentHash]*cdb.Torrent{
{114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56}: {
ID: 1,
Status: 1,
Snatched: 2,
DownMultiplier: 1,
UpMultiplier: 1,
Seeders: map[cdb.PeerKey]*cdb.Peer{},
Leechers: map[cdb.PeerKey]*cdb.Peer{},
Group: cdb.TorrentGroup{
GroupID: 1,
TorrentType: "anime",
},
},
{22, 168, 45, 221, 87, 225, 140, 177, 94, 34, 242, 225, 196, 234, 222, 46, 187, 131, 177, 155}: {
ID: 2,
Status: 0,
Snatched: 0,
DownMultiplier: 2,
UpMultiplier: 0.5,
Seeders: map[cdb.PeerKey]*cdb.Peer{},
Leechers: map[cdb.PeerKey]*cdb.Peer{},
Group: cdb.TorrentGroup{
GroupID: 1,
TorrentType: "music",
},
},
{89, 252, 84, 49, 177, 28, 118, 28, 148, 205, 62, 185, 8, 37, 234, 110, 109, 200, 165, 241}: {
ID: 3,
Status: 0,
Snatched: 0,
DownMultiplier: 1,
UpMultiplier: 1,
Seeders: map[cdb.PeerKey]*cdb.Peer{},
Leechers: map[cdb.PeerKey]*cdb.Peer{},
Group: cdb.TorrentGroup{
GroupID: 2,
TorrentType: "anime",
},
},
{114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56}: t1,
{22, 168, 45, 221, 87, 225, 140, 177, 94, 34, 242, 225, 196, 234, 222, 46, 187, 131, 177, 155}: t2,
{89, 252, 84, 49, 177, 28, 118, 28, 148, 205, 62, 185, 8, 37, 234, 110, 109, 200, 165, 241}: t3,
}
for k := range torrents {
torrents[k].InitializeLock()
}
// Test with fresh data
db.loadTorrents()
if len(db.Torrents) != len(torrents) {
dbTorrents = *db.Torrents.Load()
if len(dbTorrents) != len(torrents) {
t.Fatal(fixtureFailure("Did not load all torrents as expected from fixture file",
len(torrents),
len(db.Torrents)))
len(dbTorrents)))
}
for hash, torrent := range torrents {
if !reflect.DeepEqual(torrent, db.Torrents[hash]) {
if !cmp.Equal(torrent, dbTorrents[hash], cdb.TorrentTestCompareOptions...) {
hashHex, _ := hash.MarshalText()
t.Fatal(fixtureFailure(
fmt.Sprintf("Did not load torrent (%s) as expected from fixture file", hash),
fmt.Sprintf("Did not load torrent (%s) as expected from fixture file", string(hashHex)),
torrent,
db.Torrents[hash]))
dbTorrents[hash]))
}
}
// Now test load on top of existing data
oldTorrents := db.Torrents
oldTorrents := dbTorrents
db.loadTorrents()
if !reflect.DeepEqual(oldTorrents, db.Torrents) {
t.Fatal(fixtureFailure("Did not reload torrents as expected from fixture file", oldTorrents, db.Torrents))
dbTorrents = *db.Torrents.Load()
if !cmp.Equal(oldTorrents, dbTorrents, cdb.TorrentTestCompareOptions...) {
t.Fatal(fixtureFailure("Did not reload torrents as expected from fixture file", oldTorrents, dbTorrents))
}
}
func TestLoadGroupsFreeleech(t *testing.T) {
prepareTestDatabase()
dbMap := make(map[cdb.TorrentGroup]*cdb.TorrentGroupFreeleech)
dbMap := make(map[cdb.TorrentGroupKey]*cdb.TorrentGroupFreeleech)
db.TorrentGroupFreeleech.Store(&dbMap)
torrentGroupFreeleech := map[cdb.TorrentGroup]*cdb.TorrentGroupFreeleech{
{
GroupID: 2,
TorrentType: "anime",
}: {
torrentGroupFreeleech := map[cdb.TorrentGroupKey]*cdb.TorrentGroupFreeleech{
cdb.MustTorrentGroupKeyFromString("anime", 2): {
DownMultiplier: 0,
UpMultiplier: 2,
},
@ -297,36 +314,48 @@ func TestLoadClients(t *testing.T) {
func TestUnPrune(t *testing.T) {
prepareTestDatabase()
h := cdb.TorrentHash{114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56}
torrent := cdb.Torrent{
Seeders: db.Torrents[h].Seeders,
Leechers: db.Torrents[h].Leechers,
Group: db.Torrents[h].Group,
ID: db.Torrents[h].ID,
Snatched: db.Torrents[h].Snatched,
Status: db.Torrents[h].Status,
LastAction: db.Torrents[h].LastAction,
UpMultiplier: db.Torrents[h].UpMultiplier,
DownMultiplier: db.Torrents[h].DownMultiplier,
}
torrent.Status = 0
dbTorrents := *db.Torrents.Load()
db.UnPrune(db.Torrents[h])
h := cdb.TorrentHash{114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56}
dbTorrent := dbTorrents[h]
torrent := cdb.Torrent{
Seeders: dbTorrent.Seeders,
Leechers: dbTorrent.Leechers,
}
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
torrent.ID.Store(dbTorrent.ID.Load())
torrent.Status.Store(dbTorrent.Status.Load())
torrent.Snatched.Store(dbTorrent.Snatched.Load())
torrent.DownMultiplier.Store(dbTorrent.DownMultiplier.Load())
torrent.UpMultiplier.Store(dbTorrent.UpMultiplier.Load())
torrent.Group.GroupID.Store(dbTorrent.Group.GroupID.Load())
torrent.Group.TorrentType.Store(dbTorrent.Group.TorrentType.Load())
torrent.InitializeLock()
torrent.Status.Store(0)
db.UnPrune(dbTorrents[h])
db.loadTorrents()
if !reflect.DeepEqual(&torrent, db.Torrents[h]) {
dbTorrents = *db.Torrents.Load()
if !cmp.Equal(&torrent, dbTorrents[h], cdb.TorrentTestCompareOptions...) {
t.Fatal(fixtureFailure(
fmt.Sprintf("Torrent (%x) was not unpruned properly", h),
&torrent,
db.Torrents[h]))
dbTorrents[h]))
}
}
func TestRecordAndFlushUsers(t *testing.T) {
prepareTestDatabase()
testUser := db.Users["tbHfQDQ9xDaQdsNv5CZBtHPfk7KGzaCw"]
dbUsers := *db.Users.Load()
testUser := dbUsers["tbHfQDQ9xDaQdsNv5CZBtHPfk7KGzaCw"]
var (
initUpload int64
@ -347,11 +376,11 @@ func TestRecordAndFlushUsers(t *testing.T) {
deltaRawDownload = 83472
deltaRawUpload = 245
deltaDownload = int64(float64(deltaRawDownload) * testUser.DownMultiplier)
deltaUpload = int64(float64(deltaRawUpload) * testUser.UpMultiplier)
deltaDownload = int64(float64(deltaRawDownload) * math.Float64frombits(testUser.DownMultiplier.Load()))
deltaUpload = int64(float64(deltaRawUpload) * math.Float64frombits(testUser.UpMultiplier.Load()))
row := db.mainConn.sqlDb.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
"FROM users_main WHERE ID = ?", testUser.ID)
row := db.conn.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
"FROM users_main WHERE ID = ?", testUser.ID.Load())
err := row.Scan(&initUpload, &initDownload, &initRawUpload, &initRawDownload)
if err != nil {
@ -365,8 +394,8 @@ func TestRecordAndFlushUsers(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row = db.mainConn.sqlDb.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
"FROM users_main WHERE ID = ?", testUser.ID)
row = db.conn.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
"FROM users_main WHERE ID = ?", testUser.ID.Load())
err = row.Scan(&upload, &download, &rawUpload, &rawDownload)
if err != nil {
@ -446,7 +475,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
deltaActiveTime = 267
deltaSeedTime = 15
row := db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
row := db.conn.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
"FROM transfer_history WHERE uid = ? AND fid = ?", testPeer.UserID, testPeer.TorrentID)
err := row.Scan(&initRawUpload, &initRawDownload, &initActiveTime, &initSeedTime, &initActive, &initSnatch)
@ -467,7 +496,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row = db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
row = db.conn.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
"FROM transfer_history WHERE uid = ? AND fid = ?", testPeer.UserID, testPeer.TorrentID)
err = row.Scan(&rawUpload, &rawDownload, &activeTime, &seedTime, &active, &snatch)
@ -528,7 +557,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
var gotStartTime int64
row = db.mainConn.sqlDb.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
row = db.conn.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
"FROM transfer_history WHERE uid = ? AND fid = ?", gotPeer.UserID, gotPeer.TorrentID)
err = row.Scan(&gotPeer.Seeding, &gotStartTime, &gotPeer.LastAnnounce, &gotPeer.Left)
@ -554,7 +583,14 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
Left: 65000,
}
db.QueueTransferHistory(testPeer, 0, 1000, 1, 0, 1, true)
db.QueueTransferHistory(
testPeer,
0,
1000,
1,
0,
1,
true)
gotPeer = &cdb.Peer{
UserID: testPeer.UserID,
@ -566,7 +602,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row = db.mainConn.sqlDb.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
row = db.conn.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
"FROM transfer_history WHERE uid = ? AND fid = ?", gotPeer.UserID, gotPeer.TorrentID)
err = row.Scan(&gotPeer.Seeding, &gotPeer.StartTime, &gotPeer.LastAnnounce, &gotPeer.Left)
@ -603,7 +639,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
deltaDownload = 236
deltaUpload = 3262
row := db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded "+
row := db.conn.QueryRow("SELECT uploaded, downloaded "+
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
@ -619,7 +655,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row = db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded "+
row = db.conn.QueryRow("SELECT uploaded, downloaded "+
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
@ -655,7 +691,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
var gotStartTime int64
row = db.mainConn.sqlDb.QueryRow("SELECT port, starttime, last_announce "+
row = db.conn.QueryRow("SELECT port, starttime, last_announce "+
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
@ -700,7 +736,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
Addr: cdb.NewPeerAddressFromIPPort(testPeer.Addr.IP(), 0),
}
row = db.mainConn.sqlDb.QueryRow("SELECT port, starttime, last_announce "+
row = db.conn.QueryRow("SELECT port, starttime, last_announce "+
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
@ -738,7 +774,7 @@ func TestRecordAndFlushSnatch(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row := db.mainConn.sqlDb.QueryRow("SELECT snatched_time "+
row := db.conn.QueryRow("SELECT snatched_time "+
"FROM transfer_history WHERE uid = ? AND fid = ?", testPeer.UserID, testPeer.TorrentID)
err := row.Scan(&snatchTime)
@ -758,22 +794,24 @@ func TestRecordAndFlushTorrents(t *testing.T) {
prepareTestDatabase()
h := cdb.TorrentHash{114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56}
torrent := db.Torrents[h]
torrent.LastAction = time.Now().Unix()
torrent := (*db.Torrents.Load())[h]
torrent.LastAction.Store(time.Now().Unix())
torrent.Seeders[cdb.NewPeerKey(1, cdb.PeerIDFromRawString("test_peer_id_num_one"))] = &cdb.Peer{
UserID: 1,
TorrentID: torrent.ID,
TorrentID: torrent.ID.Load(),
ClientID: 1,
StartTime: time.Now().Unix(),
LastAnnounce: time.Now().Unix(),
}
torrent.Leechers[cdb.NewPeerKey(3, cdb.PeerIDFromRawString("test_peer_id_num_two"))] = &cdb.Peer{
UserID: 3,
TorrentID: torrent.ID,
TorrentID: torrent.ID.Load(),
ClientID: 2,
StartTime: time.Now().Unix(),
LastAnnounce: time.Now().Unix(),
}
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
db.QueueTorrent(torrent, 5)
@ -789,26 +827,26 @@ func TestRecordAndFlushTorrents(t *testing.T) {
numSeeders int
)
row := db.mainConn.sqlDb.QueryRow("SELECT Snatched, last_action, Seeders, Leechers "+
"FROM torrents WHERE ID = ?", torrent.ID)
row := db.conn.QueryRow("SELECT Snatched, last_action, Seeders, Leechers "+
"FROM torrents WHERE ID = ?", torrent.ID.Load())
err := row.Scan(&snatched, &lastAction, &numSeeders, &numLeechers)
if err != nil {
panic(err)
}
if torrent.Snatched+5 != snatched {
if uint16(torrent.Snatched.Load())+5 != snatched {
t.Fatal(fixtureFailure(
fmt.Sprintf("Snatches incorrectly updated in the database for torrent %x", h),
torrent.Snatched+5,
torrent.Snatched.Load()+5,
snatched,
))
}
if torrent.LastAction != lastAction {
if torrent.LastAction.Load() != lastAction {
t.Fatal(fixtureFailure(
fmt.Sprintf("Last incorrectly updated in the database for torrent %x", h),
torrent.LastAction,
torrent.LastAction.Load(),
lastAction,
))
}
@ -828,6 +866,22 @@ func TestRecordAndFlushTorrents(t *testing.T) {
numLeechers,
))
}
if int(torrent.SeedersLength.Load()) != numSeeders {
t.Fatal(fixtureFailure(
fmt.Sprintf("SeedersLength incorrectly updated in the database for torrent %x", h),
len(torrent.Seeders),
numSeeders,
))
}
if int(torrent.LeechersLength.Load()) != numLeechers {
t.Fatal(fixtureFailure(
fmt.Sprintf("LeechersLength incorrectly updated in the database for torrent %x", h),
len(torrent.Leechers),
numLeechers,
))
}
}
func TestTerminate(_ *testing.T) {

View file

@ -1,5 +0,0 @@
- ID: 1
Time: 1592659152
- ID: 2
Time: 1609032772

View file

@ -1,2 +0,0 @@
- ID: 1
Time: 1592659152

View file

@ -19,8 +19,8 @@ package database
import (
"bytes"
"chihaya/util"
"errors"
"github.com/viney-shih/go-lock"
"time"
"chihaya/collectors"
@ -82,8 +82,6 @@ func (db *Database) startFlushing() {
db.transferIpsChannel = make(chan *bytes.Buffer, transferIpsFlushBufferSize)
db.snatchChannel = make(chan *bytes.Buffer, snatchFlushBufferSize)
db.transferHistoryLock = lock.NewCASMutex()
go db.flushTorrents()
go db.flushUsers()
go db.flushTransferHistory() // Can not be blocking or it will lock purgeInactivePeers when chan is empty
@ -114,25 +112,9 @@ func (db *Database) flushTorrents() {
count int
)
conn := Open()
for {
query.Reset()
query.WriteString("CREATE TEMPORARY TABLE IF NOT EXISTS flush_torrents (" +
"ID int unsigned NOT NULL, " +
"Snatched int unsigned NOT NULL DEFAULT 0, " +
"Seeders int unsigned NOT NULL DEFAULT 0, " +
"Leechers int unsigned NOT NULL DEFAULT 0, " +
"last_action int NOT NULL DEFAULT 0, " +
"PRIMARY KEY (ID)) ENGINE=MEMORY")
conn.exec(&query)
query.Reset()
query.WriteString("TRUNCATE flush_torrents")
conn.exec(&query)
query.Reset()
query.WriteString("INSERT INTO flush_torrents VALUES ")
query.WriteString("INSERT IGNORE INTO torrents (ID, Snatched, Seeders, Leechers, last_action) VALUES ")
length := len(db.torrentChannel)
@ -152,7 +134,7 @@ func (db *Database) flushTorrents() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{torrents} Flushing %d", count)
}
@ -161,18 +143,9 @@ func (db *Database) flushTorrents() {
query.WriteString(" ON DUPLICATE KEY UPDATE Snatched = Snatched + VALUE(Snatched), " +
"Seeders = VALUE(Seeders), Leechers = VALUE(Leechers), " +
"last_action = IF(last_action < VALUE(last_action), VALUE(last_action), last_action)")
conn.exec(&query)
db.exec(&query)
query.Reset()
query.WriteString("UPDATE torrents t, flush_torrents ft SET " +
"t.Snatched = t.Snatched + ft.Snatched, " +
"t.Seeders = ft.Seeders, " +
"t.Leechers = ft.Leechers, " +
"t.last_action = IF(t.last_action < ft.last_action, ft.last_action, t.last_action)" +
"WHERE t.ID = ft.ID")
conn.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("torrents", elapsedTime)
collectors.UpdateChannelsLen("torrents", count)
@ -181,14 +154,12 @@ func (db *Database) flushTorrents() {
if length < (torrentFlushBufferSize >> 1) {
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
}
} else if db.terminate {
} else if db.terminate.Load() {
break
} else {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) flushUsers() {
@ -200,25 +171,9 @@ func (db *Database) flushUsers() {
count int
)
conn := Open()
for {
query.Reset()
query.WriteString("CREATE TEMPORARY TABLE IF NOT EXISTS flush_users (" +
"ID int unsigned NOT NULL, " +
"Uploaded bigint unsigned NOT NULL DEFAULT 0, " +
"Downloaded bigint unsigned NOT NULL DEFAULT 0, " +
"rawdl bigint unsigned NOT NULL DEFAULT 0, " +
"rawup bigint unsigned NOT NULL DEFAULT 0, " +
"PRIMARY KEY (ID)) ENGINE=MEMORY")
conn.exec(&query)
query.Reset()
query.WriteString("TRUNCATE flush_users")
conn.exec(&query)
query.Reset()
query.WriteString("INSERT INTO flush_users VALUES ")
query.WriteString("INSERT IGNORE INTO users_main (ID, Uploaded, Downloaded, rawdl, rawup) VALUES ")
length := len(db.userChannel)
@ -238,7 +193,7 @@ func (db *Database) flushUsers() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{users_main} Flushing %d", count)
}
@ -246,18 +201,9 @@ func (db *Database) flushUsers() {
query.WriteString(" ON DUPLICATE KEY UPDATE Uploaded = Uploaded + VALUE(Uploaded), " +
"Downloaded = Downloaded + VALUE(Downloaded), rawdl = rawdl + VALUE(rawdl), rawup = rawup + VALUE(rawup)")
conn.exec(&query)
db.exec(&query)
query.Reset()
query.WriteString("UPDATE users_main u, flush_users fu SET " +
"u.Uploaded = u.Uploaded + fu.Uploaded, " +
"u.Downloaded = u.Downloaded + fu.Downloaded, " +
"u.rawdl = u.rawdl + fu.rawdl, " +
"u.rawup = u.rawup + fu.rawup " +
"WHERE u.ID = fu.ID")
conn.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("users", elapsedTime)
collectors.UpdateChannelsLen("users", count)
@ -266,14 +212,12 @@ func (db *Database) flushUsers() {
if length < (userFlushBufferSize >> 1) {
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
}
} else if db.terminate {
} else if db.terminate.Load() {
break
} else {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) flushTransferHistory() {
@ -285,8 +229,6 @@ func (db *Database) flushTransferHistory() {
count int
)
conn := Open()
for {
length, err := func() (int, error) {
db.transferHistoryLock.Lock()
@ -314,7 +256,7 @@ func (db *Database) flushTransferHistory() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{transfer_history} Flushing %d", count)
}
@ -326,16 +268,16 @@ func (db *Database) flushTransferHistory() {
"seedtime = seedtime + VALUE(seedtime), last_announce = VALUE(last_announce), " +
"active = VALUE(active), snatched = snatched + VALUE(snatched);")
conn.exec(&query)
db.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("transfer_history", elapsedTime)
collectors.UpdateChannelsLen("transfer_history", count)
}
return length, nil
} else if db.terminate {
} else if db.terminate.Load() {
return 0, errDbTerminate
}
@ -350,8 +292,6 @@ func (db *Database) flushTransferHistory() {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) flushTransferIps() {
@ -363,8 +303,6 @@ func (db *Database) flushTransferIps() {
count int
)
conn := Open()
for {
query.Reset()
query.WriteString("INSERT INTO transfer_ips (uid, fid, client_id, ip, port, uploaded, downloaded, " +
@ -388,7 +326,7 @@ func (db *Database) flushTransferIps() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{transfer_ips} Flushing %d", count)
}
@ -397,9 +335,9 @@ func (db *Database) flushTransferIps() {
// todo: port should be part of PK
query.WriteString("\nON DUPLICATE KEY UPDATE port = VALUE(port), downloaded = downloaded + VALUE(downloaded), " +
"uploaded = uploaded + VALUE(uploaded), last_announce = VALUE(last_announce)")
conn.exec(&query)
db.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("transfer_ips", elapsedTime)
collectors.UpdateChannelsLen("transfer_ips", count)
@ -408,14 +346,12 @@ func (db *Database) flushTransferIps() {
if length < (transferIpsFlushBufferSize >> 1) {
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
}
} else if db.terminate {
} else if db.terminate.Load() {
break
} else {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) flushSnatches() {
@ -427,8 +363,6 @@ func (db *Database) flushSnatches() {
count int
)
conn := Open()
for {
query.Reset()
query.WriteString("INSERT INTO transfer_history (uid, fid, snatched_time) VALUES\n")
@ -451,16 +385,16 @@ func (db *Database) flushSnatches() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{snatches} Flushing %d", count)
}
startTime := time.Now()
query.WriteString("\nON DUPLICATE KEY UPDATE snatched_time = VALUE(snatched_time)")
conn.exec(&query)
db.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("snatches", elapsedTime)
collectors.UpdateChannelsLen("snatches", count)
@ -469,14 +403,12 @@ func (db *Database) flushSnatches() {
if length < (snatchFlushBufferSize >> 1) {
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
}
} else if db.terminate {
} else if db.terminate.Load() {
break
} else {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) purgeInactivePeers() {
@ -486,7 +418,7 @@ func (db *Database) purgeInactivePeers() {
count int
)
for !db.terminate {
util.ContextTick(db.ctx, time.Duration(purgeInactivePeersInterval)*time.Second, func() {
start = time.Now()
now = start.Unix()
count = 0
@ -494,45 +426,45 @@ func (db *Database) purgeInactivePeers() {
oldestActive := now - int64(peerInactivityInterval)
// First, remove inactive peers from memory
func() {
db.TorrentsLock.RLock()
defer db.TorrentsLock.RUnlock()
dbTorrents := *db.Torrents.Load()
for _, torrent := range dbTorrents {
func() {
//Take write lock to operate on peer entries
torrent.Lock()
defer torrent.Unlock()
for _, torrent := range db.Torrents {
func() {
//Take write lock to operate on entries
torrent.Lock()
defer torrent.Unlock()
countThisTorrent := count
countThisTorrent := count
for id, peer := range torrent.Leechers {
if peer.LastAnnounce < oldestActive {
delete(torrent.Leechers, id)
count++
}
for id, peer := range torrent.Leechers {
if peer.LastAnnounce < oldestActive {
delete(torrent.Leechers, id)
count++
}
}
if countThisTorrent != count && len(torrent.Leechers) == 0 {
/* Deallocate previous map since Go does not free space used on maps when deleting objects.
We're doing it only for Leechers as potential advantage from freeing one or two Seeders is
virtually nil, while Leechers can incur significant memory leaks due to initial swarm activity. */
torrent.Leechers = make(map[cdb.PeerKey]*cdb.Peer)
}
if countThisTorrent != count && len(torrent.Leechers) == 0 {
/* Deallocate previous map since Go does not free space used on maps when deleting objects.
We're doing it only for Leechers as potential advantage from freeing one or two Seeders is
virtually nil, while Leechers can incur significant memory leaks due to initial swarm activity. */
torrent.Leechers = make(map[cdb.PeerKey]*cdb.Peer)
}
for id, peer := range torrent.Seeders {
if peer.LastAnnounce < oldestActive {
delete(torrent.Seeders, id)
count++
}
for id, peer := range torrent.Seeders {
if peer.LastAnnounce < oldestActive {
delete(torrent.Seeders, id)
count++
}
}
if countThisTorrent != count {
db.QueueTorrent(torrent, 0)
}
}()
}
}()
if countThisTorrent != count {
// Update lengths of peers
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
db.QueueTorrent(torrent, 0)
}
}()
}
elapsedTime := time.Since(start)
collectors.UpdateFlushTime("purging_inactive_peers", elapsedTime)
@ -547,11 +479,8 @@ func (db *Database) purgeInactivePeers() {
db.transferHistoryLock.Lock()
defer db.transferHistoryLock.Unlock()
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start = time.Now()
result := db.mainConn.execute(db.cleanStalePeersStmt, oldestActive)
result := db.execute(db.cleanStalePeersStmt, oldestActive)
if result != nil {
rows, err := result.RowsAffected()
@ -562,7 +491,5 @@ func (db *Database) purgeInactivePeers() {
}
}
}()
time.Sleep(time.Duration(purgeInactivePeersInterval) * time.Second)
}
})
}

View file

@ -38,15 +38,15 @@ func (db *Database) QueueTorrent(torrent *cdb.Torrent, deltaSnatch uint8) {
tq := db.bufferPool.Take()
tq.WriteString("(")
tq.WriteString(strconv.FormatUint(uint64(torrent.ID), 10))
tq.WriteString(strconv.FormatUint(uint64(torrent.ID.Load()), 10))
tq.WriteString(",")
tq.WriteString(strconv.FormatUint(uint64(deltaSnatch), 10))
tq.WriteString(",")
tq.WriteString(strconv.FormatInt(int64(len(torrent.Seeders)), 10))
tq.WriteString(strconv.FormatUint(uint64(torrent.SeedersLength.Load()), 10))
tq.WriteString(",")
tq.WriteString(strconv.FormatInt(int64(len(torrent.Leechers)), 10))
tq.WriteString(strconv.FormatUint(uint64(torrent.LeechersLength.Load()), 10))
tq.WriteString(",")
tq.WriteString(strconv.FormatInt(torrent.LastAction, 10))
tq.WriteString(strconv.FormatInt(torrent.LastAction.Load(), 10))
tq.WriteString(")")
select {
@ -58,11 +58,16 @@ func (db *Database) QueueTorrent(torrent *cdb.Torrent, deltaSnatch uint8) {
}
}
func (db *Database) QueueUser(user *cdb.User, rawDeltaUp, rawDeltaDown, deltaUp, deltaDown int64) {
func (db *Database) QueueUser(
user *cdb.User,
rawDeltaUp,
rawDeltaDown,
deltaUp,
deltaDown int64) {
uq := db.bufferPool.Take()
uq.WriteString("(")
uq.WriteString(strconv.FormatUint(uint64(user.ID), 10))
uq.WriteString(strconv.FormatUint(uint64(user.ID.Load()), 10))
uq.WriteString(",")
uq.WriteString(strconv.FormatInt(deltaUp, 10))
uq.WriteString(",")
@ -174,7 +179,5 @@ func (db *Database) QueueSnatch(peer *cdb.Peer, now int64) {
}
func (db *Database) UnPrune(torrent *cdb.Torrent) {
db.mainConn.mutex.Lock()
db.mainConn.execute(db.unPruneTorrentStmt, torrent.ID)
db.mainConn.mutex.Unlock()
db.execute(db.unPruneTorrentStmt, torrent.ID.Load())
}

View file

@ -18,6 +18,7 @@
package database
import (
"math"
"sync/atomic"
"time"
@ -25,6 +26,7 @@ import (
"chihaya/config"
cdb "chihaya/database/types"
"chihaya/log"
"chihaya/util"
)
var (
@ -49,32 +51,26 @@ func init() {
*/
func (db *Database) startReloading() {
go func() {
for !db.terminate {
time.Sleep(time.Duration(reloadInterval) * time.Second)
util.ContextTick(db.ctx, time.Duration(reloadInterval)*time.Second, func() {
db.waitGroup.Add(1)
defer db.waitGroup.Done()
db.loadUsers()
db.loadHitAndRuns()
db.loadTorrents()
db.loadGroupsFreeleech()
db.loadConfig()
db.loadClients()
db.waitGroup.Done()
}
})
}()
}
func (db *Database) loadUsers() {
db.UsersLock.Lock()
defer db.UsersLock.Unlock()
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start := time.Now()
newUsers := make(map[string]*cdb.User, len(db.Users))
rows := db.mainConn.query(db.loadUsersStmt)
dbUsers := *db.Users.Load()
newUsers := make(map[string]*cdb.User, len(dbUsers))
rows := db.query(db.loadUsersStmt)
if rows == nil {
log.Error.Print("Failed to load hit and runs from database")
log.WriteStack()
@ -99,40 +95,37 @@ func (db *Database) loadUsers() {
log.WriteStack()
}
if old, exists := db.Users[torrentPass]; exists && old != nil {
old.ID = id
old.DownMultiplier = downMultiplier
old.UpMultiplier = upMultiplier
old.DisableDownload = disableDownload
old.TrackerHide = trackerHide
if old, exists := dbUsers[torrentPass]; exists && old != nil {
old.ID.Store(id)
old.DownMultiplier.Store(math.Float64bits(downMultiplier))
old.UpMultiplier.Store(math.Float64bits(upMultiplier))
old.DisableDownload.Store(disableDownload)
old.TrackerHide.Store(trackerHide)
newUsers[torrentPass] = old
} else {
newUsers[torrentPass] = &cdb.User{
ID: id,
UpMultiplier: upMultiplier,
DownMultiplier: downMultiplier,
DisableDownload: disableDownload,
TrackerHide: trackerHide,
}
u := &cdb.User{}
u.ID.Store(id)
u.DownMultiplier.Store(math.Float64bits(downMultiplier))
u.UpMultiplier.Store(math.Float64bits(upMultiplier))
u.DisableDownload.Store(disableDownload)
u.TrackerHide.Store(trackerHide)
newUsers[torrentPass] = u
}
}
db.Users = newUsers
db.Users.Store(&newUsers)
elapsedTime := time.Since(start)
collectors.UpdateReloadTime("users", elapsedTime)
log.Info.Printf("User load complete (%d rows, %s)", len(db.Users), elapsedTime.String())
log.Info.Printf("User load complete (%d rows, %s)", len(newUsers), elapsedTime.String())
}
func (db *Database) loadHitAndRuns() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start := time.Now()
newHnr := make(map[cdb.UserTorrentPair]struct{})
rows := db.mainConn.query(db.loadHnrStmt)
rows := db.query(db.loadHnrStmt)
if rows == nil {
log.Error.Print("Failed to load hit and runs from database")
log.WriteStack()
@ -169,100 +162,99 @@ func (db *Database) loadHitAndRuns() {
func (db *Database) loadTorrents() {
var start time.Time
newTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
dbTorrents := *db.Torrents.Load()
func() {
db.TorrentsLock.RLock()
defer db.TorrentsLock.RUnlock()
newTorrents := make(map[cdb.TorrentHash]*cdb.Torrent, len(dbTorrents))
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start = time.Now()
start = time.Now()
rows := db.query(db.loadTorrentsStmt)
if rows == nil {
log.Error.Print("Failed to load torrents from database")
log.WriteStack()
rows := db.mainConn.query(db.loadTorrentsStmt)
if rows == nil {
log.Error.Print("Failed to load torrents from database")
log.WriteStack()
return
}
return
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
var (
infoHash cdb.TorrentHash
id uint32
downMultiplier, upMultiplier float64
snatched uint16
status uint8
group cdb.TorrentGroup
)
if err := rows.Scan(
&id,
&infoHash,
&downMultiplier,
&upMultiplier,
&snatched,
&status,
&group.GroupID,
&group.TorrentType,
); err != nil {
log.Error.Printf("Error scanning torrent row: %s", err)
log.WriteStack()
}
if old, exists := db.Torrents[infoHash]; exists && old != nil {
func() {
old.Lock()
defer old.Unlock()
old.ID = id
old.DownMultiplier = downMultiplier
old.UpMultiplier = upMultiplier
old.Snatched = snatched
old.Status = status
old.Group = group
}()
newTorrents[infoHash] = old
} else {
newTorrents[infoHash] = &cdb.Torrent{
ID: id,
UpMultiplier: upMultiplier,
DownMultiplier: downMultiplier,
Snatched: snatched,
Status: status,
Group: group,
Seeders: make(map[cdb.PeerKey]*cdb.Peer),
Leechers: make(map[cdb.PeerKey]*cdb.Peer),
}
}
}
defer func() {
_ = rows.Close()
}()
db.TorrentsLock.Lock()
defer db.TorrentsLock.Unlock()
db.Torrents = newTorrents
for rows.Next() {
var (
infoHash cdb.TorrentHash
id uint32
downMultiplier, upMultiplier float64
snatched uint16
status uint8
groupID uint32
torrentType string
)
if err := rows.Scan(
&id,
&infoHash,
&downMultiplier,
&upMultiplier,
&snatched,
&status,
&groupID,
&torrentType,
); err != nil {
log.Error.Printf("Error scanning torrent row: %s", err)
log.WriteStack()
}
torrentTypeUint64, err := cdb.TorrentTypeFromString(torrentType)
if err != nil {
log.Error.Printf("Error storing torrent row: %s", err)
log.WriteStack()
}
if old, exists := dbTorrents[infoHash]; exists && old != nil {
old.ID.Store(id)
old.DownMultiplier.Store(math.Float64bits(downMultiplier))
old.UpMultiplier.Store(math.Float64bits(upMultiplier))
old.Snatched.Store(uint32(snatched))
old.Status.Store(uint32(status))
old.Group.TorrentType.Store(torrentTypeUint64)
old.Group.GroupID.Store(groupID)
newTorrents[infoHash] = old
} else {
t := &cdb.Torrent{
Seeders: make(map[cdb.PeerKey]*cdb.Peer),
Leechers: make(map[cdb.PeerKey]*cdb.Peer),
}
t.InitializeLock()
t.ID.Store(id)
t.DownMultiplier.Store(math.Float64bits(downMultiplier))
t.UpMultiplier.Store(math.Float64bits(upMultiplier))
t.Snatched.Store(uint32(snatched))
t.Status.Store(uint32(status))
t.Group.TorrentType.Store(torrentTypeUint64)
t.Group.GroupID.Store(groupID)
newTorrents[infoHash] = t
}
}
db.Torrents.Store(&newTorrents)
elapsedTime := time.Since(start)
collectors.UpdateReloadTime("torrents", elapsedTime)
log.Info.Printf("Torrent load complete (%d rows, %s)", len(db.Torrents), elapsedTime.String())
log.Info.Printf("Torrent load complete (%d rows, %s)", len(newTorrents), elapsedTime.String())
}
func (db *Database) loadGroupsFreeleech() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start := time.Now()
newTorrentGroupFreeleech := make(map[cdb.TorrentGroup]*cdb.TorrentGroupFreeleech)
newTorrentGroupFreeleech := make(map[cdb.TorrentGroupKey]*cdb.TorrentGroupFreeleech)
rows := db.mainConn.query(db.loadTorrentGroupFreeleechStmt)
rows := db.query(db.loadTorrentGroupFreeleechStmt)
if rows == nil {
log.Error.Print("Failed to load torrent group freeleech data from database")
log.WriteStack()
@ -277,15 +269,22 @@ func (db *Database) loadGroupsFreeleech() {
for rows.Next() {
var (
downMultiplier, upMultiplier float64
group cdb.TorrentGroup
groupID uint32
torrentType string
)
if err := rows.Scan(&group.GroupID, &group.TorrentType, &downMultiplier, &upMultiplier); err != nil {
log.Error.Printf("Error scanning torrent row: %s", err)
if err := rows.Scan(&groupID, &torrentType, &downMultiplier, &upMultiplier); err != nil {
log.Error.Printf("Error scanning torrent group freeleech row: %s", err)
log.WriteStack()
}
newTorrentGroupFreeleech[group] = &cdb.TorrentGroupFreeleech{
k, err := cdb.TorrentGroupKeyFromString(torrentType, groupID)
if err != nil {
log.Error.Printf("Error storing torrent group freeleech row: %s", err)
log.WriteStack()
}
newTorrentGroupFreeleech[k] = &cdb.TorrentGroupFreeleech{
UpMultiplier: upMultiplier,
DownMultiplier: downMultiplier,
}
@ -295,14 +294,11 @@ func (db *Database) loadGroupsFreeleech() {
elapsedTime := time.Since(start)
collectors.UpdateReloadTime("groups_freeleech", elapsedTime)
log.Info.Printf("Group freeleech load complete (%d rows, %s)", len(db.Torrents), elapsedTime.String())
log.Info.Printf("Group freeleech load complete (%d rows, %s)", len(newTorrentGroupFreeleech), elapsedTime.String())
}
func (db *Database) loadConfig() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
rows := db.mainConn.query(db.loadFreeleechStmt)
rows := db.query(db.loadFreeleechStmt)
if rows == nil {
log.Error.Print("Failed to load config from database")
log.WriteStack()
@ -327,13 +323,10 @@ func (db *Database) loadConfig() {
}
func (db *Database) loadClients() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start := time.Now()
newClients := make(map[uint16]string)
rows := db.mainConn.query(db.loadClientsStmt)
rows := db.query(db.loadClientsStmt)
if rows == nil {
log.Error.Print("Failed to load clients from database")
log.WriteStack()

View file

@ -1,21 +1,21 @@
create table approved_clients
(
id mediumint unsigned auto_increment primary key,
peer_id varchar(42) null,
archived tinyint(1) default 0 null
peer_id varchar(42) not null,
archived tinyint(1) default 0 not null
);
create table mod_core
(
mod_option varchar(121) not null primary key,
mod_setting int(12) default 0 not null
mod_setting int(12) not null
);
create table torrent_group_freeleech
(
ID int(10) auto_increment primary key,
GroupID int(10) default 0 not null,
Type enum ('anime', 'music') default 'anime' not null,
GroupID int(10) not null,
Type enum ('anime', 'music') not null,
DownMultiplier float default 1 not null,
UpMultiplier float default 1 not null,
constraint GroupID unique (GroupID, Type)
@ -25,7 +25,7 @@ create table torrents
(
ID int(10) auto_increment primary key,
GroupID int(10) not null,
TorrentType enum ('anime', 'music') default 'anime' not null,
TorrentType enum ('anime', 'music') not null,
info_hash blob not null,
Leechers int(6) default 0 not null,
Seeders int(6) default 0 not null,
@ -34,25 +34,13 @@ create table torrents
DownMultiplier float default 1 not null,
UpMultiplier float default 1 not null,
Status int default 0 not null,
constraint InfoHash unique (info_hash)
constraint InfoHash unique (info_hash (20))
);
create table torrents_group
(
ID int unsigned auto_increment primary key,
Time int(10) default 0 not null
) charset = utf8mb4;
create table torrents_group2
(
ID int unsigned auto_increment primary key,
Time int(10) default 0 not null
) charset = utf8mb4;
create table transfer_history
(
uid int default 0 not null,
fid int default 0 not null,
uid int not null,
fid int not null,
uploaded bigint default 0 not null,
downloaded bigint default 0 not null,
seeding tinyint default 0 not null,
@ -64,7 +52,7 @@ create table transfer_history
starttime int default 0 not null,
last_announce int default 0 not null,
snatched int default 0 not null,
snatched_time int default 0 null,
snatched_time int default 0 not null,
primary key (uid, fid)
);
@ -72,13 +60,13 @@ create table transfer_ips
(
last_announce int unsigned default 0 not null,
starttime int unsigned default 0 not null,
uid int unsigned default 0 not null,
fid int unsigned default 0 not null,
ip int unsigned default 0 not null,
client_id mediumint unsigned default 0 not null,
uid int unsigned not null,
fid int unsigned not null,
ip int unsigned not null,
client_id mediumint unsigned not null,
uploaded bigint unsigned default 0 not null,
downloaded bigint unsigned default 0 not null,
port smallint unsigned zerofill null,
port smallint unsigned zerofill default 0 not null,
primary key (uid, fid, ip, client_id)
);
@ -89,8 +77,8 @@ create table users_main
Downloaded bigint unsigned default 0 not null,
Enabled enum ('0', '1', '2') default '0' not null,
torrent_pass char(32) not null,
rawup bigint unsigned not null,
rawdl bigint unsigned not null,
rawup bigint unsigned default 0 not null,
rawdl bigint unsigned default 0 not null,
DownMultiplier float default 1 not null,
UpMultiplier float default 1 not null,
DisableDownload tinyint(1) default 0 not null,

View file

@ -18,6 +18,7 @@
package database
import (
"chihaya/util"
"fmt"
"os"
"time"
@ -37,10 +38,9 @@ func init() {
func (db *Database) startSerializing() {
go func() {
for !db.terminate {
time.Sleep(time.Duration(serializeInterval) * time.Second)
util.ContextTick(db.ctx, time.Duration(serializeInterval)*time.Second, func() {
db.serialize()
}
})
}()
}
@ -68,10 +68,7 @@ func (db *Database) serialize() {
torrentFile.Close()
}()
db.TorrentsLock.RLock()
defer db.TorrentsLock.RUnlock()
if err = cdb.WriteTorrents(torrentFile, db.Torrents); err != nil {
if err = cdb.WriteTorrents(torrentFile, *db.Torrents.Load()); err != nil {
log.Error.Print("Failed to encode torrents for serialization: ", err)
return err
}
@ -96,10 +93,7 @@ func (db *Database) serialize() {
userFile.Close()
}()
db.UsersLock.RLock()
defer db.UsersLock.RUnlock()
if err = cdb.WriteUsers(userFile, db.Users); err != nil {
if err = cdb.WriteUsers(userFile, *db.Users.Load()); err != nil {
log.Error.Print("Failed to encode users for serialization: ", err)
return err
}
@ -124,6 +118,10 @@ func (db *Database) deserialize() {
start := time.Now()
torrents := 0
peers := 0
func() {
torrentFile, err := os.OpenFile(torrentBinFilename, os.O_RDONLY, 0)
if err != nil {
@ -134,13 +132,19 @@ func (db *Database) deserialize() {
//goland:noinspection GoUnhandledErrorResult
defer torrentFile.Close()
db.TorrentsLock.Lock()
defer db.TorrentsLock.Unlock()
if err = cdb.LoadTorrents(torrentFile, db.Torrents); err != nil {
dbTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
if err = cdb.LoadTorrents(torrentFile, dbTorrents); err != nil {
log.Error.Print("Failed to deserialize torrent cache: ", err)
return
}
torrents = len(dbTorrents)
for _, t := range dbTorrents {
peers += int(t.LeechersLength.Load()) + int(t.SeedersLength.Load())
}
db.Torrents.Store(&dbTorrents)
}()
func() {
@ -153,33 +157,16 @@ func (db *Database) deserialize() {
//goland:noinspection GoUnhandledErrorResult
defer userFile.Close()
db.UsersLock.Lock()
defer db.UsersLock.Unlock()
if err = cdb.LoadUsers(userFile, db.Users); err != nil {
users := make(map[string]*cdb.User)
if err = cdb.LoadUsers(userFile, users); err != nil {
log.Error.Print("Failed to deserialize user cache: ", err)
return
}
db.Users.Store(&users)
}()
db.TorrentsLock.RLock()
defer db.TorrentsLock.RUnlock()
torrents := len(db.Torrents)
peers := 0
for _, t := range db.Torrents {
func() {
t.RLock()
defer t.RUnlock()
peers += len(t.Leechers) + len(t.Seeders)
}()
}
db.UsersLock.RLock()
defer db.UsersLock.RUnlock()
users := len(db.Users)
users := len(*db.Users.Load())
log.Info.Printf("Loaded %d users, %d torrents and %d peers (%s)",
users, torrents, peers, time.Since(start).String())

View file

@ -18,6 +18,8 @@
package database
import (
"github.com/google/go-cmp/cmp"
"math"
"net"
"reflect"
"testing"
@ -32,13 +34,14 @@ func TestSerializer(t *testing.T) {
testTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
testUsers := make(map[string]*cdb.User)
testUsers["mUztWMpBYNCqzmge6vGeEUGSrctJbgpQ"] = &cdb.User{
DisableDownload: false,
TrackerHide: false,
ID: 12,
UpMultiplier: 1,
DownMultiplier: 1,
}
testUser := &cdb.User{}
testUser.ID.Store(12)
testUser.DownMultiplier.Store(math.Float64bits(1))
testUser.UpMultiplier.Store(math.Float64bits(1))
testUser.DisableDownload.Store(false)
testUser.TrackerHide.Store(false)
testUsers["mUztWMpBYNCqzmge6vGeEUGSrctJbgpQ"] = testUser
testPeer := &cdb.Peer{
UserID: 12,
@ -57,46 +60,63 @@ func TestSerializer(t *testing.T) {
testTorrentHash := cdb.TorrentHash{
114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56,
}
testTorrents[testTorrentHash] = &cdb.Torrent{
Status: 1,
Snatched: 100,
ID: 10,
LastAction: time.Now().Unix(),
UpMultiplier: 1,
DownMultiplier: 1,
torrent := &cdb.Torrent{
Seeders: map[cdb.PeerKey]*cdb.Peer{
cdb.NewPeerKey(12, cdb.PeerIDFromRawString("peer_is_twenty_chars")): testPeer,
},
Leechers: map[cdb.PeerKey]*cdb.Peer{},
}
torrent.InitializeLock()
torrent.ID.Store(10)
torrent.Status.Store(1)
torrent.Snatched.Store(100)
torrent.LastAction.Store(time.Now().Unix())
torrent.DownMultiplier.Store(math.Float64bits(1))
torrent.UpMultiplier.Store(math.Float64bits(1))
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
torrent.Group.GroupID.Store(1)
torrent.Group.TorrentType.Store(cdb.MustTorrentTypeFromString("anime"))
testTorrents[testTorrentHash] = torrent
// Prepare empty map to populate with test data
db.Torrents = make(map[cdb.TorrentHash]*cdb.Torrent)
db.Users = make(map[string]*cdb.User)
dbTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
db.Torrents.Store(&dbTorrents)
if err := copier.Copy(&db.Torrents, testTorrents); err != nil {
dbUsers := make(map[string]*cdb.User)
db.Users.Store(&dbUsers)
if err := copier.Copy(&dbTorrents, testTorrents); err != nil {
panic(err)
}
if err := copier.Copy(&db.Users, testUsers); err != nil {
if err := copier.Copy(&dbUsers, testUsers); err != nil {
panic(err)
}
db.serialize()
// Reset map to fully test deserialization
db.Torrents = make(map[cdb.TorrentHash]*cdb.Torrent)
db.Users = make(map[string]*cdb.User)
dbTorrents = make(map[cdb.TorrentHash]*cdb.Torrent)
db.Torrents.Store(&dbTorrents)
dbUsers = make(map[string]*cdb.User)
db.Users.Store(&dbUsers)
db.deserialize()
if !reflect.DeepEqual(db.Torrents, testTorrents) {
dbTorrents = *db.Torrents.Load()
dbUsers = *db.Users.Load()
if !cmp.Equal(dbTorrents, testTorrents, cdb.TorrentTestCompareOptions...) {
t.Fatalf("Torrents (%v) after serialization and deserialization do not match original torrents (%v)!",
db.Torrents, testTorrents)
dbTorrents, testTorrents)
}
if !reflect.DeepEqual(db.Users, testUsers) {
if !reflect.DeepEqual(dbUsers, testUsers) {
t.Fatalf("Users (%v) after serialization and deserialization do not match original users (%v)!",
db.Users, testUsers)
dbUsers, testUsers)
}
}

View file

@ -107,6 +107,7 @@ func LoadTorrents(r io.Reader, torrents map[TorrentHash]*Torrent) error {
}
t := &Torrent{}
t.InitializeLock()
if err := t.Load(version, reader); err != nil {
return err

View file

@ -18,13 +18,18 @@
package types
import (
"context"
"database/sql/driver"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/viney-shih/go-lock"
"io"
"math"
"sync"
"sync/atomic"
)
const TorrentHashSize = 20
@ -116,42 +121,61 @@ func (h *TorrentHash) UnmarshalText(b []byte) error {
}
type Torrent struct {
// lock This must be taken whenever read or write is made to Seeders or Leechers fields on this torrent.
lock *lock.ChanMutex
Seeders map[PeerKey]*Peer
Leechers map[PeerKey]*Peer
// SeedersLength Contains the length of Seeders. When Seeders is modified this field must be updated
SeedersLength atomic.Uint32
// LeechersLength Contains the length of Leechers. When LeechersLength is modified this field must be updated
LeechersLength atomic.Uint32
Group TorrentGroup
ID uint32
ID atomic.Uint32
Snatched uint16
// Snatched 16 bits
Snatched atomic.Uint32
Status uint8
LastAction int64 // unix time
// Snatched 8 bits
Status atomic.Uint32
LastAction atomic.Int64 // unix time
UpMultiplier float64
DownMultiplier float64
// UpMultiplier float64
UpMultiplier atomic.Uint64
// DownMultiplier float64
DownMultiplier atomic.Uint64
}
// lock This must be taken whenever read or write is made to fields on this torrent.
// Maybe single sync.Mutex is easier to handle, but prevent concurrent access.
lock sync.RWMutex
func (t *Torrent) InitializeLock() {
if t.lock != nil {
panic("already initialized")
}
t.lock = lock.NewChanMutex()
}
func (t *Torrent) Lock() {
t.lock.Lock()
}
func (t *Torrent) TryLockWithContext(ctx context.Context) bool {
return t.lock.TryLockWithContext(ctx)
}
func (t *Torrent) Unlock() {
t.lock.Unlock()
}
func (t *Torrent) RLock() {
t.lock.RLock()
}
func (t *Torrent) RUnlock() {
t.lock.RUnlock()
}
func (t *Torrent) Load(version uint64, reader readerAndByteReader) (err error) {
var (
id uint32
snatched uint16
status uint8
lastAction int64
upMultiplier, downMultiplier float64
)
var varIntLen uint64
if varIntLen, err = binary.ReadUvarint(reader); err != nil {
@ -175,6 +199,8 @@ func (t *Torrent) Load(version uint64, reader readerAndByteReader) (err error) {
t.Seeders[k] = s
}
t.SeedersLength.Store(uint32(len(t.Seeders)))
if varIntLen, err = binary.ReadUvarint(reader); err != nil {
return err
}
@ -195,100 +221,270 @@ func (t *Torrent) Load(version uint64, reader readerAndByteReader) (err error) {
t.Leechers[k] = l
}
t.LeechersLength.Store(uint32(len(t.Leechers)))
if err = t.Group.Load(version, reader); err != nil {
return err
}
if err = binary.Read(reader, binary.LittleEndian, &t.ID); err != nil {
if err = binary.Read(reader, binary.LittleEndian, &id); err != nil {
return err
}
if err = binary.Read(reader, binary.LittleEndian, &t.Snatched); err != nil {
if err = binary.Read(reader, binary.LittleEndian, &snatched); err != nil {
return err
}
if err = binary.Read(reader, binary.LittleEndian, &t.Status); err != nil {
if err = binary.Read(reader, binary.LittleEndian, &status); err != nil {
return err
}
if err = binary.Read(reader, binary.LittleEndian, &t.LastAction); err != nil {
if err = binary.Read(reader, binary.LittleEndian, &lastAction); err != nil {
return err
}
if err = binary.Read(reader, binary.LittleEndian, &t.UpMultiplier); err != nil {
if err = binary.Read(reader, binary.LittleEndian, &upMultiplier); err != nil {
return err
}
return binary.Read(reader, binary.LittleEndian, &t.DownMultiplier)
if err = binary.Read(reader, binary.LittleEndian, &downMultiplier); err != nil {
return err
}
t.ID.Store(id)
t.Snatched.Store(uint32(snatched))
t.Status.Store(uint32(status))
t.LastAction.Store(lastAction)
t.UpMultiplier.Store(math.Float64bits(upMultiplier))
t.DownMultiplier.Store(math.Float64bits(downMultiplier))
return nil
}
func (t *Torrent) Append(preAllocatedBuffer []byte) (buf []byte) {
t.RLock()
defer t.RUnlock()
buf = preAllocatedBuffer
buf = binary.AppendUvarint(buf, uint64(len(t.Seeders)))
for k, s := range t.Seeders {
buf = append(buf, k[:]...)
func() {
// This could be a read-only lock, but this simpler lock is faster overall
t.Lock()
defer t.Unlock()
buf = s.Append(buf)
}
buf = binary.AppendUvarint(buf, uint64(len(t.Seeders)))
buf = binary.AppendUvarint(buf, uint64(len(t.Leechers)))
for k, s := range t.Seeders {
buf = append(buf, k[:]...)
for k, l := range t.Leechers {
buf = append(buf, k[:]...)
buf = s.Append(buf)
}
buf = l.Append(buf)
}
buf = binary.AppendUvarint(buf, uint64(len(t.Leechers)))
for k, l := range t.Leechers {
buf = append(buf, k[:]...)
buf = l.Append(buf)
}
}()
buf = t.Group.Append(buf)
buf = binary.LittleEndian.AppendUint32(buf, t.ID)
buf = binary.LittleEndian.AppendUint16(buf, t.Snatched)
buf = append(buf, t.Status)
buf = binary.LittleEndian.AppendUint64(buf, uint64(t.LastAction))
buf = binary.LittleEndian.AppendUint64(buf, math.Float64bits(t.UpMultiplier))
buf = binary.LittleEndian.AppendUint64(buf, math.Float64bits(t.DownMultiplier))
buf = binary.LittleEndian.AppendUint32(buf, t.ID.Load())
buf = binary.LittleEndian.AppendUint16(buf, uint16(t.Snatched.Load()))
buf = append(buf, uint8(t.Status.Load()))
buf = binary.LittleEndian.AppendUint64(buf, uint64(t.LastAction.Load()))
buf = binary.LittleEndian.AppendUint64(buf, t.UpMultiplier.Load())
buf = binary.LittleEndian.AppendUint64(buf, t.DownMultiplier.Load())
return buf
}
var encodeJSONTorrentMap = make(map[string]any)
var encodeJSONTorrentGroupMap = make(map[string]any)
// MarshalJSON Due to using atomics, JSON will not marshal values within them.
// This is only safe to call from a single thread at once
func (t *Torrent) MarshalJSON() (buf []byte, err error) {
encodeJSONTorrentMap["ID"] = t.ID.Load()
encodeJSONTorrentMap["Seeders"] = t.Seeders
encodeJSONTorrentMap["Leechers"] = t.Leechers
var torrentTypeBuf [8]byte
binary.LittleEndian.PutUint64(torrentTypeBuf[:], t.Group.TorrentType.Load())
i := 0
for ; i < len(torrentTypeBuf); i++ {
if torrentTypeBuf[i] == 0 {
break
}
}
encodeJSONTorrentGroupMap["TorrentType"] = string(torrentTypeBuf[:i])
encodeJSONTorrentGroupMap["GroupID"] = t.Group.GroupID.Load()
encodeJSONTorrentMap["Group"] = encodeJSONTorrentGroupMap
encodeJSONTorrentMap["Snatched"] = uint16(t.Snatched.Load())
encodeJSONTorrentMap["Status"] = uint8(t.Status.Load())
encodeJSONTorrentMap["LastAction"] = t.LastAction.Load()
encodeJSONTorrentMap["UpMultiplier"] = math.Float64frombits(t.UpMultiplier.Load())
encodeJSONTorrentMap["DownMultiplier"] = math.Float64frombits(t.UpMultiplier.Load())
return json.Marshal(encodeJSONTorrentMap)
}
type decodeJSONTorrent struct {
Seeders map[PeerKey]*Peer
Leechers map[PeerKey]*Peer
Group struct {
TorrentType string
GroupID uint32
}
ID uint32
Snatched uint16
Status uint8
LastAction int64
UpMultiplier float64
DownMultiplier float64
}
// UnmarshalJSON Due to using atomics, JSON will not marshal values within them.
// This is only safe to call from a single thread at once
func (t *Torrent) UnmarshalJSON(buf []byte) (err error) {
var torrentJSON decodeJSONTorrent
if err = json.Unmarshal(buf, &torrentJSON); err != nil {
return err
}
t.Seeders = torrentJSON.Seeders
t.Leechers = torrentJSON.Leechers
t.SeedersLength.Store(uint32(len(t.Seeders)))
t.LeechersLength.Store(uint32(len(t.Leechers)))
torrentType, err := TorrentTypeFromString(torrentJSON.Group.TorrentType)
if err != nil {
return err
}
t.Group.TorrentType.Store(torrentType)
t.Group.GroupID.Store(torrentJSON.Group.GroupID)
t.Snatched.Store(uint32(torrentJSON.Snatched))
t.Status.Store(uint32(torrentJSON.Status))
t.LastAction.Store(torrentJSON.LastAction)
t.UpMultiplier.Store(math.Float64bits(torrentJSON.UpMultiplier))
t.DownMultiplier.Store(math.Float64bits(torrentJSON.DownMultiplier))
return nil
}
type TorrentGroupFreeleech struct {
UpMultiplier float64
DownMultiplier float64
}
type TorrentGroup struct {
TorrentType string
GroupID uint32
type TorrentGroupKey [8 + 4]byte
func MustTorrentGroupKeyFromString(torrentType string, groupID uint32) TorrentGroupKey {
k, err := TorrentGroupKeyFromString(torrentType, groupID)
if err != nil {
panic(err)
}
return k
}
func (g *TorrentGroup) Load(_ uint64, reader readerAndByteReader) (err error) {
var varIntLen uint64
func TorrentGroupKeyFromString(torrentType string, groupID uint32) (k TorrentGroupKey, err error) {
t, err := TorrentTypeFromString(torrentType)
if err != nil {
return TorrentGroupKey{}, err
}
if varIntLen, err = binary.ReadUvarint(reader); err != nil {
binary.LittleEndian.PutUint64(k[:], t)
binary.LittleEndian.PutUint32(k[8:], groupID)
return k, nil
}
func MustTorrentTypeFromString(torrentType string) uint64 {
t, err := TorrentTypeFromString(torrentType)
if err != nil {
panic(err)
}
return t
}
func TorrentTypeFromString(torrentType string) (t uint64, err error) {
if len(torrentType) > 8 {
return 0, err
}
var buf [8]byte
copy(buf[:], torrentType)
return binary.LittleEndian.Uint64(buf[:]), nil
}
type TorrentGroup struct {
TorrentType atomic.Uint64
GroupID atomic.Uint32
}
func (g *TorrentGroup) Key() (k TorrentGroupKey) {
binary.LittleEndian.PutUint64(k[:], g.TorrentType.Load())
binary.LittleEndian.PutUint32(k[8:], g.GroupID.Load())
return k
}
var ErrTorrentTypeTooLong = errors.New("torrent type too long, maximum 8 bytes")
func (g *TorrentGroup) Load(version uint64, reader readerAndByteReader) (err error) {
var (
torrentType uint64
groupID uint32
)
if version <= 2 {
var varIntLen uint64
if varIntLen, err = binary.ReadUvarint(reader); err != nil {
return err
}
if varIntLen > 8 {
return ErrTorrentTypeTooLong
}
buf := make([]byte, 8)
if _, err = io.ReadFull(reader, buf[:varIntLen]); err != nil {
return err
}
torrentType = binary.LittleEndian.Uint64(buf)
} else {
if err = binary.Read(reader, binary.LittleEndian, &torrentType); err != nil {
return err
}
}
if err = binary.Read(reader, binary.LittleEndian, &groupID); err != nil {
return err
}
buf := make([]byte, varIntLen)
g.TorrentType.Store(torrentType)
g.GroupID.Store(groupID)
if _, err = io.ReadFull(reader, buf); err != nil {
return err
}
g.TorrentType = string(buf)
return binary.Read(reader, binary.LittleEndian, &g.GroupID)
return nil
}
func (g *TorrentGroup) Append(preAllocatedBuffer []byte) (buf []byte) {
buf = preAllocatedBuffer
buf = binary.AppendUvarint(buf, uint64(len(g.TorrentType)))
buf = append(buf, []byte(g.TorrentType)...)
buf = binary.LittleEndian.AppendUint64(buf, g.TorrentType.Load())
return binary.LittleEndian.AppendUint32(buf, g.GroupID)
return binary.LittleEndian.AppendUint32(buf, g.GroupID.Load())
}
// TorrentCacheFile holds filename used by serializer for this type
@ -296,4 +492,12 @@ var TorrentCacheFile = "torrent-cache"
// TorrentCacheVersion Used to distinguish old versions on the on-disk cache.
// Bump when fields are altered on Torrent, Peer or TorrentGroup structs
const TorrentCacheVersion = 2
const TorrentCacheVersion = 3
var TorrentTestCompareOptions = []cmp.Option{
cmp.AllowUnexported(atomic.Uint32{}),
cmp.AllowUnexported(atomic.Uint64{}),
cmp.AllowUnexported(atomic.Int64{}),
cmp.AllowUnexported(atomic.Bool{}),
cmpopts.IgnoreFields(Torrent{}, "lock"),
}

View file

@ -19,62 +19,121 @@ package types
import (
"encoding/binary"
"encoding/json"
"math"
"sync/atomic"
)
type User struct {
ID uint32
ID atomic.Uint32
DisableDownload bool
DisableDownload atomic.Bool
TrackerHide bool
TrackerHide atomic.Bool
UpMultiplier float64
DownMultiplier float64
// UpMultiplier A float64 under the covers
UpMultiplier atomic.Uint64
// DownMultiplier A float64 under the covers
DownMultiplier atomic.Uint64
}
func (u *User) Load(_ uint64, reader readerAndByteReader) (err error) {
if err = binary.Read(reader, binary.LittleEndian, &u.ID); err != nil {
var (
id uint32
disableDownload, trackerHide bool
upMultiplier, downMultiplier float64
)
if err = binary.Read(reader, binary.LittleEndian, &id); err != nil {
return err
}
if err = binary.Read(reader, binary.LittleEndian, &u.DisableDownload); err != nil {
if err = binary.Read(reader, binary.LittleEndian, &disableDownload); err != nil {
return err
}
if err = binary.Read(reader, binary.LittleEndian, &u.TrackerHide); err != nil {
if err = binary.Read(reader, binary.LittleEndian, &trackerHide); err != nil {
return err
}
if err = binary.Read(reader, binary.LittleEndian, &u.UpMultiplier); err != nil {
if err = binary.Read(reader, binary.LittleEndian, &upMultiplier); err != nil {
return err
}
return binary.Read(reader, binary.LittleEndian, &u.DownMultiplier)
if err = binary.Read(reader, binary.LittleEndian, &downMultiplier); err != nil {
return err
}
u.ID.Store(id)
u.DisableDownload.Store(disableDownload)
u.TrackerHide.Store(trackerHide)
u.UpMultiplier.Store(math.Float64bits(upMultiplier))
u.DownMultiplier.Store(math.Float64bits(downMultiplier))
return nil
}
func (u *User) Append(preAllocatedBuffer []byte) (buf []byte) {
buf = preAllocatedBuffer
buf = binary.LittleEndian.AppendUint32(buf, u.ID)
buf = binary.LittleEndian.AppendUint32(buf, u.ID.Load())
if u.DisableDownload {
if u.DisableDownload.Load() {
buf = append(buf, 1)
} else {
buf = append(buf, 0)
}
if u.TrackerHide {
if u.TrackerHide.Load() {
buf = append(buf, 1)
} else {
buf = append(buf, 0)
}
buf = binary.LittleEndian.AppendUint64(buf, math.Float64bits(u.UpMultiplier))
buf = binary.LittleEndian.AppendUint64(buf, math.Float64bits(u.DownMultiplier))
buf = binary.LittleEndian.AppendUint64(buf, u.UpMultiplier.Load())
buf = binary.LittleEndian.AppendUint64(buf, u.DownMultiplier.Load())
return buf
}
var encodeJSONUserMap = make(map[string]any)
// MarshalJSON Due to using atomics, JSON will not marshal values within them.
// This is only safe to call from a single thread at once
func (u *User) MarshalJSON() (buf []byte, err error) {
encodeJSONUserMap["ID"] = u.ID.Load()
encodeJSONUserMap["DisableDownload"] = u.DisableDownload.Load()
encodeJSONUserMap["TrackerHide"] = u.TrackerHide.Load()
encodeJSONUserMap["UpMultiplier"] = math.Float64frombits(u.UpMultiplier.Load())
encodeJSONUserMap["DownMultiplier"] = math.Float64frombits(u.UpMultiplier.Load())
return json.Marshal(encodeJSONUserMap)
}
type decodeJSONUser struct {
ID uint32
DisableDownload bool
TrackerHide bool
UpMultiplier float64
DownMultiplier float64
}
// UnmarshalJSON Due to using atomics, JSON will not marshal values within them.
// This is only safe to call from a single thread at once
func (u *User) UnmarshalJSON(buf []byte) (err error) {
var userJSON decodeJSONUser
if err = json.Unmarshal(buf, &userJSON); err != nil {
return err
}
u.ID.Store(userJSON.ID)
u.DisableDownload.Store(userJSON.DisableDownload)
u.TrackerHide.Store(userJSON.TrackerHide)
u.UpMultiplier.Store(math.Float64bits(userJSON.UpMultiplier))
u.DownMultiplier.Store(math.Float64bits(userJSON.DownMultiplier))
return nil
}
type UserTorrentPair struct {
UserID uint32
TorrentID uint32

3
go.mod
View file

@ -5,9 +5,11 @@ go 1.19
require (
github.com/go-sql-driver/mysql v1.7.1
github.com/go-testfixtures/testfixtures/v3 v3.9.0
github.com/google/go-cmp v0.5.9
github.com/jinzhu/copier v0.3.5
github.com/prometheus/client_golang v1.15.1
github.com/prometheus/common v0.44.0
github.com/valyala/fasthttp v1.47.0
github.com/viney-shih/go-lock v1.1.2
github.com/zeebo/bencode v1.0.0
)
@ -32,6 +34,7 @@ require (
github.com/prometheus/procfs v0.9.0 // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/shopspring/decimal v1.3.1 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
go.opentelemetry.io/otel v1.15.0 // indirect
go.opentelemetry.io/otel/trace v1.15.0 // indirect
golang.org/x/sync v0.1.0 // indirect

7
go.sum
View file

@ -32,6 +32,7 @@ github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
@ -88,6 +89,10 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.47.0 h1:y7moDoxYzMooFpT5aHgNgVOQDrS3qlkfiP9mDtGGK9c=
github.com/valyala/fasthttp v1.47.0/go.mod h1:k2zXd82h/7UZc3VOdJ2WaUqt1uZ/XpXAfE9i+HBC3lA=
github.com/viney-shih/go-lock v1.1.2 h1:3TdGTiHZCPqBdTvFbQZQN/TRZzKF3KWw2rFEyKz3YqA=
github.com/viney-shih/go-lock v1.1.2/go.mod h1:Yijm78Ljteb3kRiJrbLAxVntkUukGu5uzSxq/xV7OO8=
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
@ -107,7 +112,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc=
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=

View file

@ -20,7 +20,7 @@ package server
import (
"bytes"
"encoding/json"
"net/http"
"github.com/valyala/fasthttp"
"time"
"chihaya/log"
@ -35,10 +35,10 @@ func alive(buf *bytes.Buffer) int {
res, err := json.Marshal(response{time.Now().UnixMilli(), time.Since(handler.startTime).Milliseconds()})
if err != nil {
log.Error.Print("Failed to marshal JSON alive response: ", err)
return http.StatusInternalServerError
return fasthttp.StatusInternalServerError
}
buf.Write(res)
return http.StatusOK
return fasthttp.StatusOK
}

View file

@ -21,9 +21,9 @@ import (
"bytes"
"context"
"fmt"
"github.com/valyala/fasthttp"
"math"
"net"
"net/http"
"strings"
"time"
@ -111,67 +111,68 @@ func getPublicIPV4(ipAddr string, exists bool) (string, bool) {
return ipAddr, !private
}
func announce(ctx context.Context, qs string, header http.Header, remoteAddr string, user *cdb.User,
db *database.Database, buf *bytes.Buffer) int {
qp, err := params.ParseQuery(qs)
func announce(
ctx context.Context, queryArgs *fasthttp.Args, header *fasthttp.RequestHeader,
remoteAddr net.Addr, user *cdb.User, db *database.Database, buf *bytes.Buffer) int {
qp, err := params.ParseQuery(queryArgs)
if err != nil {
panic(err)
}
// Mandatory parameters
infoHashes := qp.InfoHashes()
peerID, _ := qp.Get("peer_id")
port, portExists := qp.GetUint16("port")
uploaded, uploadedExists := qp.GetUint64("uploaded")
downloaded, downloadedExists := qp.GetUint64("downloaded")
left, leftExists := qp.GetUint64("left")
infoHashes := qp.Params.InfoHashes
peerID := qp.Params.PeerID
port, portExists := qp.Params.Port, qp.Exists.Port
uploaded, uploadedExists := qp.Params.Uploaded, qp.Exists.Uploaded
downloaded, downloadedExists := qp.Params.Downloaded, qp.Exists.Downloaded
left, leftExists := qp.Params.Left, qp.Exists.Left
if infoHashes == nil {
if len(infoHashes) == 0 {
failure("Malformed request - missing info_hash", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
} else if len(infoHashes) > 1 {
failure("Malformed request - can only announce singular info_hash", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if peerID == "" {
if len(peerID) == 0 {
failure("Malformed request - missing peer_id", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if len(peerID) != 20 {
failure("Malformed request - invalid peer_id", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if !portExists {
failure("Malformed request - missing port", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if strictPort && port < 1024 || port > 65535 {
failure(fmt.Sprintf("Malformed request - port outside of acceptable range (port: %d)", port), buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if !uploadedExists {
failure("Malformed request - missing uploaded", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if !downloadedExists {
failure("Malformed request - missing downloaded", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if !leftExists {
failure("Malformed request - missing left", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
ipAddr := func() string {
ipV4, existsV4 := getPublicIPV4(qp.Get("ipv4")) // First try to get IPv4 address if client sent it
ip, exists := getPublicIPV4(qp.Get("ip")) // ... then try to get public IP if sent by client
ipV4, existsV4 := getPublicIPV4(qp.Params.IPv4, qp.Exists.IPv4) // First try to get IPv4 address if client sent it
ip, exists := getPublicIPV4(qp.Params.IP, qp.Exists.IP) // ... then try to get public IP if sent by client
// Fail if ip and ipv4 are not same, and both are provided
if (existsV4 && exists) && (ip != ipV4) {
@ -188,15 +189,17 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
// Check for proxied IP header
if proxyHeader, exists := config.Section("http").Get("proxy_header", ""); exists {
if ips, exists := header[proxyHeader]; exists && len(ips) > 0 {
return ips[0]
if ips := header.PeekAll(proxyHeader); len(ips) > 0 {
return string(ips[0])
}
}
// Check for IP in socket
portIndex := strings.LastIndex(remoteAddr, ":")
remoteAddrString := remoteAddr.String()
portIndex := strings.LastIndex(remoteAddrString, ":")
if portIndex != -1 {
return remoteAddr[0:portIndex]
return remoteAddrString[0:portIndex]
}
// Everything else failed
@ -206,45 +209,36 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
ipBytes := net.ParseIP(ipAddr).To4()
if nil == ipBytes {
failure(fmt.Sprintf("Failed to parse IP address (ip: %s)", ipAddr), buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
clientID, matched := clientApproved(peerID, db)
if !matched {
failure(fmt.Sprintf("Your client is not approved (peer_id: %s)", peerID), buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
if !db.TorrentsLock.RTryLockWithContext(ctx) {
return http.StatusRequestTimeout
}
defer db.TorrentsLock.RUnlock()
torrent, exists := db.Torrents[infoHashes[0]]
torrent, exists := (*db.Torrents.Load())[infoHashes[0]]
if !exists {
failure("This torrent does not exist", buf, 5*time.Minute)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
// Take torrent lock to read/write on it to prevent race conditions
torrent.Lock()
defer torrent.Unlock()
if torrentStatus := torrent.Status.Load(); torrentStatus == 1 && left == 0 {
log.Info.Printf("Unpruning torrent %d", torrent.ID.Load())
if torrent.Status == 1 && left == 0 {
log.Info.Printf("Unpruning torrent %d", torrent.ID)
torrent.Status = 0
torrent.Status.Store(0)
/* It is okay to do this asynchronously as tracker's internal in-memory state has already been updated for this
torrent. While it is technically possible that we will do this more than once in some cases, the state is of
boolean type so there is no risk of data loss. */
go db.UnPrune(torrent)
} else if torrent.Status != 0 {
failure(fmt.Sprintf("This torrent does not exist (status: %d, left: %d)", torrent.Status, left), buf, 15*time.Minute)
return http.StatusOK // Required by torrent clients to interpret failure response
} else if torrentStatus != 0 {
failure(fmt.Sprintf("This torrent does not exist (status: %d, left: %d)", torrentStatus, left), buf, 15*time.Minute)
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
numWant, exists := qp.GetUint16("numwant")
numWant, exists := qp.Params.NumWant, qp.Exists.NumWant
if !exists {
numWant = uint16(defaultNumWant)
} else if numWant > uint16(maxNumWant) {
@ -259,15 +253,21 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
active = true
)
event, _ := qp.Get("event")
event := qp.Params.Event
completed := event == "completed"
peerKey := cdb.NewPeerKey(user.ID, cdb.PeerIDFromRawString(peerID))
peerKey := cdb.NewPeerKey(user.ID.Load(), cdb.PeerIDFromRawString(peerID))
// Take torrent peers lock to read/write on it to prevent race conditions
if !torrent.TryLockWithContext(ctx) {
return fasthttp.StatusRequestTimeout
}
defer torrent.Unlock()
if left > 0 {
if isDisabledDownload(db, user, torrent) {
failure("Your download privileges are disabled", buf, 1*time.Hour)
return http.StatusOK // Required by torrent clients to interpret failure response
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
}
peer, exists = torrent.Leechers[peerKey]
@ -275,6 +275,8 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
newPeer = true
peer = &cdb.Peer{}
torrent.Leechers[peerKey] = peer
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
}
} else if completed {
peer, exists = torrent.Leechers[peerKey]
@ -282,10 +284,15 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
newPeer = true
peer = &cdb.Peer{}
torrent.Seeders[peerKey] = peer
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
} else {
// Previously tracked peer is now a seeder
torrent.Seeders[peerKey] = peer
delete(torrent.Leechers, peerKey)
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
}
seeding = true
} else {
@ -296,12 +303,17 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
newPeer = true
peer = &cdb.Peer{}
torrent.Seeders[peerKey] = peer
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
} else {
/* Previously tracked peer is now a seeder, however we never received their "completed" event.
Broken client? Unreported snatch? Cross-seeding? Let's not report it as snatch to avoid
over-reporting for cross-seeding */
torrent.Seeders[peerKey] = peer
delete(torrent.Leechers, peerKey)
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
}
}
seeding = true
@ -310,8 +322,8 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
// Update peer info/stats
if newPeer {
peer.ID = peerKey.PeerID()
peer.UserID = user.ID
peer.TorrentID = torrent.ID
peer.UserID = user.ID.Load()
peer.TorrentID = torrent.ID.Load()
peer.StartTime = now
peer.LastAnnounce = now
peer.Uploaded = uploaded
@ -333,19 +345,19 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
torrentGroupDownMultiplier := 1.0
torrentGroupUpMultiplier := 1.0
if torrentGroupFreeleech, exists := (*db.TorrentGroupFreeleech.Load())[torrent.Group]; exists {
if torrentGroupFreeleech, exists := (*db.TorrentGroupFreeleech.Load())[torrent.Group.Key()]; exists {
torrentGroupDownMultiplier = torrentGroupFreeleech.DownMultiplier
torrentGroupUpMultiplier = torrentGroupFreeleech.UpMultiplier
}
var deltaDownload int64
if !database.GlobalFreeleech.Load() {
deltaDownload = int64(float64(rawDeltaDownload) * math.Abs(user.DownMultiplier) *
math.Abs(torrentGroupDownMultiplier) * math.Abs(torrent.DownMultiplier))
deltaDownload = int64(float64(rawDeltaDownload) * math.Abs(math.Float64frombits(user.DownMultiplier.Load())) *
math.Abs(torrentGroupDownMultiplier) * math.Abs(math.Float64frombits(torrent.DownMultiplier.Load())))
}
deltaUpload := int64(float64(rawDeltaUpload) * math.Abs(user.UpMultiplier) *
math.Abs(torrentGroupUpMultiplier) * math.Abs(torrent.UpMultiplier))
deltaUpload := int64(float64(rawDeltaUpload) * math.Abs(math.Float64frombits(user.UpMultiplier.Load())) *
math.Abs(torrentGroupUpMultiplier) * math.Abs(math.Float64frombits(torrent.UpMultiplier.Load())))
peer.Uploaded = uploaded
peer.Downloaded = downloaded
peer.Left = left
@ -369,7 +381,7 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
/* Update torrent last_action only if announced action is seeding.
This allows dead torrents without seeder but with leecher to be proeprly pruned */
if seeding {
torrent.LastAction = now
torrent.LastAction.Store(now)
}
var deltaSnatch uint8
@ -380,17 +392,19 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
should be gone, allowing the peer to be GC'd. */
if seeding {
delete(torrent.Seeders, peerKey)
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
} else {
delete(torrent.Leechers, peerKey)
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
}
active = false
} else if completed {
go db.QueueSnatch(peer, now)
db.QueueSnatch(peer, now)
deltaSnatch = 1
}
if user.TrackerHide {
if user.TrackerHide.Load() {
peer.Addr = cdb.NewPeerAddressFromIPPort(net.IP{127, 0, 0, 1}, port) // 127.0.0.1
} else {
peer.Addr = cdb.NewPeerAddressFromIPPort(ipBytes, port)
@ -406,7 +420,7 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
go record.Record(
peer.TorrentID,
user.ID,
user.ID.Load(),
ipAddr,
port,
event,
@ -418,9 +432,9 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
left)
// Generate response
seedCount := len(torrent.Seeders)
leechCount := len(torrent.Leechers)
snatchCount := torrent.Snatched
seedCount := int(torrent.SeedersLength.Load())
leechCount := int(torrent.LeechersLength.Load())
snatchCount := uint16(torrent.Snatched.Load())
response := make(map[string]interface{})
response["complete"] = seedCount
@ -435,11 +449,9 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
response["interval"] = announceInterval + announceDrift
if numWant > 0 && active {
compactString, exists := qp.Get("compact")
compact := !exists || compactString != "0" // Defaults to being compact
compact := !qp.Exists.Compact || !qp.Params.Compact
noPeerIDString, exists := qp.Get("no_peer_id")
noPeerID := exists && noPeerIDString == "1"
noPeerID := qp.Exists.NoPeerID && qp.Params.NoPeerID
var peerCount int
if seeding {
@ -520,14 +532,14 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
// Early exit before response write
select {
case <-ctx.Done():
return http.StatusRequestTimeout
return fasthttp.StatusRequestTimeout
default:
}
encoder := bencode.NewEncoder(buf)
if err = encoder.Encode(response); err != nil {
if err := encoder.Encode(response); err != nil {
panic(err)
}
return http.StatusOK
return fasthttp.StatusOK
}

View file

@ -20,7 +20,7 @@ package server
import (
"bytes"
"context"
"net/http"
"github.com/valyala/fasthttp"
"time"
"chihaya/collectors"
@ -31,39 +31,28 @@ import (
"github.com/prometheus/common/expfmt"
)
var bearerPrefix = "Bearer "
var bearerPrefix = []byte("Bearer ")
func metrics(ctx context.Context, auth string, db *database.Database, buf *bytes.Buffer) int {
if !db.UsersLock.RTryLockWithContext(ctx) {
return http.StatusRequestTimeout
}
defer db.UsersLock.RUnlock()
if !db.TorrentsLock.RTryLockWithContext(ctx) {
return http.StatusRequestTimeout
}
defer db.TorrentsLock.RUnlock()
func metrics(ctx context.Context, auth []byte, db *database.Database, buf *bytes.Buffer) int {
dbUsers := *db.Users.Load()
dbTorrents := *db.Torrents.Load()
peers := 0
for _, t := range db.Torrents {
func() {
t.RLock()
defer t.RUnlock()
peers += len(t.Leechers) + len(t.Seeders)
}()
for _, t := range dbTorrents {
peers += int(t.LeechersLength.Load()) + int(t.SeedersLength.Load())
}
// Early exit before response write
select {
case <-ctx.Done():
return http.StatusRequestTimeout
return fasthttp.StatusRequestTimeout
default:
}
collectors.UpdateUptime(time.Since(handler.startTime).Seconds())
collectors.UpdateUsers(len(db.Users))
collectors.UpdateTorrents(len(db.Torrents))
collectors.UpdateUsers(len(dbUsers))
collectors.UpdateTorrents(len(dbTorrents))
collectors.UpdateClients(len(*db.Clients.Load()))
collectors.UpdateHitAndRuns(len(*db.HitAndRuns.Load()))
collectors.UpdatePeers(peers)
@ -79,9 +68,9 @@ func metrics(ctx context.Context, auth string, db *database.Database, buf *bytes
}
n := len(bearerPrefix)
if len(auth) > n && auth[:n] == bearerPrefix {
if len(auth) > n && bytes.Equal(auth[:n], bearerPrefix) {
adminToken, exists := config.Section("http").Get("admin_token", "")
if exists && auth[n:] == adminToken {
if exists && bytes.Equal(auth[:n], []byte(adminToken)) {
mfs, _ := prometheus.DefaultGatherer.Gather()
for _, mf := range mfs {
@ -93,5 +82,5 @@ func metrics(ctx context.Context, auth string, db *database.Database, buf *bytes
}
}
return http.StatusOK
return fasthttp.StatusOK
}

View file

@ -19,99 +19,153 @@
package params
import (
"net/url"
"strconv"
"strings"
"bytes"
cdb "chihaya/database/types"
"github.com/valyala/fasthttp"
"strconv"
)
type QueryParam struct {
query string
params map[string]string
infoHashes []cdb.TorrentHash
}
Params struct {
Uploaded uint64
Downloaded uint64
Left uint64
func ParseQuery(query string) (qp *QueryParam, err error) {
qp = &QueryParam{
query: query,
infoHashes: nil,
params: make(map[string]string),
Port uint16
NumWant uint16
PeerID string
IPv4 string
IP string
Event string
TestGarbageUnescape string
Compact bool
NoPeerID bool
InfoHashes []cdb.TorrentHash
}
for query != "" {
key := query
if i := strings.Index(key, "&"); i >= 0 {
key, query = key[:i], key[i+1:]
} else {
query = ""
Exists struct {
Uploaded bool
Downloaded bool
Left bool
Port bool
NumWant bool
PeerID bool
IPv4 bool
IP bool
Event bool
TestGarbageUnescape bool
Compact bool
NoPeerID bool
InfoHashes bool
}
}
var uploadedKey = []byte("uploaded")
var downloadedKey = []byte("downloaded")
var leftKey = []byte("left")
var portKey = []byte("port")
var numWant = []byte("numwant")
var peerIDKey = []byte("peer_id")
var ipv4Key = []byte("ipv4")
var ipKey = []byte("ip")
var eventKey = []byte("event")
var testGarbageUnescapeKey = []byte("!@#")
var infoHashKey = []byte("info_hash")
var compactKey = []byte("compact")
var noPeerIDKey = []byte("no_peer_id")
func ParseQuery(queryArgs *fasthttp.Args) (qp QueryParam, err error) {
var returnError error
queryArgs.VisitAll(func(key, value []byte) {
if returnError != nil {
return
}
if key == "" {
continue
}
value := ""
if i := strings.Index(key, "="); i >= 0 {
key, value = key[:i], key[i+1:]
}
key, err = url.QueryUnescape(key)
if err != nil {
panic(err)
}
value, err = url.QueryUnescape(value)
if err != nil {
panic(err)
}
if key == "info_hash" {
if len(value) == cdb.TorrentHashSize {
qp.infoHashes = append(qp.infoHashes, cdb.TorrentHashFromBytes([]byte(value)))
key = bytes.ToLower(key)
switch true {
case bytes.Equal(key, uploadedKey):
n, err := strconv.ParseUint(string(value), 10, 0)
if err != nil {
returnError = err
return
}
} else {
qp.params[strings.ToLower(key)] = value
qp.Params.Uploaded = n
qp.Exists.Uploaded = true
case bytes.Equal(key, downloadedKey):
n, err := strconv.ParseUint(string(value), 10, 0)
if err != nil {
returnError = err
return
}
qp.Params.Downloaded = n
qp.Exists.Downloaded = true
case bytes.Equal(key, leftKey):
n, err := strconv.ParseUint(string(value), 10, 0)
if err != nil {
returnError = err
return
}
qp.Params.Left = n
qp.Exists.Left = true
case bytes.Equal(key, portKey):
n, err := strconv.ParseUint(string(value), 10, 16)
if err != nil {
returnError = err
return
}
qp.Params.Port = uint16(n)
qp.Exists.Port = true
case bytes.Equal(key, numWant):
n, err := strconv.ParseUint(string(value), 10, 16)
if err != nil {
returnError = err
return
}
qp.Params.NumWant = uint16(n)
qp.Exists.NumWant = true
case bytes.Equal(key, peerIDKey):
qp.Params.PeerID = string(value)
qp.Exists.PeerID = true
case bytes.Equal(key, ipv4Key):
qp.Params.IPv4 = string(value)
qp.Exists.IPv4 = true
case bytes.Equal(key, ipKey):
qp.Params.IP = string(value)
qp.Exists.IP = true
case bytes.Equal(key, eventKey):
qp.Params.Event = string(value)
qp.Exists.Event = true
case bytes.Equal(key, testGarbageUnescapeKey):
qp.Params.TestGarbageUnescape = string(value)
qp.Exists.TestGarbageUnescape = true
case bytes.Equal(key, infoHashKey):
if len(value) == cdb.TorrentHashSize {
qp.Params.InfoHashes = append(qp.Params.InfoHashes, cdb.TorrentHashFromBytes(value))
qp.Exists.InfoHashes = true
}
case bytes.Equal(key, compactKey):
qp.Params.Compact = bytes.Equal(value, []byte{'1'})
qp.Exists.Compact = true
case bytes.Equal(key, noPeerIDKey):
qp.Params.NoPeerID = bytes.Equal(value, []byte{'1'})
qp.Exists.NoPeerID = true
}
}
})
return qp, nil
}
func (qp *QueryParam) getUint(which string, bitSize int) (ret uint64, exists bool) {
str, exists := qp.params[which]
if exists {
var err error
ret, err = strconv.ParseUint(str, 10, bitSize)
if err != nil {
exists = false
}
}
return
}
func (qp *QueryParam) Get(which string) (ret string, exists bool) {
ret, exists = qp.params[which]
return
}
func (qp *QueryParam) GetUint64(which string) (ret uint64, exists bool) {
return qp.getUint(which, 64)
}
func (qp *QueryParam) GetUint16(which string) (ret uint16, exists bool) {
tmp, exists := qp.getUint(which, 16)
ret = uint16(tmp)
return
}
func (qp *QueryParam) InfoHashes() []cdb.TorrentHash {
return qp.infoHashes
}
func (qp *QueryParam) RawQuery() string {
return qp.query
return qp, returnError
}

View file

@ -19,6 +19,7 @@ package params
import (
"fmt"
"github.com/valyala/fasthttp"
"net/url"
"os"
"reflect"
@ -44,83 +45,95 @@ func TestMain(m *testing.M) {
}
func TestParseQuery(t *testing.T) {
query := ""
var queryParsed QueryParam
queryParsed.Params.Event, queryParsed.Exists.Event = "completed", true
queryParsed.Params.Port, queryParsed.Exists.Port = 25362, true
queryParsed.Params.PeerID, queryParsed.Exists.Left = "-CH010-VnpZR7uz31I1A", true
queryParsed.Params.Left, queryParsed.Exists.Left = 0, true
query := fmt.Sprintf("event=%s&port=%d&peer_id=%s&left=%d",
queryParsed.Params.Event,
queryParsed.Params.Port,
queryParsed.Params.PeerID,
queryParsed.Params.Left,
)
for _, infoHash := range infoHashes {
query += "info_hash=" + url.QueryEscape(string(infoHash[:])) + "&"
queryParsed.Params.InfoHashes = append(queryParsed.Params.InfoHashes, infoHash)
queryParsed.Exists.InfoHashes = true
query += "&info_hash=" + url.QueryEscape(string(infoHash[:]))
}
queryMap := make(map[string]string)
queryMap["event"] = "completed"
queryMap["port"] = "25362"
queryMap["peer_id"] = "-CH010-VnpZR7uz31I1A"
queryMap["left"] = "0"
args := fasthttp.Args{}
args.Parse(query)
for k, v := range queryMap {
query += k + "=" + v + "&"
}
query = query[:len(query)-1]
qp, err := ParseQuery(query)
qp, err := ParseQuery(&args)
if err != nil {
panic(err)
t.Fatal(err)
}
if !reflect.DeepEqual(qp.params, queryMap) {
t.Fatalf("Parsed query map (%v) is not deeply equal as original (%v)!", qp.params, queryMap)
}
if !reflect.DeepEqual(qp.infoHashes, infoHashes) {
t.Fatalf("Parsed info hashes (%v) are not deeply equal as original (%v)!", qp.infoHashes, infoHashes)
if !reflect.DeepEqual(qp, queryParsed) {
t.Fatalf("Parsed query map (%v) is not deeply equal as original (%v)!", qp, queryParsed)
}
}
func TestBrokenParseQuery(t *testing.T) {
brokenQueryMap := make(map[string]string)
brokenQueryMap["event"] = "started"
brokenQueryMap["bug"] = ""
brokenQueryMap["yes"] = ""
var brokenQueryParsed QueryParam
brokenQueryParsed.Params.Event, brokenQueryParsed.Exists.Event = "started", true
brokenQueryParsed.Params.IPv4, brokenQueryParsed.Exists.IPv4 = "", true
brokenQueryParsed.Params.IP, brokenQueryParsed.Exists.IP = "", true
qp, err := ParseQuery("event=started&bug=&yes=")
args := fasthttp.Args{}
args.Parse("event=started&ipv4=&ip=")
qp, err := ParseQuery(&args)
if err != nil {
panic(err)
t.Fatal(err)
}
if !reflect.DeepEqual(qp.params, brokenQueryMap) {
t.Fatalf("Parsed query map (%v) is not deeply equal as original (%v)!", qp.params, brokenQueryMap)
if !reflect.DeepEqual(qp, brokenQueryParsed) {
t.Fatalf("Parsed query map (%v) is not deeply equal as original (%v)!", qp, brokenQueryParsed)
}
}
func TestLowerKey(t *testing.T) {
qp, err := ParseQuery("EvEnT=c0mPl3tED")
args := fasthttp.Args{}
args.Parse("EvEnT=c0mPl3tED")
qp, err := ParseQuery(&args)
if err != nil {
panic(err)
t.Fatal(err)
}
if param, exists := qp.Get("event"); !exists || param != "c0mPl3tED" {
if param, exists := qp.Params.Event, qp.Exists.Event; !exists || param != "c0mPl3tED" {
t.Fatalf("Got parsed value %s but expected c0mPl3tED for \"event\"!", param)
}
}
func TestUnescape(t *testing.T) {
qp, err := ParseQuery("%21%40%23=%24%25%5E")
args := fasthttp.Args{}
args.Parse("%21%40%23=%24%25%5E")
qp, err := ParseQuery(&args)
if err != nil {
panic(err)
t.Fatal(err)
}
if param, exists := qp.Get("!@#"); !exists || param != "$%^" {
if param, exists := qp.Params.TestGarbageUnescape, qp.Exists.TestGarbageUnescape; !exists || param != "$%^" {
t.Fatal(fmt.Sprintf("Got parsed value %s but expected", param), "$%^ for \"!@#\"!")
}
}
func TestGet(t *testing.T) {
qp, err := ParseQuery("event=completed")
func TestString(t *testing.T) {
args := fasthttp.Args{}
args.Parse("event=completed")
qp, err := ParseQuery(&args)
if err != nil {
panic(err)
t.Fatal(err)
}
if param, exists := qp.Get("event"); !exists || param != "completed" {
if param, exists := qp.Params.Event, qp.Exists.Event; !exists || param != "completed" {
t.Fatalf("Got parsed value %s but expected completed for \"event\"!", param)
}
}
@ -128,12 +141,15 @@ func TestGet(t *testing.T) {
func TestGetUint64(t *testing.T) {
val := uint64(1<<62 + 42)
qp, err := ParseQuery("left=" + strconv.FormatUint(val, 10))
args := fasthttp.Args{}
args.Parse("left=" + strconv.FormatUint(val, 10))
qp, err := ParseQuery(&args)
if err != nil {
panic(err)
t.Fatal(err)
}
if parsedVal, exists := qp.GetUint64("left"); !exists || parsedVal != val {
if parsedVal, exists := qp.Params.Left, qp.Exists.Left; !exists || parsedVal != val {
t.Fatalf("Got parsed value %v but expected %v for \"left\"!", parsedVal, val)
}
}
@ -141,12 +157,15 @@ func TestGetUint64(t *testing.T) {
func TestGetUint16(t *testing.T) {
val := uint16(1<<15 + 4242)
qp, err := ParseQuery("port=" + strconv.FormatUint(uint64(val), 10))
args := fasthttp.Args{}
args.Parse("port=" + strconv.FormatUint(uint64(val), 10))
qp, err := ParseQuery(&args)
if err != nil {
panic(err)
t.Fatal(err)
}
if parsedVal, exists := qp.GetUint16("port"); !exists || parsedVal != val {
if parsedVal, exists := qp.Params.Port, qp.Exists.Port; !exists || parsedVal != val {
t.Fatalf("Got parsed value %v but expected %v for \"port\"!", parsedVal, val)
}
}
@ -160,25 +179,15 @@ func TestInfoHashes(t *testing.T) {
query = query[:len(query)-1]
qp, err := ParseQuery(query)
args := fasthttp.Args{}
args.Parse(query)
qp, err := ParseQuery(&args)
if err != nil {
panic(err)
t.Fatal(err)
}
if !reflect.DeepEqual(qp.InfoHashes(), infoHashes) {
t.Fatalf("Parsed info hashes (%v) are not deeply equal as original (%v)!", qp.InfoHashes(), infoHashes)
}
}
func TestRawQuery(t *testing.T) {
q := "event=completed&port=25541&left=0&uploaded=0&downloaded=0"
qp, err := ParseQuery(q)
if err != nil {
panic(err)
}
if rq := qp.RawQuery(); rq != q {
t.Fatalf("Got raw query %s but expected %s", rq, q)
if !reflect.DeepEqual(qp.Params.InfoHashes, infoHashes) {
t.Fatalf("Parsed info hashes (%v) are not deeply equal as original (%v)!", qp.Params.InfoHashes, infoHashes)
}
}

View file

@ -19,13 +19,12 @@ package server
import (
"bytes"
"context"
"net/http"
"chihaya/config"
"chihaya/database"
cdb "chihaya/database/types"
"chihaya/server/params"
"context"
"github.com/valyala/fasthttp"
"github.com/zeebo/bencode"
)
@ -37,19 +36,18 @@ func init() {
}
func writeScrapeInfo(torrent *cdb.Torrent) map[string]interface{} {
torrent.RLock()
defer torrent.RUnlock()
ret := make(map[string]interface{})
ret["complete"] = len(torrent.Seeders)
ret["downloaded"] = torrent.Snatched
ret["incomplete"] = len(torrent.Leechers)
ret["complete"] = torrent.SeedersLength.Load()
ret["downloaded"] = torrent.Snatched.Load()
ret["incomplete"] = torrent.LeechersLength.Load()
return ret
}
func scrape(ctx context.Context, qs string, user *cdb.User, db *database.Database, buf *bytes.Buffer) int {
qp, err := params.ParseQuery(qs)
func scrape(
ctx context.Context, queryArgs *fasthttp.Args,
user *cdb.User, db *database.Database, buf *bytes.Buffer) int {
qp, err := params.ParseQuery(queryArgs)
if err != nil {
panic(err)
}
@ -57,14 +55,11 @@ func scrape(ctx context.Context, qs string, user *cdb.User, db *database.Databas
scrapeData := make(map[string]interface{})
fileData := make(map[cdb.TorrentHash]interface{})
if !db.TorrentsLock.RTryLockWithContext(ctx) {
return http.StatusRequestTimeout
}
defer db.TorrentsLock.RUnlock()
dbTorrents := *db.Torrents.Load()
if qp.InfoHashes() != nil {
for _, infoHash := range qp.InfoHashes() {
torrent, exists := db.Torrents[infoHash]
if len(qp.Params.InfoHashes) > 0 {
for _, infoHash := range qp.Params.InfoHashes {
torrent, exists := dbTorrents[infoHash]
if exists {
if !isDisabledDownload(db, user, torrent) {
fileData[infoHash] = writeScrapeInfo(torrent)
@ -86,14 +81,14 @@ func scrape(ctx context.Context, qs string, user *cdb.User, db *database.Databas
// Early exit before response write
select {
case <-ctx.Done():
return http.StatusRequestTimeout
return fasthttp.StatusRequestTimeout
default:
}
encoder := bencode.NewEncoder(buf)
if err = encoder.Encode(scrapeData); err != nil {
if err := encoder.Encode(scrapeData); err != nil {
panic(err)
}
return http.StatusOK
return fasthttp.StatusOK
}

View file

@ -20,10 +20,9 @@ package server
import (
"bytes"
"context"
"github.com/valyala/fasthttp"
"net"
"net/http"
"path"
"strconv"
"sync"
"sync/atomic"
"time"
@ -63,10 +62,10 @@ var (
listener net.Listener
)
func (handler *httpHandler) respond(ctx context.Context, r *http.Request, buf *bytes.Buffer) int {
dir, action := path.Split(r.URL.Path)
func (handler *httpHandler) respond(ctx context.Context, requestContext *fasthttp.RequestCtx, buf *bytes.Buffer) int {
dir, action := path.Split(string(requestContext.Request.URI().Path()))
if action == "" {
return http.StatusNotFound
return fasthttp.StatusNotFound
}
/*
@ -82,7 +81,7 @@ func (handler *httpHandler) respond(ctx context.Context, r *http.Request, buf *b
return alive(buf)
}
return http.StatusNotFound
return fasthttp.StatusNotFound
}
/*
@ -91,31 +90,36 @@ func (handler *httpHandler) respond(ctx context.Context, r *http.Request, buf *b
* ===================================================
*/
user, err := isPasskeyValid(ctx, passkey, handler.db)
if err != nil {
return http.StatusRequestTimeout
} else if user == nil {
user := isPasskeyValid(passkey, handler.db)
if user == nil {
failure("Your passkey is invalid", buf, 1*time.Hour)
return http.StatusOK
return fasthttp.StatusOK
}
switch action {
case "announce":
return announce(ctx, r.URL.RawQuery, r.Header, r.RemoteAddr, user, handler.db, buf)
return announce(
ctx,
requestContext.Request.URI().QueryArgs(),
&requestContext.Request.Header,
requestContext.RemoteAddr(),
user,
handler.db,
buf)
case "scrape":
if enabled, _ := config.GetBool("scrape", true); !enabled {
return http.StatusNotFound
return fasthttp.StatusNotFound
}
return scrape(ctx, r.URL.RawQuery, user, handler.db, buf)
return scrape(ctx, requestContext.Request.URI().QueryArgs(), user, handler.db, buf)
case "metrics":
return metrics(ctx, r.Header.Get("Authorization"), handler.db, buf)
return metrics(ctx, requestContext.Request.Header.PeekBytes([]byte("Authorization")), handler.db, buf)
}
return http.StatusNotFound
return fasthttp.StatusNotFound
}
func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (handler *httpHandler) ServeHTTP(requestContext *fasthttp.RequestCtx) {
if handler.terminate {
return
}
@ -127,9 +131,6 @@ func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
handler.waitGroup.Add(1)
defer handler.waitGroup.Done()
// Flush buffered data to client
defer w.(http.Flusher).Flush()
buf := handler.bufferPool.Take()
// Mark buf to be returned to bufferPool after we are done with it
defer handler.bufferPool.Give(buf)
@ -137,46 +138,48 @@ func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Gracefully handle panics so that they're confined to single request and don't crash server
defer func() {
if err := recover(); err != nil {
log.Error.Printf("Recovered from panicking request handler - %v\nURL was: %s", err, r.URL)
log.Error.Printf("Recovered from panicking request handler - %v\nURL was: %s", err, requestContext.URI().String())
log.WriteStack()
collectors.IncrementErroredRequests()
if w.Header().Get("Content-Type") == "" {
if len(requestContext.Response.Header.ContentType()) == 0 {
buf.Reset()
w.WriteHeader(http.StatusInternalServerError)
requestContext.Response.SetStatusCode(fasthttp.StatusInternalServerError)
}
}
}()
// Prepare and start new context with timeout to abort long-running requests
ctx, cancel := context.WithTimeout(r.Context(), handler.contextTimeout)
ctx, cancel := context.WithTimeout(context.Background(), handler.contextTimeout)
defer cancel()
/* Pass flow to handler; note that handler should be responsible for actually cancelling
its own work based on request context cancellation */
status := handler.respond(ctx, r, buf)
status := handler.respond(ctx, requestContext, buf)
select {
case <-ctx.Done():
collectors.IncrementTimeoutRequests()
switch err := ctx.Err(); err {
case context.DeadlineExceeded:
collectors.IncrementTimeoutRequests()
failure("Request context deadline exceeded", buf, 5*time.Minute)
w.Header().Add("Content-Length", strconv.Itoa(buf.Len()))
w.Header().Add("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK) // Required by torrent clients to interpret failure response
_, _ = w.Write(buf.Bytes())
default:
w.WriteHeader(http.StatusRequestTimeout)
requestContext.Request.Header.SetContentLength(buf.Len())
requestContext.Request.Header.SetContentTypeBytes([]byte("text/plain"))
requestContext.Response.SetStatusCode(fasthttp.StatusOK) // Required by torrent clients to interpret failure response
_, _ = requestContext.Write(buf.Bytes())
case context.Canceled:
collectors.IncrementCancelRequests()
requestContext.Response.SetStatusCode(fasthttp.StatusRequestTimeout)
}
default:
w.Header().Add("Content-Length", strconv.Itoa(buf.Len()))
w.Header().Add("Content-Type", "text/plain")
w.WriteHeader(status)
_, _ = w.Write(buf.Bytes())
requestContext.Request.Header.SetContentLength(buf.Len())
requestContext.Request.Header.SetContentTypeBytes([]byte("text/plain"))
requestContext.Response.SetStatusCode(status)
_, _ = requestContext.Write(buf.Bytes())
}
}
@ -190,7 +193,6 @@ func Start() {
addr, _ := config.Section("http").Get("addr", ":34000")
readTimeout, _ := config.Section("http").Section("timeout").GetInt("read", 1)
readHeaderTimeout, _ := config.Section("http").Section("timeout").GetInt("read_header", 2)
writeTimeout, _ := config.Section("http").Section("timeout").GetInt("write", 3)
idleTimeout, _ := config.Section("http").Section("timeout").GetInt("idle", 30)
@ -198,17 +200,23 @@ func Start() {
handler.contextTimeout = time.Duration(writeTimeout)*time.Second - 200*time.Millisecond
// Create new server instance
server := &http.Server{
Handler: handler,
ReadTimeout: time.Duration(readTimeout) * time.Second,
ReadHeaderTimeout: time.Duration(readHeaderTimeout) * time.Second,
WriteTimeout: time.Duration(writeTimeout) * time.Second,
IdleTimeout: time.Duration(idleTimeout) * time.Second,
server := &fasthttp.Server{
Handler: handler.ServeHTTP,
ReadTimeout: time.Duration(readTimeout) * time.Second,
WriteTimeout: time.Duration(writeTimeout) * time.Second,
IdleTimeout: time.Duration(idleTimeout) * time.Second,
GetOnly: true,
DisablePreParseMultipartForm: true,
NoDefaultServerHeader: true,
NoDefaultDate: true,
NoDefaultContentType: true,
CloseOnShutdown: true,
}
if idleTimeout <= 0 {
log.Warning.Print("Setting idleTimeout <= 0 disables Keep-Alive which might negatively impact performance")
server.SetKeepAlivesEnabled(false)
server.DisableKeepalive = true
}
// Start new goroutine to calculate throughput
@ -261,8 +269,7 @@ func Start() {
// Wait for active connections to finish processing
handler.waitGroup.Wait()
// Close server so that it does not Accept(), https://github.com/golang/go/issues/10527
_ = server.Close()
_ = server.Shutdown()
log.Info.Print("Now closed and not accepting any new connections")

View file

@ -19,7 +19,6 @@ package server
import (
"bytes"
"context"
"time"
"chihaya/database"
@ -69,18 +68,13 @@ func clientApproved(peerID string, db *database.Database) (uint16, bool) {
return 0, false
}
func isPasskeyValid(ctx context.Context, passkey string, db *database.Database) (*cdb.User, error) {
if !db.UsersLock.RTryLockWithContext(ctx) {
return nil, ctx.Err()
}
defer db.UsersLock.RUnlock()
user, exists := db.Users[passkey]
func isPasskeyValid(passkey string, db *database.Database) *cdb.User {
user, exists := (*db.Users.Load())[passkey]
if !exists {
return nil, nil
return nil
}
return user, nil
return user
}
func hasHitAndRun(db *database.Database, userID, torrentID uint32) bool {
@ -96,5 +90,5 @@ func hasHitAndRun(db *database.Database, userID, torrentID uint32) bool {
func isDisabledDownload(db *database.Database, user *cdb.User, torrent *cdb.Torrent) bool {
// Only disable download if the torrent doesn't have a HnR against it
return user.DisableDownload && !hasHitAndRun(db, user.ID, torrent.ID)
return user.DisableDownload.Load() && !hasHitAndRun(db, user.ID.Load(), torrent.ID.Load())
}

20
util/context_ticker.go Normal file
View file

@ -0,0 +1,20 @@
package util
import (
"context"
"time"
)
func ContextTick(ctx context.Context, d time.Duration, onTick func()) {
ticker := time.NewTicker(d)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
onTick()
}
}
}