Compare commits
7 commits
d766e251b4
...
690a2df3e1
Author | SHA1 | Date | |
---|---|---|---|
DataHoarder | 690a2df3e1 | ||
DataHoarder | 5e669c3b60 | ||
DataHoarder | 4449bde9bb | ||
DataHoarder | 6fc797f884 | ||
DataHoarder | ceb2fe909c | ||
DataHoarder | 326e6adb07 | ||
046b30b7a9 |
11
README.md
11
README.md
|
@ -58,12 +58,7 @@ Configuration is done in `config.json`, which you'll need to create with the fol
|
|||
```json
|
||||
{
|
||||
"database": {
|
||||
"username": "chihaya",
|
||||
"password": "",
|
||||
"database": "chihaya",
|
||||
"proto": "tcp",
|
||||
"addr": "127.0.0.1:3306",
|
||||
|
||||
"dsn": "chihaya:@tcp(127.0.0.1:3306)/chihaya",
|
||||
"deadlock_pause": 1,
|
||||
"deadlock_retries": 5
|
||||
},
|
||||
|
@ -96,7 +91,6 @@ Configuration is done in `config.json`, which you'll need to create with the fol
|
|||
"proxy_header": "",
|
||||
"timeout": {
|
||||
"read": 1,
|
||||
"read_header": 2,
|
||||
"write": 1,
|
||||
"idle": 30
|
||||
}
|
||||
|
@ -143,8 +137,7 @@ Configuration is done in `config.json`, which you'll need to create with the fol
|
|||
- `admin_token` - administrative token used in `Authorization` header to access advanced prometheus statistics
|
||||
- `proxy_header` - header name to look for user's real IP address, for example `X-Real-Ip`
|
||||
- `timeout`
|
||||
- `read_header` - timeout in seconds for reading request headers
|
||||
- `read` - timeout in seconds for reading request; is reset after headers are read
|
||||
- `read` - timeout in seconds for reading request
|
||||
- `write` - timeout in seconds for writing response (total time spent)
|
||||
- `idle` - how long (in seconds) to keep connection open for keep-alive requests
|
||||
- `announce`
|
||||
|
|
|
@ -133,17 +133,17 @@ func main() {
|
|||
newUserID++
|
||||
|
||||
// Create mapping to get consistent peers
|
||||
anonUserMapping[user.ID] = newUserID
|
||||
anonUserMapping[user.ID.Load()] = newUserID
|
||||
|
||||
// Replaces user id
|
||||
user.ID = newUserID
|
||||
user.ID.Store(newUserID)
|
||||
|
||||
// Replaces hidden flag
|
||||
user.TrackerHide = false
|
||||
user.TrackerHide.Store(false)
|
||||
|
||||
// Replace Up/Down multipliers with baseline
|
||||
user.UpMultiplier = 1.0
|
||||
user.DownMultiplier = 1.0
|
||||
user.UpMultiplier.Store(math.Float64bits(1.0))
|
||||
user.DownMultiplier.Store(math.Float64bits(1.0))
|
||||
|
||||
// Replace passkey with a random provided one of same length
|
||||
for {
|
||||
|
|
|
@ -32,6 +32,7 @@ type AdminCollector struct {
|
|||
deadlockAbortedMetric *prometheus.Desc
|
||||
erroredRequestsMetric *prometheus.Desc
|
||||
timeoutRequestsMetric *prometheus.Desc
|
||||
cancelRequestsMetric *prometheus.Desc
|
||||
sqlErrorCountMetric *prometheus.Desc
|
||||
|
||||
serializationTimeSummary *prometheus.Histogram
|
||||
|
@ -81,6 +82,7 @@ var (
|
|||
deadlockAborted = 0
|
||||
erroredRequests = 0
|
||||
timeoutRequests = 0
|
||||
cancelRequests = 0
|
||||
sqlErrorCount = 0
|
||||
)
|
||||
|
||||
|
@ -131,6 +133,8 @@ func NewAdminCollector() *AdminCollector {
|
|||
"Number of failed requests", nil, nil),
|
||||
timeoutRequestsMetric: prometheus.NewDesc("chihaya_requests_timeout",
|
||||
"Number of requests for which context deadline was exceeded", nil, nil),
|
||||
cancelRequestsMetric: prometheus.NewDesc("chihaya_requests_cancel",
|
||||
"Number of requests for which context was prematurely cancelled", nil, nil),
|
||||
sqlErrorCountMetric: prometheus.NewDesc("chihaya_sql_errors_count",
|
||||
"Number of SQL errors", nil, nil),
|
||||
|
||||
|
@ -152,6 +156,7 @@ func (collector *AdminCollector) Describe(ch chan<- *prometheus.Desc) {
|
|||
ch <- collector.deadlockTimeMetric
|
||||
ch <- collector.erroredRequestsMetric
|
||||
ch <- collector.timeoutRequestsMetric
|
||||
ch <- collector.cancelRequestsMetric
|
||||
ch <- collector.sqlErrorCountMetric
|
||||
|
||||
serializationTime.Describe(ch)
|
||||
|
@ -171,6 +176,7 @@ func (collector *AdminCollector) Collect(ch chan<- prometheus.Metric) {
|
|||
ch <- prometheus.MustNewConstMetric(collector.deadlockTimeMetric, prometheus.CounterValue, deadlockTime.Seconds())
|
||||
ch <- prometheus.MustNewConstMetric(collector.erroredRequestsMetric, prometheus.CounterValue, float64(erroredRequests))
|
||||
ch <- prometheus.MustNewConstMetric(collector.timeoutRequestsMetric, prometheus.CounterValue, float64(timeoutRequests))
|
||||
ch <- prometheus.MustNewConstMetric(collector.cancelRequestsMetric, prometheus.CounterValue, float64(cancelRequests))
|
||||
ch <- prometheus.MustNewConstMetric(collector.sqlErrorCountMetric, prometheus.CounterValue, float64(sqlErrorCount))
|
||||
|
||||
serializationTime.Collect(ch)
|
||||
|
@ -204,6 +210,10 @@ func IncrementTimeoutRequests() {
|
|||
timeoutRequests++
|
||||
}
|
||||
|
||||
func IncrementCancelRequests() {
|
||||
cancelRequests++
|
||||
}
|
||||
|
||||
func IncrementSQLErrorCount() {
|
||||
sqlErrorCount++
|
||||
}
|
||||
|
|
|
@ -40,18 +40,16 @@ func TestMain(m *testing.M) {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
f, err := os.OpenFile("config.json", os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
f, err := os.OpenFile(configFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
configTest = make(Map)
|
||||
dbConfig := map[string]interface{}{
|
||||
"username": "chihaya",
|
||||
"password": "",
|
||||
"proto": "tcp",
|
||||
"addr": "127.0.0.1:3306",
|
||||
"database": "chihaya",
|
||||
"dsn": "chihaya:@tcp(127.0.0.1:3306)/chihaya",
|
||||
"deadlock_pause": json.Number("1"),
|
||||
"deadlock_retries": json.Number("5"),
|
||||
}
|
||||
configTest["database"] = dbConfig
|
||||
configTest["addr"] = ":34000"
|
||||
|
@ -148,5 +146,5 @@ func TestSection(t *testing.T) {
|
|||
}
|
||||
|
||||
func cleanup() {
|
||||
_ = os.Remove("config.json")
|
||||
_ = os.Remove(configFile)
|
||||
}
|
||||
|
|
|
@ -19,9 +19,8 @@ package database
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/viney-shih/go-lock"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -36,15 +35,7 @@ import (
|
|||
"github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
type Connection struct {
|
||||
sqlDb *sql.DB
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
type Database struct {
|
||||
TorrentsLock lock.RWMutex
|
||||
UsersLock lock.RWMutex
|
||||
|
||||
snatchChannel chan *bytes.Buffer
|
||||
transferHistoryChannel chan *bytes.Buffer
|
||||
transferIpsChannel chan *bytes.Buffer
|
||||
|
@ -60,19 +51,21 @@ type Database struct {
|
|||
cleanStalePeersStmt *sql.Stmt
|
||||
unPruneTorrentStmt *sql.Stmt
|
||||
|
||||
Users map[string]*cdb.User
|
||||
Users atomic.Pointer[map[string]*cdb.User]
|
||||
HitAndRuns atomic.Pointer[map[cdb.UserTorrentPair]struct{}]
|
||||
Torrents map[cdb.TorrentHash]*cdb.Torrent // SHA-1 hash (20 bytes)
|
||||
TorrentGroupFreeleech atomic.Pointer[map[cdb.TorrentGroup]*cdb.TorrentGroupFreeleech]
|
||||
Torrents atomic.Pointer[map[cdb.TorrentHash]*cdb.Torrent]
|
||||
TorrentGroupFreeleech atomic.Pointer[map[cdb.TorrentGroupKey]*cdb.TorrentGroupFreeleech]
|
||||
Clients atomic.Pointer[map[uint16]string]
|
||||
|
||||
mainConn *Connection // Used for reloading and misc queries
|
||||
|
||||
bufferPool *util.BufferPool
|
||||
|
||||
transferHistoryLock lock.Mutex
|
||||
transferHistoryLock sync.Mutex
|
||||
|
||||
terminate bool
|
||||
conn *sql.DB
|
||||
|
||||
terminate atomic.Bool
|
||||
ctx context.Context
|
||||
ctxCancel func()
|
||||
waitGroup sync.WaitGroup
|
||||
}
|
||||
|
||||
|
@ -81,82 +74,76 @@ var (
|
|||
maxDeadlockRetries int
|
||||
)
|
||||
|
||||
var defaultDsn = map[string]string{
|
||||
"username": "chihaya",
|
||||
"password": "",
|
||||
"proto": "tcp",
|
||||
"addr": "127.0.0.1:3306",
|
||||
"database": "chihaya",
|
||||
}
|
||||
const defaultDsn = "chihaya:@tcp(127.0.0.1:3306)/chihaya"
|
||||
|
||||
func (db *Database) Init() {
|
||||
db.terminate = false
|
||||
db.terminate.Store(false)
|
||||
db.ctx, db.ctxCancel = context.WithCancel(context.Background())
|
||||
|
||||
log.Info.Print("Opening database connection...")
|
||||
|
||||
db.mainConn = Open()
|
||||
|
||||
// Initializing locks
|
||||
db.TorrentsLock = lock.NewCASMutex()
|
||||
db.UsersLock = lock.NewCASMutex()
|
||||
db.conn = Open()
|
||||
|
||||
// Used for recording updates, so the max required size should be < 128 bytes. See queue.go for details
|
||||
db.bufferPool = util.NewBufferPool(128)
|
||||
|
||||
var err error
|
||||
|
||||
db.loadUsersStmt, err = db.mainConn.sqlDb.Prepare(
|
||||
db.loadUsersStmt, err = db.conn.Prepare(
|
||||
"SELECT ID, torrent_pass, DownMultiplier, UpMultiplier, DisableDownload, TrackerHide " +
|
||||
"FROM users_main WHERE Enabled = '1'")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
db.loadHnrStmt, err = db.mainConn.sqlDb.Prepare(
|
||||
db.loadHnrStmt, err = db.conn.Prepare(
|
||||
"SELECT h.uid, h.fid FROM transfer_history AS h " +
|
||||
"JOIN users_main AS u ON u.ID = h.uid WHERE h.hnr = 1 AND u.Enabled = '1'")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
db.loadTorrentsStmt, err = db.mainConn.sqlDb.Prepare(
|
||||
db.loadTorrentsStmt, err = db.conn.Prepare(
|
||||
"SELECT ID, info_hash, DownMultiplier, UpMultiplier, Snatched, Status, GroupID, TorrentType FROM torrents")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
db.loadTorrentGroupFreeleechStmt, err = db.mainConn.sqlDb.Prepare(
|
||||
db.loadTorrentGroupFreeleechStmt, err = db.conn.Prepare(
|
||||
"SELECT GroupID, `Type`, DownMultiplier, UpMultiplier FROM torrent_group_freeleech")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
db.loadClientsStmt, err = db.mainConn.sqlDb.Prepare(
|
||||
db.loadClientsStmt, err = db.conn.Prepare(
|
||||
"SELECT id, peer_id FROM approved_clients WHERE archived = 0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
db.loadFreeleechStmt, err = db.mainConn.sqlDb.Prepare(
|
||||
db.loadFreeleechStmt, err = db.conn.Prepare(
|
||||
"SELECT mod_setting FROM mod_core WHERE mod_option = 'global_freeleech'")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
db.cleanStalePeersStmt, err = db.mainConn.sqlDb.Prepare(
|
||||
db.cleanStalePeersStmt, err = db.conn.Prepare(
|
||||
"UPDATE transfer_history SET active = 0 WHERE last_announce < ? AND active = 1")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
db.unPruneTorrentStmt, err = db.mainConn.sqlDb.Prepare(
|
||||
db.unPruneTorrentStmt, err = db.conn.Prepare(
|
||||
"UPDATE torrents SET Status = 0 WHERE ID = ?")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
db.Users = make(map[string]*cdb.User)
|
||||
db.Torrents = make(map[cdb.TorrentHash]*cdb.Torrent)
|
||||
dbUsers := make(map[string]*cdb.User)
|
||||
db.Users.Store(&dbUsers)
|
||||
|
||||
dbTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
|
||||
db.Torrents.Store(&dbTorrents)
|
||||
|
||||
dbHitAndRuns := make(map[cdb.UserTorrentPair]struct{})
|
||||
db.HitAndRuns.Store(&dbHitAndRuns)
|
||||
|
@ -184,7 +171,8 @@ func (db *Database) Init() {
|
|||
func (db *Database) Terminate() {
|
||||
log.Info.Print("Terminating database connection...")
|
||||
|
||||
db.terminate = true
|
||||
db.terminate.Store(true)
|
||||
db.ctxCancel()
|
||||
|
||||
log.Info.Print("Closing all flush channels...")
|
||||
db.closeFlushChannels()
|
||||
|
@ -195,13 +183,11 @@ func (db *Database) Terminate() {
|
|||
}()
|
||||
|
||||
db.waitGroup.Wait()
|
||||
db.mainConn.mutex.Lock()
|
||||
_ = db.mainConn.Close()
|
||||
db.mainConn.mutex.Unlock()
|
||||
_ = db.conn.Close()
|
||||
db.serialize()
|
||||
}
|
||||
|
||||
func Open() *Connection {
|
||||
func Open() *sql.DB {
|
||||
databaseConfig := config.Section("database")
|
||||
deadlockWaitTime, _ = databaseConfig.GetInt("deadlock_pause", 1)
|
||||
maxDeadlockRetries, _ = databaseConfig.GetInt("deadlock_retries", 5)
|
||||
|
@ -217,18 +203,7 @@ func Open() *Connection {
|
|||
// First try to load the DSN from environment. USeful for tests.
|
||||
databaseDsn := os.Getenv("DB_DSN")
|
||||
if databaseDsn == "" {
|
||||
dbUsername, _ := databaseConfig.Get("username", defaultDsn["username"])
|
||||
dbPassword, _ := databaseConfig.Get("password", defaultDsn["password"])
|
||||
dbProto, _ := databaseConfig.Get("proto", defaultDsn["proto"])
|
||||
dbAddr, _ := databaseConfig.Get("addr", defaultDsn["addr"])
|
||||
dbDatabase, _ := databaseConfig.Get("database", defaultDsn["database"])
|
||||
databaseDsn = fmt.Sprintf("%s:%s@%s(%s)/%s",
|
||||
dbUsername,
|
||||
dbPassword,
|
||||
dbProto,
|
||||
dbAddr,
|
||||
dbDatabase,
|
||||
)
|
||||
databaseDsn, _ = databaseConfig.Get("dsn", defaultDsn)
|
||||
}
|
||||
|
||||
sqlDb, err := sql.Open("mysql", databaseDsn)
|
||||
|
@ -240,16 +215,10 @@ func Open() *Connection {
|
|||
log.Fatal.Fatalf("Couldn't ping database - %s", err)
|
||||
}
|
||||
|
||||
return &Connection{
|
||||
sqlDb: sqlDb,
|
||||
}
|
||||
return sqlDb
|
||||
}
|
||||
|
||||
func (db *Connection) Close() error {
|
||||
return db.sqlDb.Close()
|
||||
}
|
||||
|
||||
func (db *Connection) query(stmt *sql.Stmt, args ...interface{}) *sql.Rows { //nolint:unparam
|
||||
func (db *Database) query(stmt *sql.Stmt, args ...interface{}) *sql.Rows { //nolint:unparam
|
||||
rows, _ := perform(func() (interface{}, error) {
|
||||
return stmt.Query(args...)
|
||||
}).(*sql.Rows)
|
||||
|
@ -257,7 +226,7 @@ func (db *Connection) query(stmt *sql.Stmt, args ...interface{}) *sql.Rows { //n
|
|||
return rows
|
||||
}
|
||||
|
||||
func (db *Connection) execute(stmt *sql.Stmt, args ...interface{}) sql.Result {
|
||||
func (db *Database) execute(stmt *sql.Stmt, args ...interface{}) sql.Result {
|
||||
result, _ := perform(func() (interface{}, error) {
|
||||
return stmt.Exec(args...)
|
||||
}).(sql.Result)
|
||||
|
@ -265,9 +234,9 @@ func (db *Connection) execute(stmt *sql.Stmt, args ...interface{}) sql.Result {
|
|||
return result
|
||||
}
|
||||
|
||||
func (db *Connection) exec(query *bytes.Buffer, args ...interface{}) sql.Result { //nolint:unparam
|
||||
func (db *Database) exec(query *bytes.Buffer, args ...interface{}) sql.Result { //nolint:unparam
|
||||
result, _ := perform(func() (interface{}, error) {
|
||||
return db.sqlDb.Exec(query.String(), args...)
|
||||
return db.conn.Exec(query.String(), args...)
|
||||
}).(sql.Result)
|
||||
|
||||
return result
|
||||
|
|
|
@ -19,6 +19,8 @@ package database
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
|
@ -44,7 +46,7 @@ func TestMain(m *testing.M) {
|
|||
db.Init()
|
||||
|
||||
fixtures, err = testfixtures.New(
|
||||
testfixtures.Database(db.mainConn.sqlDb),
|
||||
testfixtures.Database(db.conn),
|
||||
testfixtures.Dialect("mariadb"),
|
||||
testfixtures.Directory("fixtures"),
|
||||
testfixtures.DangerousSkipTestDatabaseCheck(),
|
||||
|
@ -59,47 +61,55 @@ func TestMain(m *testing.M) {
|
|||
func TestLoadUsers(t *testing.T) {
|
||||
prepareTestDatabase()
|
||||
|
||||
db.Users = make(map[string]*cdb.User)
|
||||
dbUsers := make(map[string]*cdb.User)
|
||||
db.Users.Store(&dbUsers)
|
||||
|
||||
testUser1 := &cdb.User{}
|
||||
testUser1.ID.Store(1)
|
||||
testUser1.DownMultiplier.Store(math.Float64bits(1))
|
||||
testUser1.UpMultiplier.Store(math.Float64bits(1))
|
||||
testUser1.DisableDownload.Store(false)
|
||||
testUser1.TrackerHide.Store(false)
|
||||
|
||||
testUser2 := &cdb.User{}
|
||||
testUser2.ID.Store(2)
|
||||
testUser2.DownMultiplier.Store(math.Float64bits(2))
|
||||
testUser2.UpMultiplier.Store(math.Float64bits(0.5))
|
||||
testUser2.DisableDownload.Store(true)
|
||||
testUser2.TrackerHide.Store(true)
|
||||
|
||||
users := map[string]*cdb.User{
|
||||
"mUztWMpBYNCqzmge6vGeEUGSrctJbgpQ": {
|
||||
ID: 1,
|
||||
UpMultiplier: 1,
|
||||
DownMultiplier: 1,
|
||||
TrackerHide: false,
|
||||
DisableDownload: false,
|
||||
},
|
||||
"tbHfQDQ9xDaQdsNv5CZBtHPfk7KGzaCw": {
|
||||
ID: 2,
|
||||
UpMultiplier: 0.5,
|
||||
DownMultiplier: 2,
|
||||
TrackerHide: true,
|
||||
DisableDownload: true,
|
||||
},
|
||||
"mUztWMpBYNCqzmge6vGeEUGSrctJbgpQ": testUser1,
|
||||
"tbHfQDQ9xDaQdsNv5CZBtHPfk7KGzaCw": testUser2,
|
||||
}
|
||||
|
||||
// Test with fresh data
|
||||
db.loadUsers()
|
||||
|
||||
if len(db.Users) != len(users) {
|
||||
t.Fatal(fixtureFailure("Did not load all users as expected from fixture file", len(users), len(db.Users)))
|
||||
dbUsers = *db.Users.Load()
|
||||
|
||||
if len(dbUsers) != len(users) {
|
||||
t.Fatal(fixtureFailure("Did not load all users as expected from fixture file", len(users), len(dbUsers)))
|
||||
}
|
||||
|
||||
for passkey, user := range users {
|
||||
if !reflect.DeepEqual(user, db.Users[passkey]) {
|
||||
if !reflect.DeepEqual(user, dbUsers[passkey]) {
|
||||
t.Fatal(fixtureFailure(
|
||||
fmt.Sprintf("Did not load user (%s) as expected from fixture file", passkey),
|
||||
user,
|
||||
db.Users[passkey]))
|
||||
dbUsers[passkey]))
|
||||
}
|
||||
}
|
||||
|
||||
// Now test load on top of existing data
|
||||
oldUsers := db.Users
|
||||
oldUsers := dbUsers
|
||||
|
||||
db.loadUsers()
|
||||
|
||||
if !reflect.DeepEqual(oldUsers, db.Users) {
|
||||
t.Fatal(fixtureFailure("Did not reload users as expected from fixture file", oldUsers, db.Users))
|
||||
dbUsers = *db.Users.Load()
|
||||
|
||||
if !reflect.DeepEqual(oldUsers, dbUsers) {
|
||||
t.Fatal(fixtureFailure("Did not reload users as expected from fixture file", oldUsers, dbUsers))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -133,89 +143,96 @@ func TestLoadHitAndRuns(t *testing.T) {
|
|||
func TestLoadTorrents(t *testing.T) {
|
||||
prepareTestDatabase()
|
||||
|
||||
db.Torrents = make(map[cdb.TorrentHash]*cdb.Torrent)
|
||||
dbTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
|
||||
db.Torrents.Store(&dbTorrents)
|
||||
|
||||
t1 := &cdb.Torrent{
|
||||
Seeders: map[cdb.PeerKey]*cdb.Peer{},
|
||||
Leechers: map[cdb.PeerKey]*cdb.Peer{},
|
||||
}
|
||||
t1.ID.Store(1)
|
||||
t1.Status.Store(1)
|
||||
t1.Snatched.Store(2)
|
||||
t1.DownMultiplier.Store(math.Float64bits(1))
|
||||
t1.UpMultiplier.Store(math.Float64bits(1))
|
||||
t1.Group.GroupID.Store(1)
|
||||
t1.Group.TorrentType.Store(cdb.MustTorrentTypeFromString("anime"))
|
||||
|
||||
t2 := &cdb.Torrent{
|
||||
Seeders: map[cdb.PeerKey]*cdb.Peer{},
|
||||
Leechers: map[cdb.PeerKey]*cdb.Peer{},
|
||||
}
|
||||
t2.ID.Store(2)
|
||||
t2.Status.Store(0)
|
||||
t2.Snatched.Store(0)
|
||||
t2.DownMultiplier.Store(math.Float64bits(2))
|
||||
t2.UpMultiplier.Store(math.Float64bits(0.5))
|
||||
t2.Group.GroupID.Store(1)
|
||||
t2.Group.TorrentType.Store(cdb.MustTorrentTypeFromString("music"))
|
||||
|
||||
t3 := &cdb.Torrent{
|
||||
Seeders: map[cdb.PeerKey]*cdb.Peer{},
|
||||
Leechers: map[cdb.PeerKey]*cdb.Peer{},
|
||||
}
|
||||
t3.ID.Store(3)
|
||||
t3.Status.Store(0)
|
||||
t3.Snatched.Store(0)
|
||||
t3.DownMultiplier.Store(math.Float64bits(1))
|
||||
t3.UpMultiplier.Store(math.Float64bits(1))
|
||||
t3.Group.GroupID.Store(2)
|
||||
t3.Group.TorrentType.Store(cdb.MustTorrentTypeFromString("anime"))
|
||||
|
||||
torrents := map[cdb.TorrentHash]*cdb.Torrent{
|
||||
{114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56}: {
|
||||
ID: 1,
|
||||
Status: 1,
|
||||
Snatched: 2,
|
||||
DownMultiplier: 1,
|
||||
UpMultiplier: 1,
|
||||
Seeders: map[cdb.PeerKey]*cdb.Peer{},
|
||||
Leechers: map[cdb.PeerKey]*cdb.Peer{},
|
||||
Group: cdb.TorrentGroup{
|
||||
GroupID: 1,
|
||||
TorrentType: "anime",
|
||||
},
|
||||
},
|
||||
{22, 168, 45, 221, 87, 225, 140, 177, 94, 34, 242, 225, 196, 234, 222, 46, 187, 131, 177, 155}: {
|
||||
ID: 2,
|
||||
Status: 0,
|
||||
Snatched: 0,
|
||||
DownMultiplier: 2,
|
||||
UpMultiplier: 0.5,
|
||||
Seeders: map[cdb.PeerKey]*cdb.Peer{},
|
||||
Leechers: map[cdb.PeerKey]*cdb.Peer{},
|
||||
Group: cdb.TorrentGroup{
|
||||
GroupID: 1,
|
||||
TorrentType: "music",
|
||||
},
|
||||
},
|
||||
{89, 252, 84, 49, 177, 28, 118, 28, 148, 205, 62, 185, 8, 37, 234, 110, 109, 200, 165, 241}: {
|
||||
ID: 3,
|
||||
Status: 0,
|
||||
Snatched: 0,
|
||||
DownMultiplier: 1,
|
||||
UpMultiplier: 1,
|
||||
Seeders: map[cdb.PeerKey]*cdb.Peer{},
|
||||
Leechers: map[cdb.PeerKey]*cdb.Peer{},
|
||||
Group: cdb.TorrentGroup{
|
||||
GroupID: 2,
|
||||
TorrentType: "anime",
|
||||
},
|
||||
},
|
||||
{114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56}: t1,
|
||||
{22, 168, 45, 221, 87, 225, 140, 177, 94, 34, 242, 225, 196, 234, 222, 46, 187, 131, 177, 155}: t2,
|
||||
{89, 252, 84, 49, 177, 28, 118, 28, 148, 205, 62, 185, 8, 37, 234, 110, 109, 200, 165, 241}: t3,
|
||||
}
|
||||
|
||||
for k := range torrents {
|
||||
torrents[k].InitializeLock()
|
||||
}
|
||||
|
||||
// Test with fresh data
|
||||
db.loadTorrents()
|
||||
|
||||
if len(db.Torrents) != len(torrents) {
|
||||
dbTorrents = *db.Torrents.Load()
|
||||
|
||||
if len(dbTorrents) != len(torrents) {
|
||||
t.Fatal(fixtureFailure("Did not load all torrents as expected from fixture file",
|
||||
len(torrents),
|
||||
len(db.Torrents)))
|
||||
len(dbTorrents)))
|
||||
}
|
||||
|
||||
for hash, torrent := range torrents {
|
||||
if !reflect.DeepEqual(torrent, db.Torrents[hash]) {
|
||||
if !cmp.Equal(torrent, dbTorrents[hash], cdb.TorrentTestCompareOptions...) {
|
||||
hashHex, _ := hash.MarshalText()
|
||||
t.Fatal(fixtureFailure(
|
||||
fmt.Sprintf("Did not load torrent (%s) as expected from fixture file", hash),
|
||||
fmt.Sprintf("Did not load torrent (%s) as expected from fixture file", string(hashHex)),
|
||||
torrent,
|
||||
db.Torrents[hash]))
|
||||
dbTorrents[hash]))
|
||||
}
|
||||
}
|
||||
|
||||
// Now test load on top of existing data
|
||||
oldTorrents := db.Torrents
|
||||
oldTorrents := dbTorrents
|
||||
|
||||
db.loadTorrents()
|
||||
|
||||
if !reflect.DeepEqual(oldTorrents, db.Torrents) {
|
||||
t.Fatal(fixtureFailure("Did not reload torrents as expected from fixture file", oldTorrents, db.Torrents))
|
||||
dbTorrents = *db.Torrents.Load()
|
||||
|
||||
if !cmp.Equal(oldTorrents, dbTorrents, cdb.TorrentTestCompareOptions...) {
|
||||
t.Fatal(fixtureFailure("Did not reload torrents as expected from fixture file", oldTorrents, dbTorrents))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadGroupsFreeleech(t *testing.T) {
|
||||
prepareTestDatabase()
|
||||
|
||||
dbMap := make(map[cdb.TorrentGroup]*cdb.TorrentGroupFreeleech)
|
||||
dbMap := make(map[cdb.TorrentGroupKey]*cdb.TorrentGroupFreeleech)
|
||||
db.TorrentGroupFreeleech.Store(&dbMap)
|
||||
|
||||
torrentGroupFreeleech := map[cdb.TorrentGroup]*cdb.TorrentGroupFreeleech{
|
||||
{
|
||||
GroupID: 2,
|
||||
TorrentType: "anime",
|
||||
}: {
|
||||
torrentGroupFreeleech := map[cdb.TorrentGroupKey]*cdb.TorrentGroupFreeleech{
|
||||
cdb.MustTorrentGroupKeyFromString("anime", 2): {
|
||||
DownMultiplier: 0,
|
||||
UpMultiplier: 2,
|
||||
},
|
||||
|
@ -297,36 +314,48 @@ func TestLoadClients(t *testing.T) {
|
|||
func TestUnPrune(t *testing.T) {
|
||||
prepareTestDatabase()
|
||||
|
||||
h := cdb.TorrentHash{114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56}
|
||||
torrent := cdb.Torrent{
|
||||
Seeders: db.Torrents[h].Seeders,
|
||||
Leechers: db.Torrents[h].Leechers,
|
||||
Group: db.Torrents[h].Group,
|
||||
ID: db.Torrents[h].ID,
|
||||
Snatched: db.Torrents[h].Snatched,
|
||||
Status: db.Torrents[h].Status,
|
||||
LastAction: db.Torrents[h].LastAction,
|
||||
UpMultiplier: db.Torrents[h].UpMultiplier,
|
||||
DownMultiplier: db.Torrents[h].DownMultiplier,
|
||||
}
|
||||
torrent.Status = 0
|
||||
dbTorrents := *db.Torrents.Load()
|
||||
|
||||
db.UnPrune(db.Torrents[h])
|
||||
h := cdb.TorrentHash{114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56}
|
||||
dbTorrent := dbTorrents[h]
|
||||
|
||||
torrent := cdb.Torrent{
|
||||
Seeders: dbTorrent.Seeders,
|
||||
Leechers: dbTorrent.Leechers,
|
||||
}
|
||||
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
|
||||
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
|
||||
torrent.ID.Store(dbTorrent.ID.Load())
|
||||
torrent.Status.Store(dbTorrent.Status.Load())
|
||||
torrent.Snatched.Store(dbTorrent.Snatched.Load())
|
||||
torrent.DownMultiplier.Store(dbTorrent.DownMultiplier.Load())
|
||||
torrent.UpMultiplier.Store(dbTorrent.UpMultiplier.Load())
|
||||
torrent.Group.GroupID.Store(dbTorrent.Group.GroupID.Load())
|
||||
torrent.Group.TorrentType.Store(dbTorrent.Group.TorrentType.Load())
|
||||
|
||||
torrent.InitializeLock()
|
||||
torrent.Status.Store(0)
|
||||
|
||||
db.UnPrune(dbTorrents[h])
|
||||
|
||||
db.loadTorrents()
|
||||
|
||||
if !reflect.DeepEqual(&torrent, db.Torrents[h]) {
|
||||
dbTorrents = *db.Torrents.Load()
|
||||
|
||||
if !cmp.Equal(&torrent, dbTorrents[h], cdb.TorrentTestCompareOptions...) {
|
||||
t.Fatal(fixtureFailure(
|
||||
fmt.Sprintf("Torrent (%x) was not unpruned properly", h),
|
||||
&torrent,
|
||||
db.Torrents[h]))
|
||||
dbTorrents[h]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAndFlushUsers(t *testing.T) {
|
||||
prepareTestDatabase()
|
||||
|
||||
testUser := db.Users["tbHfQDQ9xDaQdsNv5CZBtHPfk7KGzaCw"]
|
||||
dbUsers := *db.Users.Load()
|
||||
|
||||
testUser := dbUsers["tbHfQDQ9xDaQdsNv5CZBtHPfk7KGzaCw"]
|
||||
|
||||
var (
|
||||
initUpload int64
|
||||
|
@ -347,11 +376,11 @@ func TestRecordAndFlushUsers(t *testing.T) {
|
|||
|
||||
deltaRawDownload = 83472
|
||||
deltaRawUpload = 245
|
||||
deltaDownload = int64(float64(deltaRawDownload) * testUser.DownMultiplier)
|
||||
deltaUpload = int64(float64(deltaRawUpload) * testUser.UpMultiplier)
|
||||
deltaDownload = int64(float64(deltaRawDownload) * math.Float64frombits(testUser.DownMultiplier.Load()))
|
||||
deltaUpload = int64(float64(deltaRawUpload) * math.Float64frombits(testUser.UpMultiplier.Load()))
|
||||
|
||||
row := db.mainConn.sqlDb.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
|
||||
"FROM users_main WHERE ID = ?", testUser.ID)
|
||||
row := db.conn.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
|
||||
"FROM users_main WHERE ID = ?", testUser.ID.Load())
|
||||
|
||||
err := row.Scan(&initUpload, &initDownload, &initRawUpload, &initRawDownload)
|
||||
if err != nil {
|
||||
|
@ -365,8 +394,8 @@ func TestRecordAndFlushUsers(t *testing.T) {
|
|||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
row = db.mainConn.sqlDb.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
|
||||
"FROM users_main WHERE ID = ?", testUser.ID)
|
||||
row = db.conn.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
|
||||
"FROM users_main WHERE ID = ?", testUser.ID.Load())
|
||||
|
||||
err = row.Scan(&upload, &download, &rawUpload, &rawDownload)
|
||||
if err != nil {
|
||||
|
@ -446,7 +475,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
|
|||
deltaActiveTime = 267
|
||||
deltaSeedTime = 15
|
||||
|
||||
row := db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
|
||||
row := db.conn.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
|
||||
"FROM transfer_history WHERE uid = ? AND fid = ?", testPeer.UserID, testPeer.TorrentID)
|
||||
|
||||
err := row.Scan(&initRawUpload, &initRawDownload, &initActiveTime, &initSeedTime, &initActive, &initSnatch)
|
||||
|
@ -467,7 +496,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
|
|||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
row = db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
|
||||
row = db.conn.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
|
||||
"FROM transfer_history WHERE uid = ? AND fid = ?", testPeer.UserID, testPeer.TorrentID)
|
||||
|
||||
err = row.Scan(&rawUpload, &rawDownload, &activeTime, &seedTime, &active, &snatch)
|
||||
|
@ -528,7 +557,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
|
|||
|
||||
var gotStartTime int64
|
||||
|
||||
row = db.mainConn.sqlDb.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
|
||||
row = db.conn.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
|
||||
"FROM transfer_history WHERE uid = ? AND fid = ?", gotPeer.UserID, gotPeer.TorrentID)
|
||||
|
||||
err = row.Scan(&gotPeer.Seeding, &gotStartTime, &gotPeer.LastAnnounce, &gotPeer.Left)
|
||||
|
@ -554,7 +583,14 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
|
|||
Left: 65000,
|
||||
}
|
||||
|
||||
db.QueueTransferHistory(testPeer, 0, 1000, 1, 0, 1, true)
|
||||
db.QueueTransferHistory(
|
||||
testPeer,
|
||||
0,
|
||||
1000,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
true)
|
||||
|
||||
gotPeer = &cdb.Peer{
|
||||
UserID: testPeer.UserID,
|
||||
|
@ -566,7 +602,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
|
|||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
row = db.mainConn.sqlDb.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
|
||||
row = db.conn.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
|
||||
"FROM transfer_history WHERE uid = ? AND fid = ?", gotPeer.UserID, gotPeer.TorrentID)
|
||||
|
||||
err = row.Scan(&gotPeer.Seeding, &gotPeer.StartTime, &gotPeer.LastAnnounce, &gotPeer.Left)
|
||||
|
@ -603,7 +639,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
|
|||
deltaDownload = 236
|
||||
deltaUpload = 3262
|
||||
|
||||
row := db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded "+
|
||||
row := db.conn.QueryRow("SELECT uploaded, downloaded "+
|
||||
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
|
||||
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
|
||||
|
||||
|
@ -619,7 +655,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
|
|||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
row = db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded "+
|
||||
row = db.conn.QueryRow("SELECT uploaded, downloaded "+
|
||||
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
|
||||
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
|
||||
|
||||
|
@ -655,7 +691,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
|
|||
|
||||
var gotStartTime int64
|
||||
|
||||
row = db.mainConn.sqlDb.QueryRow("SELECT port, starttime, last_announce "+
|
||||
row = db.conn.QueryRow("SELECT port, starttime, last_announce "+
|
||||
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
|
||||
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
|
||||
|
||||
|
@ -700,7 +736,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
|
|||
Addr: cdb.NewPeerAddressFromIPPort(testPeer.Addr.IP(), 0),
|
||||
}
|
||||
|
||||
row = db.mainConn.sqlDb.QueryRow("SELECT port, starttime, last_announce "+
|
||||
row = db.conn.QueryRow("SELECT port, starttime, last_announce "+
|
||||
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
|
||||
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
|
||||
|
||||
|
@ -738,7 +774,7 @@ func TestRecordAndFlushSnatch(t *testing.T) {
|
|||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
row := db.mainConn.sqlDb.QueryRow("SELECT snatched_time "+
|
||||
row := db.conn.QueryRow("SELECT snatched_time "+
|
||||
"FROM transfer_history WHERE uid = ? AND fid = ?", testPeer.UserID, testPeer.TorrentID)
|
||||
|
||||
err := row.Scan(&snatchTime)
|
||||
|
@ -758,22 +794,24 @@ func TestRecordAndFlushTorrents(t *testing.T) {
|
|||
prepareTestDatabase()
|
||||
|
||||
h := cdb.TorrentHash{114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56}
|
||||
torrent := db.Torrents[h]
|
||||
torrent.LastAction = time.Now().Unix()
|
||||
torrent := (*db.Torrents.Load())[h]
|
||||
torrent.LastAction.Store(time.Now().Unix())
|
||||
torrent.Seeders[cdb.NewPeerKey(1, cdb.PeerIDFromRawString("test_peer_id_num_one"))] = &cdb.Peer{
|
||||
UserID: 1,
|
||||
TorrentID: torrent.ID,
|
||||
TorrentID: torrent.ID.Load(),
|
||||
ClientID: 1,
|
||||
StartTime: time.Now().Unix(),
|
||||
LastAnnounce: time.Now().Unix(),
|
||||
}
|
||||
torrent.Leechers[cdb.NewPeerKey(3, cdb.PeerIDFromRawString("test_peer_id_num_two"))] = &cdb.Peer{
|
||||
UserID: 3,
|
||||
TorrentID: torrent.ID,
|
||||
TorrentID: torrent.ID.Load(),
|
||||
ClientID: 2,
|
||||
StartTime: time.Now().Unix(),
|
||||
LastAnnounce: time.Now().Unix(),
|
||||
}
|
||||
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
|
||||
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
|
||||
|
||||
db.QueueTorrent(torrent, 5)
|
||||
|
||||
|
@ -789,26 +827,26 @@ func TestRecordAndFlushTorrents(t *testing.T) {
|
|||
numSeeders int
|
||||
)
|
||||
|
||||
row := db.mainConn.sqlDb.QueryRow("SELECT Snatched, last_action, Seeders, Leechers "+
|
||||
"FROM torrents WHERE ID = ?", torrent.ID)
|
||||
row := db.conn.QueryRow("SELECT Snatched, last_action, Seeders, Leechers "+
|
||||
"FROM torrents WHERE ID = ?", torrent.ID.Load())
|
||||
err := row.Scan(&snatched, &lastAction, &numSeeders, &numLeechers)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if torrent.Snatched+5 != snatched {
|
||||
if uint16(torrent.Snatched.Load())+5 != snatched {
|
||||
t.Fatal(fixtureFailure(
|
||||
fmt.Sprintf("Snatches incorrectly updated in the database for torrent %x", h),
|
||||
torrent.Snatched+5,
|
||||
torrent.Snatched.Load()+5,
|
||||
snatched,
|
||||
))
|
||||
}
|
||||
|
||||
if torrent.LastAction != lastAction {
|
||||
if torrent.LastAction.Load() != lastAction {
|
||||
t.Fatal(fixtureFailure(
|
||||
fmt.Sprintf("Last incorrectly updated in the database for torrent %x", h),
|
||||
torrent.LastAction,
|
||||
torrent.LastAction.Load(),
|
||||
lastAction,
|
||||
))
|
||||
}
|
||||
|
@ -828,6 +866,22 @@ func TestRecordAndFlushTorrents(t *testing.T) {
|
|||
numLeechers,
|
||||
))
|
||||
}
|
||||
|
||||
if int(torrent.SeedersLength.Load()) != numSeeders {
|
||||
t.Fatal(fixtureFailure(
|
||||
fmt.Sprintf("SeedersLength incorrectly updated in the database for torrent %x", h),
|
||||
len(torrent.Seeders),
|
||||
numSeeders,
|
||||
))
|
||||
}
|
||||
|
||||
if int(torrent.LeechersLength.Load()) != numLeechers {
|
||||
t.Fatal(fixtureFailure(
|
||||
fmt.Sprintf("LeechersLength incorrectly updated in the database for torrent %x", h),
|
||||
len(torrent.Leechers),
|
||||
numLeechers,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTerminate(_ *testing.T) {
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
- ID: 1
|
||||
Time: 1592659152
|
||||
|
||||
- ID: 2
|
||||
Time: 1609032772
|
|
@ -1,2 +0,0 @@
|
|||
- ID: 1
|
||||
Time: 1592659152
|
|
@ -19,8 +19,8 @@ package database
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"chihaya/util"
|
||||
"errors"
|
||||
"github.com/viney-shih/go-lock"
|
||||
"time"
|
||||
|
||||
"chihaya/collectors"
|
||||
|
@ -82,8 +82,6 @@ func (db *Database) startFlushing() {
|
|||
db.transferIpsChannel = make(chan *bytes.Buffer, transferIpsFlushBufferSize)
|
||||
db.snatchChannel = make(chan *bytes.Buffer, snatchFlushBufferSize)
|
||||
|
||||
db.transferHistoryLock = lock.NewCASMutex()
|
||||
|
||||
go db.flushTorrents()
|
||||
go db.flushUsers()
|
||||
go db.flushTransferHistory() // Can not be blocking or it will lock purgeInactivePeers when chan is empty
|
||||
|
@ -114,25 +112,9 @@ func (db *Database) flushTorrents() {
|
|||
count int
|
||||
)
|
||||
|
||||
conn := Open()
|
||||
|
||||
for {
|
||||
query.Reset()
|
||||
query.WriteString("CREATE TEMPORARY TABLE IF NOT EXISTS flush_torrents (" +
|
||||
"ID int unsigned NOT NULL, " +
|
||||
"Snatched int unsigned NOT NULL DEFAULT 0, " +
|
||||
"Seeders int unsigned NOT NULL DEFAULT 0, " +
|
||||
"Leechers int unsigned NOT NULL DEFAULT 0, " +
|
||||
"last_action int NOT NULL DEFAULT 0, " +
|
||||
"PRIMARY KEY (ID)) ENGINE=MEMORY")
|
||||
conn.exec(&query)
|
||||
|
||||
query.Reset()
|
||||
query.WriteString("TRUNCATE flush_torrents")
|
||||
conn.exec(&query)
|
||||
|
||||
query.Reset()
|
||||
query.WriteString("INSERT INTO flush_torrents VALUES ")
|
||||
query.WriteString("INSERT IGNORE INTO torrents (ID, Snatched, Seeders, Leechers, last_action) VALUES ")
|
||||
|
||||
length := len(db.torrentChannel)
|
||||
|
||||
|
@ -152,7 +134,7 @@ func (db *Database) flushTorrents() {
|
|||
|
||||
if count > 0 {
|
||||
logFlushes, _ := config.GetBool("log_flushes", true)
|
||||
if logFlushes && !db.terminate {
|
||||
if logFlushes && !db.terminate.Load() {
|
||||
log.Info.Printf("{torrents} Flushing %d", count)
|
||||
}
|
||||
|
||||
|
@ -161,18 +143,9 @@ func (db *Database) flushTorrents() {
|
|||
query.WriteString(" ON DUPLICATE KEY UPDATE Snatched = Snatched + VALUE(Snatched), " +
|
||||
"Seeders = VALUE(Seeders), Leechers = VALUE(Leechers), " +
|
||||
"last_action = IF(last_action < VALUE(last_action), VALUE(last_action), last_action)")
|
||||
conn.exec(&query)
|
||||
db.exec(&query)
|
||||
|
||||
query.Reset()
|
||||
query.WriteString("UPDATE torrents t, flush_torrents ft SET " +
|
||||
"t.Snatched = t.Snatched + ft.Snatched, " +
|
||||
"t.Seeders = ft.Seeders, " +
|
||||
"t.Leechers = ft.Leechers, " +
|
||||
"t.last_action = IF(t.last_action < ft.last_action, ft.last_action, t.last_action)" +
|
||||
"WHERE t.ID = ft.ID")
|
||||
conn.exec(&query)
|
||||
|
||||
if !db.terminate {
|
||||
if !db.terminate.Load() {
|
||||
elapsedTime := time.Since(startTime)
|
||||
collectors.UpdateFlushTime("torrents", elapsedTime)
|
||||
collectors.UpdateChannelsLen("torrents", count)
|
||||
|
@ -181,14 +154,12 @@ func (db *Database) flushTorrents() {
|
|||
if length < (torrentFlushBufferSize >> 1) {
|
||||
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
|
||||
}
|
||||
} else if db.terminate {
|
||||
} else if db.terminate.Load() {
|
||||
break
|
||||
} else {
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
func (db *Database) flushUsers() {
|
||||
|
@ -200,25 +171,9 @@ func (db *Database) flushUsers() {
|
|||
count int
|
||||
)
|
||||
|
||||
conn := Open()
|
||||
|
||||
for {
|
||||
query.Reset()
|
||||
query.WriteString("CREATE TEMPORARY TABLE IF NOT EXISTS flush_users (" +
|
||||
"ID int unsigned NOT NULL, " +
|
||||
"Uploaded bigint unsigned NOT NULL DEFAULT 0, " +
|
||||
"Downloaded bigint unsigned NOT NULL DEFAULT 0, " +
|
||||
"rawdl bigint unsigned NOT NULL DEFAULT 0, " +
|
||||
"rawup bigint unsigned NOT NULL DEFAULT 0, " +
|
||||
"PRIMARY KEY (ID)) ENGINE=MEMORY")
|
||||
conn.exec(&query)
|
||||
|
||||
query.Reset()
|
||||
query.WriteString("TRUNCATE flush_users")
|
||||
conn.exec(&query)
|
||||
|
||||
query.Reset()
|
||||
query.WriteString("INSERT INTO flush_users VALUES ")
|
||||
query.WriteString("INSERT IGNORE INTO users_main (ID, Uploaded, Downloaded, rawdl, rawup) VALUES ")
|
||||
|
||||
length := len(db.userChannel)
|
||||
|
||||
|
@ -238,7 +193,7 @@ func (db *Database) flushUsers() {
|
|||
|
||||
if count > 0 {
|
||||
logFlushes, _ := config.GetBool("log_flushes", true)
|
||||
if logFlushes && !db.terminate {
|
||||
if logFlushes && !db.terminate.Load() {
|
||||
log.Info.Printf("{users_main} Flushing %d", count)
|
||||
}
|
||||
|
||||
|
@ -246,18 +201,9 @@ func (db *Database) flushUsers() {
|
|||
|
||||
query.WriteString(" ON DUPLICATE KEY UPDATE Uploaded = Uploaded + VALUE(Uploaded), " +
|
||||
"Downloaded = Downloaded + VALUE(Downloaded), rawdl = rawdl + VALUE(rawdl), rawup = rawup + VALUE(rawup)")
|
||||
conn.exec(&query)
|
||||
db.exec(&query)
|
||||
|
||||
query.Reset()
|
||||
query.WriteString("UPDATE users_main u, flush_users fu SET " +
|
||||
"u.Uploaded = u.Uploaded + fu.Uploaded, " +
|
||||
"u.Downloaded = u.Downloaded + fu.Downloaded, " +
|
||||
"u.rawdl = u.rawdl + fu.rawdl, " +
|
||||
"u.rawup = u.rawup + fu.rawup " +
|
||||
"WHERE u.ID = fu.ID")
|
||||
conn.exec(&query)
|
||||
|
||||
if !db.terminate {
|
||||
if !db.terminate.Load() {
|
||||
elapsedTime := time.Since(startTime)
|
||||
collectors.UpdateFlushTime("users", elapsedTime)
|
||||
collectors.UpdateChannelsLen("users", count)
|
||||
|
@ -266,14 +212,12 @@ func (db *Database) flushUsers() {
|
|||
if length < (userFlushBufferSize >> 1) {
|
||||
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
|
||||
}
|
||||
} else if db.terminate {
|
||||
} else if db.terminate.Load() {
|
||||
break
|
||||
} else {
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
func (db *Database) flushTransferHistory() {
|
||||
|
@ -285,8 +229,6 @@ func (db *Database) flushTransferHistory() {
|
|||
count int
|
||||
)
|
||||
|
||||
conn := Open()
|
||||
|
||||
for {
|
||||
length, err := func() (int, error) {
|
||||
db.transferHistoryLock.Lock()
|
||||
|
@ -314,7 +256,7 @@ func (db *Database) flushTransferHistory() {
|
|||
|
||||
if count > 0 {
|
||||
logFlushes, _ := config.GetBool("log_flushes", true)
|
||||
if logFlushes && !db.terminate {
|
||||
if logFlushes && !db.terminate.Load() {
|
||||
log.Info.Printf("{transfer_history} Flushing %d", count)
|
||||
}
|
||||
|
||||
|
@ -326,16 +268,16 @@ func (db *Database) flushTransferHistory() {
|
|||
"seedtime = seedtime + VALUE(seedtime), last_announce = VALUE(last_announce), " +
|
||||
"active = VALUE(active), snatched = snatched + VALUE(snatched);")
|
||||
|
||||
conn.exec(&query)
|
||||
db.exec(&query)
|
||||
|
||||
if !db.terminate {
|
||||
if !db.terminate.Load() {
|
||||
elapsedTime := time.Since(startTime)
|
||||
collectors.UpdateFlushTime("transfer_history", elapsedTime)
|
||||
collectors.UpdateChannelsLen("transfer_history", count)
|
||||
}
|
||||
|
||||
return length, nil
|
||||
} else if db.terminate {
|
||||
} else if db.terminate.Load() {
|
||||
return 0, errDbTerminate
|
||||
}
|
||||
|
||||
|
@ -350,8 +292,6 @@ func (db *Database) flushTransferHistory() {
|
|||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
func (db *Database) flushTransferIps() {
|
||||
|
@ -363,8 +303,6 @@ func (db *Database) flushTransferIps() {
|
|||
count int
|
||||
)
|
||||
|
||||
conn := Open()
|
||||
|
||||
for {
|
||||
query.Reset()
|
||||
query.WriteString("INSERT INTO transfer_ips (uid, fid, client_id, ip, port, uploaded, downloaded, " +
|
||||
|
@ -388,7 +326,7 @@ func (db *Database) flushTransferIps() {
|
|||
|
||||
if count > 0 {
|
||||
logFlushes, _ := config.GetBool("log_flushes", true)
|
||||
if logFlushes && !db.terminate {
|
||||
if logFlushes && !db.terminate.Load() {
|
||||
log.Info.Printf("{transfer_ips} Flushing %d", count)
|
||||
}
|
||||
|
||||
|
@ -397,9 +335,9 @@ func (db *Database) flushTransferIps() {
|
|||
// todo: port should be part of PK
|
||||
query.WriteString("\nON DUPLICATE KEY UPDATE port = VALUE(port), downloaded = downloaded + VALUE(downloaded), " +
|
||||
"uploaded = uploaded + VALUE(uploaded), last_announce = VALUE(last_announce)")
|
||||
conn.exec(&query)
|
||||
db.exec(&query)
|
||||
|
||||
if !db.terminate {
|
||||
if !db.terminate.Load() {
|
||||
elapsedTime := time.Since(startTime)
|
||||
collectors.UpdateFlushTime("transfer_ips", elapsedTime)
|
||||
collectors.UpdateChannelsLen("transfer_ips", count)
|
||||
|
@ -408,14 +346,12 @@ func (db *Database) flushTransferIps() {
|
|||
if length < (transferIpsFlushBufferSize >> 1) {
|
||||
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
|
||||
}
|
||||
} else if db.terminate {
|
||||
} else if db.terminate.Load() {
|
||||
break
|
||||
} else {
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
func (db *Database) flushSnatches() {
|
||||
|
@ -427,8 +363,6 @@ func (db *Database) flushSnatches() {
|
|||
count int
|
||||
)
|
||||
|
||||
conn := Open()
|
||||
|
||||
for {
|
||||
query.Reset()
|
||||
query.WriteString("INSERT INTO transfer_history (uid, fid, snatched_time) VALUES\n")
|
||||
|
@ -451,16 +385,16 @@ func (db *Database) flushSnatches() {
|
|||
|
||||
if count > 0 {
|
||||
logFlushes, _ := config.GetBool("log_flushes", true)
|
||||
if logFlushes && !db.terminate {
|
||||
if logFlushes && !db.terminate.Load() {
|
||||
log.Info.Printf("{snatches} Flushing %d", count)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
query.WriteString("\nON DUPLICATE KEY UPDATE snatched_time = VALUE(snatched_time)")
|
||||
conn.exec(&query)
|
||||
db.exec(&query)
|
||||
|
||||
if !db.terminate {
|
||||
if !db.terminate.Load() {
|
||||
elapsedTime := time.Since(startTime)
|
||||
collectors.UpdateFlushTime("snatches", elapsedTime)
|
||||
collectors.UpdateChannelsLen("snatches", count)
|
||||
|
@ -469,14 +403,12 @@ func (db *Database) flushSnatches() {
|
|||
if length < (snatchFlushBufferSize >> 1) {
|
||||
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
|
||||
}
|
||||
} else if db.terminate {
|
||||
} else if db.terminate.Load() {
|
||||
break
|
||||
} else {
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
func (db *Database) purgeInactivePeers() {
|
||||
|
@ -486,7 +418,7 @@ func (db *Database) purgeInactivePeers() {
|
|||
count int
|
||||
)
|
||||
|
||||
for !db.terminate {
|
||||
util.ContextTick(db.ctx, time.Duration(purgeInactivePeersInterval)*time.Second, func() {
|
||||
start = time.Now()
|
||||
now = start.Unix()
|
||||
count = 0
|
||||
|
@ -494,45 +426,45 @@ func (db *Database) purgeInactivePeers() {
|
|||
oldestActive := now - int64(peerInactivityInterval)
|
||||
|
||||
// First, remove inactive peers from memory
|
||||
func() {
|
||||
db.TorrentsLock.RLock()
|
||||
defer db.TorrentsLock.RUnlock()
|
||||
dbTorrents := *db.Torrents.Load()
|
||||
for _, torrent := range dbTorrents {
|
||||
func() {
|
||||
//Take write lock to operate on peer entries
|
||||
torrent.Lock()
|
||||
defer torrent.Unlock()
|
||||
|
||||
for _, torrent := range db.Torrents {
|
||||
func() {
|
||||
//Take write lock to operate on entries
|
||||
torrent.Lock()
|
||||
defer torrent.Unlock()
|
||||
countThisTorrent := count
|
||||
|
||||
countThisTorrent := count
|
||||
|
||||
for id, peer := range torrent.Leechers {
|
||||
if peer.LastAnnounce < oldestActive {
|
||||
delete(torrent.Leechers, id)
|
||||
count++
|
||||
}
|
||||
for id, peer := range torrent.Leechers {
|
||||
if peer.LastAnnounce < oldestActive {
|
||||
delete(torrent.Leechers, id)
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
if countThisTorrent != count && len(torrent.Leechers) == 0 {
|
||||
/* Deallocate previous map since Go does not free space used on maps when deleting objects.
|
||||
We're doing it only for Leechers as potential advantage from freeing one or two Seeders is
|
||||
virtually nil, while Leechers can incur significant memory leaks due to initial swarm activity. */
|
||||
torrent.Leechers = make(map[cdb.PeerKey]*cdb.Peer)
|
||||
}
|
||||
if countThisTorrent != count && len(torrent.Leechers) == 0 {
|
||||
/* Deallocate previous map since Go does not free space used on maps when deleting objects.
|
||||
We're doing it only for Leechers as potential advantage from freeing one or two Seeders is
|
||||
virtually nil, while Leechers can incur significant memory leaks due to initial swarm activity. */
|
||||
torrent.Leechers = make(map[cdb.PeerKey]*cdb.Peer)
|
||||
}
|
||||
|
||||
for id, peer := range torrent.Seeders {
|
||||
if peer.LastAnnounce < oldestActive {
|
||||
delete(torrent.Seeders, id)
|
||||
count++
|
||||
}
|
||||
for id, peer := range torrent.Seeders {
|
||||
if peer.LastAnnounce < oldestActive {
|
||||
delete(torrent.Seeders, id)
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
if countThisTorrent != count {
|
||||
db.QueueTorrent(torrent, 0)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
if countThisTorrent != count {
|
||||
// Update lengths of peers
|
||||
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
|
||||
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
|
||||
|
||||
db.QueueTorrent(torrent, 0)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
elapsedTime := time.Since(start)
|
||||
collectors.UpdateFlushTime("purging_inactive_peers", elapsedTime)
|
||||
|
@ -547,11 +479,8 @@ func (db *Database) purgeInactivePeers() {
|
|||
db.transferHistoryLock.Lock()
|
||||
defer db.transferHistoryLock.Unlock()
|
||||
|
||||
db.mainConn.mutex.Lock()
|
||||
defer db.mainConn.mutex.Unlock()
|
||||
|
||||
start = time.Now()
|
||||
result := db.mainConn.execute(db.cleanStalePeersStmt, oldestActive)
|
||||
result := db.execute(db.cleanStalePeersStmt, oldestActive)
|
||||
|
||||
if result != nil {
|
||||
rows, err := result.RowsAffected()
|
||||
|
@ -562,7 +491,5 @@ func (db *Database) purgeInactivePeers() {
|
|||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(time.Duration(purgeInactivePeersInterval) * time.Second)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -38,15 +38,15 @@ func (db *Database) QueueTorrent(torrent *cdb.Torrent, deltaSnatch uint8) {
|
|||
tq := db.bufferPool.Take()
|
||||
|
||||
tq.WriteString("(")
|
||||
tq.WriteString(strconv.FormatUint(uint64(torrent.ID), 10))
|
||||
tq.WriteString(strconv.FormatUint(uint64(torrent.ID.Load()), 10))
|
||||
tq.WriteString(",")
|
||||
tq.WriteString(strconv.FormatUint(uint64(deltaSnatch), 10))
|
||||
tq.WriteString(",")
|
||||
tq.WriteString(strconv.FormatInt(int64(len(torrent.Seeders)), 10))
|
||||
tq.WriteString(strconv.FormatUint(uint64(torrent.SeedersLength.Load()), 10))
|
||||
tq.WriteString(",")
|
||||
tq.WriteString(strconv.FormatInt(int64(len(torrent.Leechers)), 10))
|
||||
tq.WriteString(strconv.FormatUint(uint64(torrent.LeechersLength.Load()), 10))
|
||||
tq.WriteString(",")
|
||||
tq.WriteString(strconv.FormatInt(torrent.LastAction, 10))
|
||||
tq.WriteString(strconv.FormatInt(torrent.LastAction.Load(), 10))
|
||||
tq.WriteString(")")
|
||||
|
||||
select {
|
||||
|
@ -58,11 +58,16 @@ func (db *Database) QueueTorrent(torrent *cdb.Torrent, deltaSnatch uint8) {
|
|||
}
|
||||
}
|
||||
|
||||
func (db *Database) QueueUser(user *cdb.User, rawDeltaUp, rawDeltaDown, deltaUp, deltaDown int64) {
|
||||
func (db *Database) QueueUser(
|
||||
user *cdb.User,
|
||||
rawDeltaUp,
|
||||
rawDeltaDown,
|
||||
deltaUp,
|
||||
deltaDown int64) {
|
||||
uq := db.bufferPool.Take()
|
||||
|
||||
uq.WriteString("(")
|
||||
uq.WriteString(strconv.FormatUint(uint64(user.ID), 10))
|
||||
uq.WriteString(strconv.FormatUint(uint64(user.ID.Load()), 10))
|
||||
uq.WriteString(",")
|
||||
uq.WriteString(strconv.FormatInt(deltaUp, 10))
|
||||
uq.WriteString(",")
|
||||
|
@ -174,7 +179,5 @@ func (db *Database) QueueSnatch(peer *cdb.Peer, now int64) {
|
|||
}
|
||||
|
||||
func (db *Database) UnPrune(torrent *cdb.Torrent) {
|
||||
db.mainConn.mutex.Lock()
|
||||
db.mainConn.execute(db.unPruneTorrentStmt, torrent.ID)
|
||||
db.mainConn.mutex.Unlock()
|
||||
db.execute(db.unPruneTorrentStmt, torrent.ID.Load())
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
|
@ -25,6 +26,7 @@ import (
|
|||
"chihaya/config"
|
||||
cdb "chihaya/database/types"
|
||||
"chihaya/log"
|
||||
"chihaya/util"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -49,32 +51,26 @@ func init() {
|
|||
*/
|
||||
func (db *Database) startReloading() {
|
||||
go func() {
|
||||
for !db.terminate {
|
||||
time.Sleep(time.Duration(reloadInterval) * time.Second)
|
||||
|
||||
util.ContextTick(db.ctx, time.Duration(reloadInterval)*time.Second, func() {
|
||||
db.waitGroup.Add(1)
|
||||
defer db.waitGroup.Done()
|
||||
db.loadUsers()
|
||||
db.loadHitAndRuns()
|
||||
db.loadTorrents()
|
||||
db.loadGroupsFreeleech()
|
||||
db.loadConfig()
|
||||
db.loadClients()
|
||||
db.waitGroup.Done()
|
||||
}
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
func (db *Database) loadUsers() {
|
||||
db.UsersLock.Lock()
|
||||
defer db.UsersLock.Unlock()
|
||||
|
||||
db.mainConn.mutex.Lock()
|
||||
defer db.mainConn.mutex.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
newUsers := make(map[string]*cdb.User, len(db.Users))
|
||||
|
||||
rows := db.mainConn.query(db.loadUsersStmt)
|
||||
dbUsers := *db.Users.Load()
|
||||
newUsers := make(map[string]*cdb.User, len(dbUsers))
|
||||
|
||||
rows := db.query(db.loadUsersStmt)
|
||||
if rows == nil {
|
||||
log.Error.Print("Failed to load hit and runs from database")
|
||||
log.WriteStack()
|
||||
|
@ -99,40 +95,37 @@ func (db *Database) loadUsers() {
|
|||
log.WriteStack()
|
||||
}
|
||||
|
||||
if old, exists := db.Users[torrentPass]; exists && old != nil {
|
||||
old.ID = id
|
||||
old.DownMultiplier = downMultiplier
|
||||
old.UpMultiplier = upMultiplier
|
||||
old.DisableDownload = disableDownload
|
||||
old.TrackerHide = trackerHide
|
||||
if old, exists := dbUsers[torrentPass]; exists && old != nil {
|
||||
old.ID.Store(id)
|
||||
old.DownMultiplier.Store(math.Float64bits(downMultiplier))
|
||||
old.UpMultiplier.Store(math.Float64bits(upMultiplier))
|
||||
old.DisableDownload.Store(disableDownload)
|
||||
old.TrackerHide.Store(trackerHide)
|
||||
|
||||
newUsers[torrentPass] = old
|
||||
} else {
|
||||
newUsers[torrentPass] = &cdb.User{
|
||||
ID: id,
|
||||
UpMultiplier: upMultiplier,
|
||||
DownMultiplier: downMultiplier,
|
||||
DisableDownload: disableDownload,
|
||||
TrackerHide: trackerHide,
|
||||
}
|
||||
u := &cdb.User{}
|
||||
u.ID.Store(id)
|
||||
u.DownMultiplier.Store(math.Float64bits(downMultiplier))
|
||||
u.UpMultiplier.Store(math.Float64bits(upMultiplier))
|
||||
u.DisableDownload.Store(disableDownload)
|
||||
u.TrackerHide.Store(trackerHide)
|
||||
newUsers[torrentPass] = u
|
||||
}
|
||||
}
|
||||
|
||||
db.Users = newUsers
|
||||
db.Users.Store(&newUsers)
|
||||
|
||||
elapsedTime := time.Since(start)
|
||||
collectors.UpdateReloadTime("users", elapsedTime)
|
||||
log.Info.Printf("User load complete (%d rows, %s)", len(db.Users), elapsedTime.String())
|
||||
log.Info.Printf("User load complete (%d rows, %s)", len(newUsers), elapsedTime.String())
|
||||
}
|
||||
|
||||
func (db *Database) loadHitAndRuns() {
|
||||
db.mainConn.mutex.Lock()
|
||||
defer db.mainConn.mutex.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
newHnr := make(map[cdb.UserTorrentPair]struct{})
|
||||
|
||||
rows := db.mainConn.query(db.loadHnrStmt)
|
||||
rows := db.query(db.loadHnrStmt)
|
||||
if rows == nil {
|
||||
log.Error.Print("Failed to load hit and runs from database")
|
||||
log.WriteStack()
|
||||
|
@ -169,100 +162,99 @@ func (db *Database) loadHitAndRuns() {
|
|||
func (db *Database) loadTorrents() {
|
||||
var start time.Time
|
||||
|
||||
newTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
|
||||
dbTorrents := *db.Torrents.Load()
|
||||
|
||||
func() {
|
||||
db.TorrentsLock.RLock()
|
||||
defer db.TorrentsLock.RUnlock()
|
||||
newTorrents := make(map[cdb.TorrentHash]*cdb.Torrent, len(dbTorrents))
|
||||
|
||||
db.mainConn.mutex.Lock()
|
||||
defer db.mainConn.mutex.Unlock()
|
||||
start = time.Now()
|
||||
|
||||
start = time.Now()
|
||||
rows := db.query(db.loadTorrentsStmt)
|
||||
if rows == nil {
|
||||
log.Error.Print("Failed to load torrents from database")
|
||||
log.WriteStack()
|
||||
|
||||
rows := db.mainConn.query(db.loadTorrentsStmt)
|
||||
if rows == nil {
|
||||
log.Error.Print("Failed to load torrents from database")
|
||||
log.WriteStack()
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
infoHash cdb.TorrentHash
|
||||
id uint32
|
||||
downMultiplier, upMultiplier float64
|
||||
snatched uint16
|
||||
status uint8
|
||||
group cdb.TorrentGroup
|
||||
)
|
||||
|
||||
if err := rows.Scan(
|
||||
&id,
|
||||
&infoHash,
|
||||
&downMultiplier,
|
||||
&upMultiplier,
|
||||
&snatched,
|
||||
&status,
|
||||
&group.GroupID,
|
||||
&group.TorrentType,
|
||||
); err != nil {
|
||||
log.Error.Printf("Error scanning torrent row: %s", err)
|
||||
log.WriteStack()
|
||||
}
|
||||
|
||||
if old, exists := db.Torrents[infoHash]; exists && old != nil {
|
||||
func() {
|
||||
old.Lock()
|
||||
defer old.Unlock()
|
||||
|
||||
old.ID = id
|
||||
old.DownMultiplier = downMultiplier
|
||||
old.UpMultiplier = upMultiplier
|
||||
old.Snatched = snatched
|
||||
old.Status = status
|
||||
old.Group = group
|
||||
}()
|
||||
|
||||
newTorrents[infoHash] = old
|
||||
} else {
|
||||
newTorrents[infoHash] = &cdb.Torrent{
|
||||
ID: id,
|
||||
UpMultiplier: upMultiplier,
|
||||
DownMultiplier: downMultiplier,
|
||||
Snatched: snatched,
|
||||
Status: status,
|
||||
Group: group,
|
||||
|
||||
Seeders: make(map[cdb.PeerKey]*cdb.Peer),
|
||||
Leechers: make(map[cdb.PeerKey]*cdb.Peer),
|
||||
}
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
db.TorrentsLock.Lock()
|
||||
defer db.TorrentsLock.Unlock()
|
||||
db.Torrents = newTorrents
|
||||
for rows.Next() {
|
||||
var (
|
||||
infoHash cdb.TorrentHash
|
||||
id uint32
|
||||
downMultiplier, upMultiplier float64
|
||||
snatched uint16
|
||||
status uint8
|
||||
groupID uint32
|
||||
torrentType string
|
||||
)
|
||||
|
||||
if err := rows.Scan(
|
||||
&id,
|
||||
&infoHash,
|
||||
&downMultiplier,
|
||||
&upMultiplier,
|
||||
&snatched,
|
||||
&status,
|
||||
&groupID,
|
||||
&torrentType,
|
||||
); err != nil {
|
||||
log.Error.Printf("Error scanning torrent row: %s", err)
|
||||
log.WriteStack()
|
||||
}
|
||||
|
||||
torrentTypeUint64, err := cdb.TorrentTypeFromString(torrentType)
|
||||
|
||||
if err != nil {
|
||||
log.Error.Printf("Error storing torrent row: %s", err)
|
||||
log.WriteStack()
|
||||
}
|
||||
|
||||
if old, exists := dbTorrents[infoHash]; exists && old != nil {
|
||||
old.ID.Store(id)
|
||||
old.DownMultiplier.Store(math.Float64bits(downMultiplier))
|
||||
old.UpMultiplier.Store(math.Float64bits(upMultiplier))
|
||||
old.Snatched.Store(uint32(snatched))
|
||||
old.Status.Store(uint32(status))
|
||||
|
||||
old.Group.TorrentType.Store(torrentTypeUint64)
|
||||
old.Group.GroupID.Store(groupID)
|
||||
|
||||
newTorrents[infoHash] = old
|
||||
} else {
|
||||
t := &cdb.Torrent{
|
||||
Seeders: make(map[cdb.PeerKey]*cdb.Peer),
|
||||
Leechers: make(map[cdb.PeerKey]*cdb.Peer),
|
||||
}
|
||||
t.InitializeLock()
|
||||
|
||||
t.ID.Store(id)
|
||||
t.DownMultiplier.Store(math.Float64bits(downMultiplier))
|
||||
t.UpMultiplier.Store(math.Float64bits(upMultiplier))
|
||||
t.Snatched.Store(uint32(snatched))
|
||||
t.Status.Store(uint32(status))
|
||||
|
||||
t.Group.TorrentType.Store(torrentTypeUint64)
|
||||
t.Group.GroupID.Store(groupID)
|
||||
|
||||
newTorrents[infoHash] = t
|
||||
}
|
||||
}
|
||||
|
||||
db.Torrents.Store(&newTorrents)
|
||||
|
||||
elapsedTime := time.Since(start)
|
||||
collectors.UpdateReloadTime("torrents", elapsedTime)
|
||||
log.Info.Printf("Torrent load complete (%d rows, %s)", len(db.Torrents), elapsedTime.String())
|
||||
log.Info.Printf("Torrent load complete (%d rows, %s)", len(newTorrents), elapsedTime.String())
|
||||
}
|
||||
|
||||
func (db *Database) loadGroupsFreeleech() {
|
||||
db.mainConn.mutex.Lock()
|
||||
defer db.mainConn.mutex.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
newTorrentGroupFreeleech := make(map[cdb.TorrentGroup]*cdb.TorrentGroupFreeleech)
|
||||
newTorrentGroupFreeleech := make(map[cdb.TorrentGroupKey]*cdb.TorrentGroupFreeleech)
|
||||
|
||||
rows := db.mainConn.query(db.loadTorrentGroupFreeleechStmt)
|
||||
rows := db.query(db.loadTorrentGroupFreeleechStmt)
|
||||
if rows == nil {
|
||||
log.Error.Print("Failed to load torrent group freeleech data from database")
|
||||
log.WriteStack()
|
||||
|
@ -277,15 +269,22 @@ func (db *Database) loadGroupsFreeleech() {
|
|||
for rows.Next() {
|
||||
var (
|
||||
downMultiplier, upMultiplier float64
|
||||
group cdb.TorrentGroup
|
||||
groupID uint32
|
||||
torrentType string
|
||||
)
|
||||
|
||||
if err := rows.Scan(&group.GroupID, &group.TorrentType, &downMultiplier, &upMultiplier); err != nil {
|
||||
log.Error.Printf("Error scanning torrent row: %s", err)
|
||||
if err := rows.Scan(&groupID, &torrentType, &downMultiplier, &upMultiplier); err != nil {
|
||||
log.Error.Printf("Error scanning torrent group freeleech row: %s", err)
|
||||
log.WriteStack()
|
||||
}
|
||||
|
||||
newTorrentGroupFreeleech[group] = &cdb.TorrentGroupFreeleech{
|
||||
k, err := cdb.TorrentGroupKeyFromString(torrentType, groupID)
|
||||
if err != nil {
|
||||
log.Error.Printf("Error storing torrent group freeleech row: %s", err)
|
||||
log.WriteStack()
|
||||
}
|
||||
|
||||
newTorrentGroupFreeleech[k] = &cdb.TorrentGroupFreeleech{
|
||||
UpMultiplier: upMultiplier,
|
||||
DownMultiplier: downMultiplier,
|
||||
}
|
||||
|
@ -295,14 +294,11 @@ func (db *Database) loadGroupsFreeleech() {
|
|||
|
||||
elapsedTime := time.Since(start)
|
||||
collectors.UpdateReloadTime("groups_freeleech", elapsedTime)
|
||||
log.Info.Printf("Group freeleech load complete (%d rows, %s)", len(db.Torrents), elapsedTime.String())
|
||||
log.Info.Printf("Group freeleech load complete (%d rows, %s)", len(newTorrentGroupFreeleech), elapsedTime.String())
|
||||
}
|
||||
|
||||
func (db *Database) loadConfig() {
|
||||
db.mainConn.mutex.Lock()
|
||||
defer db.mainConn.mutex.Unlock()
|
||||
|
||||
rows := db.mainConn.query(db.loadFreeleechStmt)
|
||||
rows := db.query(db.loadFreeleechStmt)
|
||||
if rows == nil {
|
||||
log.Error.Print("Failed to load config from database")
|
||||
log.WriteStack()
|
||||
|
@ -327,13 +323,10 @@ func (db *Database) loadConfig() {
|
|||
}
|
||||
|
||||
func (db *Database) loadClients() {
|
||||
db.mainConn.mutex.Lock()
|
||||
defer db.mainConn.mutex.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
newClients := make(map[uint16]string)
|
||||
|
||||
rows := db.mainConn.query(db.loadClientsStmt)
|
||||
rows := db.query(db.loadClientsStmt)
|
||||
if rows == nil {
|
||||
log.Error.Print("Failed to load clients from database")
|
||||
log.WriteStack()
|
||||
|
|
|
@ -1,21 +1,21 @@
|
|||
create table approved_clients
|
||||
(
|
||||
id mediumint unsigned auto_increment primary key,
|
||||
peer_id varchar(42) null,
|
||||
archived tinyint(1) default 0 null
|
||||
peer_id varchar(42) not null,
|
||||
archived tinyint(1) default 0 not null
|
||||
);
|
||||
|
||||
create table mod_core
|
||||
(
|
||||
mod_option varchar(121) not null primary key,
|
||||
mod_setting int(12) default 0 not null
|
||||
mod_setting int(12) not null
|
||||
);
|
||||
|
||||
create table torrent_group_freeleech
|
||||
(
|
||||
ID int(10) auto_increment primary key,
|
||||
GroupID int(10) default 0 not null,
|
||||
Type enum ('anime', 'music') default 'anime' not null,
|
||||
GroupID int(10) not null,
|
||||
Type enum ('anime', 'music') not null,
|
||||
DownMultiplier float default 1 not null,
|
||||
UpMultiplier float default 1 not null,
|
||||
constraint GroupID unique (GroupID, Type)
|
||||
|
@ -25,7 +25,7 @@ create table torrents
|
|||
(
|
||||
ID int(10) auto_increment primary key,
|
||||
GroupID int(10) not null,
|
||||
TorrentType enum ('anime', 'music') default 'anime' not null,
|
||||
TorrentType enum ('anime', 'music') not null,
|
||||
info_hash blob not null,
|
||||
Leechers int(6) default 0 not null,
|
||||
Seeders int(6) default 0 not null,
|
||||
|
@ -34,25 +34,13 @@ create table torrents
|
|||
DownMultiplier float default 1 not null,
|
||||
UpMultiplier float default 1 not null,
|
||||
Status int default 0 not null,
|
||||
constraint InfoHash unique (info_hash)
|
||||
constraint InfoHash unique (info_hash (20))
|
||||
);
|
||||
|
||||
create table torrents_group
|
||||
(
|
||||
ID int unsigned auto_increment primary key,
|
||||
Time int(10) default 0 not null
|
||||
) charset = utf8mb4;
|
||||
|
||||
create table torrents_group2
|
||||
(
|
||||
ID int unsigned auto_increment primary key,
|
||||
Time int(10) default 0 not null
|
||||
) charset = utf8mb4;
|
||||
|
||||
create table transfer_history
|
||||
(
|
||||
uid int default 0 not null,
|
||||
fid int default 0 not null,
|
||||
uid int not null,
|
||||
fid int not null,
|
||||
uploaded bigint default 0 not null,
|
||||
downloaded bigint default 0 not null,
|
||||
seeding tinyint default 0 not null,
|
||||
|
@ -64,7 +52,7 @@ create table transfer_history
|
|||
starttime int default 0 not null,
|
||||
last_announce int default 0 not null,
|
||||
snatched int default 0 not null,
|
||||
snatched_time int default 0 null,
|
||||
snatched_time int default 0 not null,
|
||||
primary key (uid, fid)
|
||||
);
|
||||
|
||||
|
@ -72,13 +60,13 @@ create table transfer_ips
|
|||
(
|
||||
last_announce int unsigned default 0 not null,
|
||||
starttime int unsigned default 0 not null,
|
||||
uid int unsigned default 0 not null,
|
||||
fid int unsigned default 0 not null,
|
||||
ip int unsigned default 0 not null,
|
||||
client_id mediumint unsigned default 0 not null,
|
||||
uid int unsigned not null,
|
||||
fid int unsigned not null,
|
||||
ip int unsigned not null,
|
||||
client_id mediumint unsigned not null,
|
||||
uploaded bigint unsigned default 0 not null,
|
||||
downloaded bigint unsigned default 0 not null,
|
||||
port smallint unsigned zerofill null,
|
||||
port smallint unsigned zerofill default 0 not null,
|
||||
primary key (uid, fid, ip, client_id)
|
||||
);
|
||||
|
||||
|
@ -89,8 +77,8 @@ create table users_main
|
|||
Downloaded bigint unsigned default 0 not null,
|
||||
Enabled enum ('0', '1', '2') default '0' not null,
|
||||
torrent_pass char(32) not null,
|
||||
rawup bigint unsigned not null,
|
||||
rawdl bigint unsigned not null,
|
||||
rawup bigint unsigned default 0 not null,
|
||||
rawdl bigint unsigned default 0 not null,
|
||||
DownMultiplier float default 1 not null,
|
||||
UpMultiplier float default 1 not null,
|
||||
DisableDownload tinyint(1) default 0 not null,
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"chihaya/util"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
@ -37,10 +38,9 @@ func init() {
|
|||
|
||||
func (db *Database) startSerializing() {
|
||||
go func() {
|
||||
for !db.terminate {
|
||||
time.Sleep(time.Duration(serializeInterval) * time.Second)
|
||||
util.ContextTick(db.ctx, time.Duration(serializeInterval)*time.Second, func() {
|
||||
db.serialize()
|
||||
}
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
|
@ -68,10 +68,7 @@ func (db *Database) serialize() {
|
|||
torrentFile.Close()
|
||||
}()
|
||||
|
||||
db.TorrentsLock.RLock()
|
||||
defer db.TorrentsLock.RUnlock()
|
||||
|
||||
if err = cdb.WriteTorrents(torrentFile, db.Torrents); err != nil {
|
||||
if err = cdb.WriteTorrents(torrentFile, *db.Torrents.Load()); err != nil {
|
||||
log.Error.Print("Failed to encode torrents for serialization: ", err)
|
||||
return err
|
||||
}
|
||||
|
@ -96,10 +93,7 @@ func (db *Database) serialize() {
|
|||
userFile.Close()
|
||||
}()
|
||||
|
||||
db.UsersLock.RLock()
|
||||
defer db.UsersLock.RUnlock()
|
||||
|
||||
if err = cdb.WriteUsers(userFile, db.Users); err != nil {
|
||||
if err = cdb.WriteUsers(userFile, *db.Users.Load()); err != nil {
|
||||
log.Error.Print("Failed to encode users for serialization: ", err)
|
||||
return err
|
||||
}
|
||||
|
@ -124,6 +118,10 @@ func (db *Database) deserialize() {
|
|||
|
||||
start := time.Now()
|
||||
|
||||
torrents := 0
|
||||
|
||||
peers := 0
|
||||
|
||||
func() {
|
||||
torrentFile, err := os.OpenFile(torrentBinFilename, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
|
@ -134,13 +132,19 @@ func (db *Database) deserialize() {
|
|||
//goland:noinspection GoUnhandledErrorResult
|
||||
defer torrentFile.Close()
|
||||
|
||||
db.TorrentsLock.Lock()
|
||||
defer db.TorrentsLock.Unlock()
|
||||
|
||||
if err = cdb.LoadTorrents(torrentFile, db.Torrents); err != nil {
|
||||
dbTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
|
||||
if err = cdb.LoadTorrents(torrentFile, dbTorrents); err != nil {
|
||||
log.Error.Print("Failed to deserialize torrent cache: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
torrents = len(dbTorrents)
|
||||
|
||||
for _, t := range dbTorrents {
|
||||
peers += int(t.LeechersLength.Load()) + int(t.SeedersLength.Load())
|
||||
}
|
||||
|
||||
db.Torrents.Store(&dbTorrents)
|
||||
}()
|
||||
|
||||
func() {
|
||||
|
@ -153,33 +157,16 @@ func (db *Database) deserialize() {
|
|||
//goland:noinspection GoUnhandledErrorResult
|
||||
defer userFile.Close()
|
||||
|
||||
db.UsersLock.Lock()
|
||||
defer db.UsersLock.Unlock()
|
||||
|
||||
if err = cdb.LoadUsers(userFile, db.Users); err != nil {
|
||||
users := make(map[string]*cdb.User)
|
||||
if err = cdb.LoadUsers(userFile, users); err != nil {
|
||||
log.Error.Print("Failed to deserialize user cache: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
db.Users.Store(&users)
|
||||
}()
|
||||
|
||||
db.TorrentsLock.RLock()
|
||||
defer db.TorrentsLock.RUnlock()
|
||||
|
||||
torrents := len(db.Torrents)
|
||||
|
||||
peers := 0
|
||||
|
||||
for _, t := range db.Torrents {
|
||||
func() {
|
||||
t.RLock()
|
||||
defer t.RUnlock()
|
||||
peers += len(t.Leechers) + len(t.Seeders)
|
||||
}()
|
||||
}
|
||||
|
||||
db.UsersLock.RLock()
|
||||
defer db.UsersLock.RUnlock()
|
||||
users := len(db.Users)
|
||||
users := len(*db.Users.Load())
|
||||
|
||||
log.Info.Printf("Loaded %d users, %d torrents and %d peers (%s)",
|
||||
users, torrents, peers, time.Since(start).String())
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"math"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
@ -32,13 +34,14 @@ func TestSerializer(t *testing.T) {
|
|||
testTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
|
||||
testUsers := make(map[string]*cdb.User)
|
||||
|
||||
testUsers["mUztWMpBYNCqzmge6vGeEUGSrctJbgpQ"] = &cdb.User{
|
||||
DisableDownload: false,
|
||||
TrackerHide: false,
|
||||
ID: 12,
|
||||
UpMultiplier: 1,
|
||||
DownMultiplier: 1,
|
||||
}
|
||||
testUser := &cdb.User{}
|
||||
testUser.ID.Store(12)
|
||||
testUser.DownMultiplier.Store(math.Float64bits(1))
|
||||
testUser.UpMultiplier.Store(math.Float64bits(1))
|
||||
testUser.DisableDownload.Store(false)
|
||||
testUser.TrackerHide.Store(false)
|
||||
|
||||
testUsers["mUztWMpBYNCqzmge6vGeEUGSrctJbgpQ"] = testUser
|
||||
|
||||
testPeer := &cdb.Peer{
|
||||
UserID: 12,
|
||||
|
@ -57,46 +60,63 @@ func TestSerializer(t *testing.T) {
|
|||
testTorrentHash := cdb.TorrentHash{
|
||||
114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56,
|
||||
}
|
||||
testTorrents[testTorrentHash] = &cdb.Torrent{
|
||||
Status: 1,
|
||||
Snatched: 100,
|
||||
ID: 10,
|
||||
LastAction: time.Now().Unix(),
|
||||
UpMultiplier: 1,
|
||||
DownMultiplier: 1,
|
||||
|
||||
torrent := &cdb.Torrent{
|
||||
Seeders: map[cdb.PeerKey]*cdb.Peer{
|
||||
cdb.NewPeerKey(12, cdb.PeerIDFromRawString("peer_is_twenty_chars")): testPeer,
|
||||
},
|
||||
Leechers: map[cdb.PeerKey]*cdb.Peer{},
|
||||
}
|
||||
torrent.InitializeLock()
|
||||
torrent.ID.Store(10)
|
||||
torrent.Status.Store(1)
|
||||
torrent.Snatched.Store(100)
|
||||
torrent.LastAction.Store(time.Now().Unix())
|
||||
torrent.DownMultiplier.Store(math.Float64bits(1))
|
||||
torrent.UpMultiplier.Store(math.Float64bits(1))
|
||||
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
|
||||
|
||||
torrent.Group.GroupID.Store(1)
|
||||
torrent.Group.TorrentType.Store(cdb.MustTorrentTypeFromString("anime"))
|
||||
|
||||
testTorrents[testTorrentHash] = torrent
|
||||
|
||||
// Prepare empty map to populate with test data
|
||||
db.Torrents = make(map[cdb.TorrentHash]*cdb.Torrent)
|
||||
db.Users = make(map[string]*cdb.User)
|
||||
dbTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
|
||||
db.Torrents.Store(&dbTorrents)
|
||||
|
||||
if err := copier.Copy(&db.Torrents, testTorrents); err != nil {
|
||||
dbUsers := make(map[string]*cdb.User)
|
||||
db.Users.Store(&dbUsers)
|
||||
|
||||
if err := copier.Copy(&dbTorrents, testTorrents); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := copier.Copy(&db.Users, testUsers); err != nil {
|
||||
if err := copier.Copy(&dbUsers, testUsers); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
db.serialize()
|
||||
|
||||
// Reset map to fully test deserialization
|
||||
db.Torrents = make(map[cdb.TorrentHash]*cdb.Torrent)
|
||||
db.Users = make(map[string]*cdb.User)
|
||||
dbTorrents = make(map[cdb.TorrentHash]*cdb.Torrent)
|
||||
db.Torrents.Store(&dbTorrents)
|
||||
|
||||
dbUsers = make(map[string]*cdb.User)
|
||||
db.Users.Store(&dbUsers)
|
||||
|
||||
db.deserialize()
|
||||
|
||||
if !reflect.DeepEqual(db.Torrents, testTorrents) {
|
||||
dbTorrents = *db.Torrents.Load()
|
||||
dbUsers = *db.Users.Load()
|
||||
|
||||
if !cmp.Equal(dbTorrents, testTorrents, cdb.TorrentTestCompareOptions...) {
|
||||
t.Fatalf("Torrents (%v) after serialization and deserialization do not match original torrents (%v)!",
|
||||
db.Torrents, testTorrents)
|
||||
dbTorrents, testTorrents)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(db.Users, testUsers) {
|
||||
if !reflect.DeepEqual(dbUsers, testUsers) {
|
||||
t.Fatalf("Users (%v) after serialization and deserialization do not match original users (%v)!",
|
||||
db.Users, testUsers)
|
||||
dbUsers, testUsers)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -107,6 +107,7 @@ func LoadTorrents(r io.Reader, torrents map[TorrentHash]*Torrent) error {
|
|||
}
|
||||
|
||||
t := &Torrent{}
|
||||
t.InitializeLock()
|
||||
|
||||
if err := t.Load(version, reader); err != nil {
|
||||
return err
|
||||
|
|
|
@ -18,13 +18,18 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/viney-shih/go-lock"
|
||||
"io"
|
||||
"math"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
const TorrentHashSize = 20
|
||||
|
@ -116,42 +121,61 @@ func (h *TorrentHash) UnmarshalText(b []byte) error {
|
|||
}
|
||||
|
||||
type Torrent struct {
|
||||
// lock This must be taken whenever read or write is made to Seeders or Leechers fields on this torrent.
|
||||
lock *lock.ChanMutex
|
||||
Seeders map[PeerKey]*Peer
|
||||
Leechers map[PeerKey]*Peer
|
||||
|
||||
// SeedersLength Contains the length of Seeders. When Seeders is modified this field must be updated
|
||||
SeedersLength atomic.Uint32
|
||||
// LeechersLength Contains the length of Leechers. When LeechersLength is modified this field must be updated
|
||||
LeechersLength atomic.Uint32
|
||||
|
||||
Group TorrentGroup
|
||||
ID uint32
|
||||
ID atomic.Uint32
|
||||
|
||||
Snatched uint16
|
||||
// Snatched 16 bits
|
||||
Snatched atomic.Uint32
|
||||
|
||||
Status uint8
|
||||
LastAction int64 // unix time
|
||||
// Snatched 8 bits
|
||||
Status atomic.Uint32
|
||||
LastAction atomic.Int64 // unix time
|
||||
|
||||
UpMultiplier float64
|
||||
DownMultiplier float64
|
||||
// UpMultiplier float64
|
||||
UpMultiplier atomic.Uint64
|
||||
// DownMultiplier float64
|
||||
DownMultiplier atomic.Uint64
|
||||
}
|
||||
|
||||
// lock This must be taken whenever read or write is made to fields on this torrent.
|
||||
// Maybe single sync.Mutex is easier to handle, but prevent concurrent access.
|
||||
lock sync.RWMutex
|
||||
func (t *Torrent) InitializeLock() {
|
||||
if t.lock != nil {
|
||||
panic("already initialized")
|
||||
}
|
||||
|
||||
t.lock = lock.NewChanMutex()
|
||||
}
|
||||
|
||||
func (t *Torrent) Lock() {
|
||||
t.lock.Lock()
|
||||
}
|
||||
|
||||
func (t *Torrent) TryLockWithContext(ctx context.Context) bool {
|
||||
return t.lock.TryLockWithContext(ctx)
|
||||
}
|
||||
|
||||
func (t *Torrent) Unlock() {
|
||||
t.lock.Unlock()
|
||||
}
|
||||
|
||||
func (t *Torrent) RLock() {
|
||||
t.lock.RLock()
|
||||
}
|
||||
|
||||
func (t *Torrent) RUnlock() {
|
||||
t.lock.RUnlock()
|
||||
}
|
||||
|
||||
func (t *Torrent) Load(version uint64, reader readerAndByteReader) (err error) {
|
||||
var (
|
||||
id uint32
|
||||
snatched uint16
|
||||
status uint8
|
||||
lastAction int64
|
||||
upMultiplier, downMultiplier float64
|
||||
)
|
||||
|
||||
var varIntLen uint64
|
||||
|
||||
if varIntLen, err = binary.ReadUvarint(reader); err != nil {
|
||||
|
@ -175,6 +199,8 @@ func (t *Torrent) Load(version uint64, reader readerAndByteReader) (err error) {
|
|||
t.Seeders[k] = s
|
||||
}
|
||||
|
||||
t.SeedersLength.Store(uint32(len(t.Seeders)))
|
||||
|
||||
if varIntLen, err = binary.ReadUvarint(reader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -195,100 +221,270 @@ func (t *Torrent) Load(version uint64, reader readerAndByteReader) (err error) {
|
|||
t.Leechers[k] = l
|
||||
}
|
||||
|
||||
t.LeechersLength.Store(uint32(len(t.Leechers)))
|
||||
|
||||
if err = t.Group.Load(version, reader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = binary.Read(reader, binary.LittleEndian, &t.ID); err != nil {
|
||||
if err = binary.Read(reader, binary.LittleEndian, &id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = binary.Read(reader, binary.LittleEndian, &t.Snatched); err != nil {
|
||||
if err = binary.Read(reader, binary.LittleEndian, &snatched); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = binary.Read(reader, binary.LittleEndian, &t.Status); err != nil {
|
||||
if err = binary.Read(reader, binary.LittleEndian, &status); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = binary.Read(reader, binary.LittleEndian, &t.LastAction); err != nil {
|
||||
if err = binary.Read(reader, binary.LittleEndian, &lastAction); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = binary.Read(reader, binary.LittleEndian, &t.UpMultiplier); err != nil {
|
||||
if err = binary.Read(reader, binary.LittleEndian, &upMultiplier); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return binary.Read(reader, binary.LittleEndian, &t.DownMultiplier)
|
||||
if err = binary.Read(reader, binary.LittleEndian, &downMultiplier); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.ID.Store(id)
|
||||
t.Snatched.Store(uint32(snatched))
|
||||
t.Status.Store(uint32(status))
|
||||
t.LastAction.Store(lastAction)
|
||||
t.UpMultiplier.Store(math.Float64bits(upMultiplier))
|
||||
t.DownMultiplier.Store(math.Float64bits(downMultiplier))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Torrent) Append(preAllocatedBuffer []byte) (buf []byte) {
|
||||
t.RLock()
|
||||
defer t.RUnlock()
|
||||
|
||||
buf = preAllocatedBuffer
|
||||
buf = binary.AppendUvarint(buf, uint64(len(t.Seeders)))
|
||||
|
||||
for k, s := range t.Seeders {
|
||||
buf = append(buf, k[:]...)
|
||||
func() {
|
||||
// This could be a read-only lock, but this simpler lock is faster overall
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
|
||||
buf = s.Append(buf)
|
||||
}
|
||||
buf = binary.AppendUvarint(buf, uint64(len(t.Seeders)))
|
||||
|
||||
buf = binary.AppendUvarint(buf, uint64(len(t.Leechers)))
|
||||
for k, s := range t.Seeders {
|
||||
buf = append(buf, k[:]...)
|
||||
|
||||
for k, l := range t.Leechers {
|
||||
buf = append(buf, k[:]...)
|
||||
buf = s.Append(buf)
|
||||
}
|
||||
|
||||
buf = l.Append(buf)
|
||||
}
|
||||
buf = binary.AppendUvarint(buf, uint64(len(t.Leechers)))
|
||||
|
||||
for k, l := range t.Leechers {
|
||||
buf = append(buf, k[:]...)
|
||||
|
||||
buf = l.Append(buf)
|
||||
}
|
||||
}()
|
||||
|
||||
buf = t.Group.Append(buf)
|
||||
|
||||
buf = binary.LittleEndian.AppendUint32(buf, t.ID)
|
||||
buf = binary.LittleEndian.AppendUint16(buf, t.Snatched)
|
||||
buf = append(buf, t.Status)
|
||||
buf = binary.LittleEndian.AppendUint64(buf, uint64(t.LastAction))
|
||||
buf = binary.LittleEndian.AppendUint64(buf, math.Float64bits(t.UpMultiplier))
|
||||
buf = binary.LittleEndian.AppendUint64(buf, math.Float64bits(t.DownMultiplier))
|
||||
buf = binary.LittleEndian.AppendUint32(buf, t.ID.Load())
|
||||
buf = binary.LittleEndian.AppendUint16(buf, uint16(t.Snatched.Load()))
|
||||
buf = append(buf, uint8(t.Status.Load()))
|
||||
buf = binary.LittleEndian.AppendUint64(buf, uint64(t.LastAction.Load()))
|
||||
buf = binary.LittleEndian.AppendUint64(buf, t.UpMultiplier.Load())
|
||||
buf = binary.LittleEndian.AppendUint64(buf, t.DownMultiplier.Load())
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
var encodeJSONTorrentMap = make(map[string]any)
|
||||
var encodeJSONTorrentGroupMap = make(map[string]any)
|
||||
|
||||
// MarshalJSON Due to using atomics, JSON will not marshal values within them.
|
||||
// This is only safe to call from a single thread at once
|
||||
func (t *Torrent) MarshalJSON() (buf []byte, err error) {
|
||||
encodeJSONTorrentMap["ID"] = t.ID.Load()
|
||||
encodeJSONTorrentMap["Seeders"] = t.Seeders
|
||||
encodeJSONTorrentMap["Leechers"] = t.Leechers
|
||||
|
||||
var torrentTypeBuf [8]byte
|
||||
|
||||
binary.LittleEndian.PutUint64(torrentTypeBuf[:], t.Group.TorrentType.Load())
|
||||
|
||||
i := 0
|
||||
|
||||
for ; i < len(torrentTypeBuf); i++ {
|
||||
if torrentTypeBuf[i] == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
encodeJSONTorrentGroupMap["TorrentType"] = string(torrentTypeBuf[:i])
|
||||
encodeJSONTorrentGroupMap["GroupID"] = t.Group.GroupID.Load()
|
||||
encodeJSONTorrentMap["Group"] = encodeJSONTorrentGroupMap
|
||||
encodeJSONTorrentMap["Snatched"] = uint16(t.Snatched.Load())
|
||||
encodeJSONTorrentMap["Status"] = uint8(t.Status.Load())
|
||||
encodeJSONTorrentMap["LastAction"] = t.LastAction.Load()
|
||||
encodeJSONTorrentMap["UpMultiplier"] = math.Float64frombits(t.UpMultiplier.Load())
|
||||
encodeJSONTorrentMap["DownMultiplier"] = math.Float64frombits(t.UpMultiplier.Load())
|
||||
|
||||
return json.Marshal(encodeJSONTorrentMap)
|
||||
}
|
||||
|
||||
type decodeJSONTorrent struct {
|
||||
Seeders map[PeerKey]*Peer
|
||||
Leechers map[PeerKey]*Peer
|
||||
|
||||
Group struct {
|
||||
TorrentType string
|
||||
GroupID uint32
|
||||
}
|
||||
ID uint32
|
||||
Snatched uint16
|
||||
|
||||
Status uint8
|
||||
LastAction int64
|
||||
UpMultiplier float64
|
||||
DownMultiplier float64
|
||||
}
|
||||
|
||||
// UnmarshalJSON Due to using atomics, JSON will not marshal values within them.
|
||||
// This is only safe to call from a single thread at once
|
||||
func (t *Torrent) UnmarshalJSON(buf []byte) (err error) {
|
||||
var torrentJSON decodeJSONTorrent
|
||||
if err = json.Unmarshal(buf, &torrentJSON); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.Seeders = torrentJSON.Seeders
|
||||
t.Leechers = torrentJSON.Leechers
|
||||
t.SeedersLength.Store(uint32(len(t.Seeders)))
|
||||
t.LeechersLength.Store(uint32(len(t.Leechers)))
|
||||
|
||||
torrentType, err := TorrentTypeFromString(torrentJSON.Group.TorrentType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.Group.TorrentType.Store(torrentType)
|
||||
t.Group.GroupID.Store(torrentJSON.Group.GroupID)
|
||||
t.Snatched.Store(uint32(torrentJSON.Snatched))
|
||||
t.Status.Store(uint32(torrentJSON.Status))
|
||||
t.LastAction.Store(torrentJSON.LastAction)
|
||||
t.UpMultiplier.Store(math.Float64bits(torrentJSON.UpMultiplier))
|
||||
t.DownMultiplier.Store(math.Float64bits(torrentJSON.DownMultiplier))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type TorrentGroupFreeleech struct {
|
||||
UpMultiplier float64
|
||||
DownMultiplier float64
|
||||
}
|
||||
|
||||
type TorrentGroup struct {
|
||||
TorrentType string
|
||||
GroupID uint32
|
||||
type TorrentGroupKey [8 + 4]byte
|
||||
|
||||
func MustTorrentGroupKeyFromString(torrentType string, groupID uint32) TorrentGroupKey {
|
||||
k, err := TorrentGroupKeyFromString(torrentType, groupID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return k
|
||||
}
|
||||
|
||||
func (g *TorrentGroup) Load(_ uint64, reader readerAndByteReader) (err error) {
|
||||
var varIntLen uint64
|
||||
func TorrentGroupKeyFromString(torrentType string, groupID uint32) (k TorrentGroupKey, err error) {
|
||||
t, err := TorrentTypeFromString(torrentType)
|
||||
if err != nil {
|
||||
return TorrentGroupKey{}, err
|
||||
}
|
||||
|
||||
if varIntLen, err = binary.ReadUvarint(reader); err != nil {
|
||||
binary.LittleEndian.PutUint64(k[:], t)
|
||||
binary.LittleEndian.PutUint32(k[8:], groupID)
|
||||
|
||||
return k, nil
|
||||
}
|
||||
|
||||
func MustTorrentTypeFromString(torrentType string) uint64 {
|
||||
t, err := TorrentTypeFromString(torrentType)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func TorrentTypeFromString(torrentType string) (t uint64, err error) {
|
||||
if len(torrentType) > 8 {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var buf [8]byte
|
||||
|
||||
copy(buf[:], torrentType)
|
||||
|
||||
return binary.LittleEndian.Uint64(buf[:]), nil
|
||||
}
|
||||
|
||||
type TorrentGroup struct {
|
||||
TorrentType atomic.Uint64
|
||||
GroupID atomic.Uint32
|
||||
}
|
||||
|
||||
func (g *TorrentGroup) Key() (k TorrentGroupKey) {
|
||||
binary.LittleEndian.PutUint64(k[:], g.TorrentType.Load())
|
||||
binary.LittleEndian.PutUint32(k[8:], g.GroupID.Load())
|
||||
|
||||
return k
|
||||
}
|
||||
|
||||
var ErrTorrentTypeTooLong = errors.New("torrent type too long, maximum 8 bytes")
|
||||
|
||||
func (g *TorrentGroup) Load(version uint64, reader readerAndByteReader) (err error) {
|
||||
var (
|
||||
torrentType uint64
|
||||
groupID uint32
|
||||
)
|
||||
|
||||
if version <= 2 {
|
||||
var varIntLen uint64
|
||||
|
||||
if varIntLen, err = binary.ReadUvarint(reader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if varIntLen > 8 {
|
||||
return ErrTorrentTypeTooLong
|
||||
}
|
||||
|
||||
buf := make([]byte, 8)
|
||||
if _, err = io.ReadFull(reader, buf[:varIntLen]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
torrentType = binary.LittleEndian.Uint64(buf)
|
||||
} else {
|
||||
if err = binary.Read(reader, binary.LittleEndian, &torrentType); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err = binary.Read(reader, binary.LittleEndian, &groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := make([]byte, varIntLen)
|
||||
g.TorrentType.Store(torrentType)
|
||||
g.GroupID.Store(groupID)
|
||||
|
||||
if _, err = io.ReadFull(reader, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g.TorrentType = string(buf)
|
||||
|
||||
return binary.Read(reader, binary.LittleEndian, &g.GroupID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *TorrentGroup) Append(preAllocatedBuffer []byte) (buf []byte) {
|
||||
buf = preAllocatedBuffer
|
||||
buf = binary.AppendUvarint(buf, uint64(len(g.TorrentType)))
|
||||
buf = append(buf, []byte(g.TorrentType)...)
|
||||
buf = binary.LittleEndian.AppendUint64(buf, g.TorrentType.Load())
|
||||
|
||||
return binary.LittleEndian.AppendUint32(buf, g.GroupID)
|
||||
return binary.LittleEndian.AppendUint32(buf, g.GroupID.Load())
|
||||
}
|
||||
|
||||
// TorrentCacheFile holds filename used by serializer for this type
|
||||
|
@ -296,4 +492,12 @@ var TorrentCacheFile = "torrent-cache"
|
|||
|
||||
// TorrentCacheVersion Used to distinguish old versions on the on-disk cache.
|
||||
// Bump when fields are altered on Torrent, Peer or TorrentGroup structs
|
||||
const TorrentCacheVersion = 2
|
||||
const TorrentCacheVersion = 3
|
||||
|
||||
var TorrentTestCompareOptions = []cmp.Option{
|
||||
cmp.AllowUnexported(atomic.Uint32{}),
|
||||
cmp.AllowUnexported(atomic.Uint64{}),
|
||||
cmp.AllowUnexported(atomic.Int64{}),
|
||||
cmp.AllowUnexported(atomic.Bool{}),
|
||||
cmpopts.IgnoreFields(Torrent{}, "lock"),
|
||||
}
|
||||
|
|
|
@ -19,62 +19,121 @@ package types
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"math"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID uint32
|
||||
ID atomic.Uint32
|
||||
|
||||
DisableDownload bool
|
||||
DisableDownload atomic.Bool
|
||||
|
||||
TrackerHide bool
|
||||
TrackerHide atomic.Bool
|
||||
|
||||
UpMultiplier float64
|
||||
DownMultiplier float64
|
||||
// UpMultiplier A float64 under the covers
|
||||
UpMultiplier atomic.Uint64
|
||||
// DownMultiplier A float64 under the covers
|
||||
DownMultiplier atomic.Uint64
|
||||
}
|
||||
|
||||
func (u *User) Load(_ uint64, reader readerAndByteReader) (err error) {
|
||||
if err = binary.Read(reader, binary.LittleEndian, &u.ID); err != nil {
|
||||
var (
|
||||
id uint32
|
||||
disableDownload, trackerHide bool
|
||||
upMultiplier, downMultiplier float64
|
||||
)
|
||||
|
||||
if err = binary.Read(reader, binary.LittleEndian, &id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = binary.Read(reader, binary.LittleEndian, &u.DisableDownload); err != nil {
|
||||
if err = binary.Read(reader, binary.LittleEndian, &disableDownload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = binary.Read(reader, binary.LittleEndian, &u.TrackerHide); err != nil {
|
||||
if err = binary.Read(reader, binary.LittleEndian, &trackerHide); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = binary.Read(reader, binary.LittleEndian, &u.UpMultiplier); err != nil {
|
||||
if err = binary.Read(reader, binary.LittleEndian, &upMultiplier); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return binary.Read(reader, binary.LittleEndian, &u.DownMultiplier)
|
||||
if err = binary.Read(reader, binary.LittleEndian, &downMultiplier); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.ID.Store(id)
|
||||
u.DisableDownload.Store(disableDownload)
|
||||
u.TrackerHide.Store(trackerHide)
|
||||
u.UpMultiplier.Store(math.Float64bits(upMultiplier))
|
||||
u.DownMultiplier.Store(math.Float64bits(downMultiplier))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) Append(preAllocatedBuffer []byte) (buf []byte) {
|
||||
buf = preAllocatedBuffer
|
||||
buf = binary.LittleEndian.AppendUint32(buf, u.ID)
|
||||
buf = binary.LittleEndian.AppendUint32(buf, u.ID.Load())
|
||||
|
||||
if u.DisableDownload {
|
||||
if u.DisableDownload.Load() {
|
||||
buf = append(buf, 1)
|
||||
} else {
|
||||
buf = append(buf, 0)
|
||||
}
|
||||
|
||||
if u.TrackerHide {
|
||||
if u.TrackerHide.Load() {
|
||||
buf = append(buf, 1)
|
||||
} else {
|
||||
buf = append(buf, 0)
|
||||
}
|
||||
|
||||
buf = binary.LittleEndian.AppendUint64(buf, math.Float64bits(u.UpMultiplier))
|
||||
buf = binary.LittleEndian.AppendUint64(buf, math.Float64bits(u.DownMultiplier))
|
||||
buf = binary.LittleEndian.AppendUint64(buf, u.UpMultiplier.Load())
|
||||
buf = binary.LittleEndian.AppendUint64(buf, u.DownMultiplier.Load())
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
var encodeJSONUserMap = make(map[string]any)
|
||||
|
||||
// MarshalJSON Due to using atomics, JSON will not marshal values within them.
|
||||
// This is only safe to call from a single thread at once
|
||||
func (u *User) MarshalJSON() (buf []byte, err error) {
|
||||
encodeJSONUserMap["ID"] = u.ID.Load()
|
||||
encodeJSONUserMap["DisableDownload"] = u.DisableDownload.Load()
|
||||
encodeJSONUserMap["TrackerHide"] = u.TrackerHide.Load()
|
||||
encodeJSONUserMap["UpMultiplier"] = math.Float64frombits(u.UpMultiplier.Load())
|
||||
encodeJSONUserMap["DownMultiplier"] = math.Float64frombits(u.UpMultiplier.Load())
|
||||
|
||||
return json.Marshal(encodeJSONUserMap)
|
||||
}
|
||||
|
||||
type decodeJSONUser struct {
|
||||
ID uint32
|
||||
DisableDownload bool
|
||||
TrackerHide bool
|
||||
UpMultiplier float64
|
||||
DownMultiplier float64
|
||||
}
|
||||
|
||||
// UnmarshalJSON Due to using atomics, JSON will not marshal values within them.
|
||||
// This is only safe to call from a single thread at once
|
||||
func (u *User) UnmarshalJSON(buf []byte) (err error) {
|
||||
var userJSON decodeJSONUser
|
||||
if err = json.Unmarshal(buf, &userJSON); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.ID.Store(userJSON.ID)
|
||||
u.DisableDownload.Store(userJSON.DisableDownload)
|
||||
u.TrackerHide.Store(userJSON.TrackerHide)
|
||||
u.UpMultiplier.Store(math.Float64bits(userJSON.UpMultiplier))
|
||||
u.DownMultiplier.Store(math.Float64bits(userJSON.DownMultiplier))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type UserTorrentPair struct {
|
||||
UserID uint32
|
||||
TorrentID uint32
|
||||
|
|
3
go.mod
3
go.mod
|
@ -5,9 +5,11 @@ go 1.19
|
|||
require (
|
||||
github.com/go-sql-driver/mysql v1.7.1
|
||||
github.com/go-testfixtures/testfixtures/v3 v3.9.0
|
||||
github.com/google/go-cmp v0.5.9
|
||||
github.com/jinzhu/copier v0.3.5
|
||||
github.com/prometheus/client_golang v1.15.1
|
||||
github.com/prometheus/common v0.44.0
|
||||
github.com/valyala/fasthttp v1.47.0
|
||||
github.com/viney-shih/go-lock v1.1.2
|
||||
github.com/zeebo/bencode v1.0.0
|
||||
)
|
||||
|
@ -32,6 +34,7 @@ require (
|
|||
github.com/prometheus/procfs v0.9.0 // indirect
|
||||
github.com/segmentio/asm v1.2.0 // indirect
|
||||
github.com/shopspring/decimal v1.3.1 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
go.opentelemetry.io/otel v1.15.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.15.0 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
|
|
7
go.sum
7
go.sum
|
@ -32,6 +32,7 @@ github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW
|
|||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
|
||||
|
@ -88,6 +89,10 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5
|
|||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
||||
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.47.0 h1:y7moDoxYzMooFpT5aHgNgVOQDrS3qlkfiP9mDtGGK9c=
|
||||
github.com/valyala/fasthttp v1.47.0/go.mod h1:k2zXd82h/7UZc3VOdJ2WaUqt1uZ/XpXAfE9i+HBC3lA=
|
||||
github.com/viney-shih/go-lock v1.1.2 h1:3TdGTiHZCPqBdTvFbQZQN/TRZzKF3KWw2rFEyKz3YqA=
|
||||
github.com/viney-shih/go-lock v1.1.2/go.mod h1:Yijm78Ljteb3kRiJrbLAxVntkUukGu5uzSxq/xV7OO8=
|
||||
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
|
||||
|
@ -107,7 +112,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
|||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc=
|
||||
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
|
|
|
@ -20,7 +20,7 @@ package server
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"github.com/valyala/fasthttp"
|
||||
"time"
|
||||
|
||||
"chihaya/log"
|
||||
|
@ -35,10 +35,10 @@ func alive(buf *bytes.Buffer) int {
|
|||
res, err := json.Marshal(response{time.Now().UnixMilli(), time.Since(handler.startTime).Milliseconds()})
|
||||
if err != nil {
|
||||
log.Error.Print("Failed to marshal JSON alive response: ", err)
|
||||
return http.StatusInternalServerError
|
||||
return fasthttp.StatusInternalServerError
|
||||
}
|
||||
|
||||
buf.Write(res)
|
||||
|
||||
return http.StatusOK
|
||||
return fasthttp.StatusOK
|
||||
}
|
||||
|
|
|
@ -21,9 +21,9 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/valyala/fasthttp"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -111,67 +111,68 @@ func getPublicIPV4(ipAddr string, exists bool) (string, bool) {
|
|||
return ipAddr, !private
|
||||
}
|
||||
|
||||
func announce(ctx context.Context, qs string, header http.Header, remoteAddr string, user *cdb.User,
|
||||
db *database.Database, buf *bytes.Buffer) int {
|
||||
qp, err := params.ParseQuery(qs)
|
||||
func announce(
|
||||
ctx context.Context, queryArgs *fasthttp.Args, header *fasthttp.RequestHeader,
|
||||
remoteAddr net.Addr, user *cdb.User, db *database.Database, buf *bytes.Buffer) int {
|
||||
qp, err := params.ParseQuery(queryArgs)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Mandatory parameters
|
||||
infoHashes := qp.InfoHashes()
|
||||
peerID, _ := qp.Get("peer_id")
|
||||
port, portExists := qp.GetUint16("port")
|
||||
uploaded, uploadedExists := qp.GetUint64("uploaded")
|
||||
downloaded, downloadedExists := qp.GetUint64("downloaded")
|
||||
left, leftExists := qp.GetUint64("left")
|
||||
infoHashes := qp.Params.InfoHashes
|
||||
peerID := qp.Params.PeerID
|
||||
port, portExists := qp.Params.Port, qp.Exists.Port
|
||||
uploaded, uploadedExists := qp.Params.Uploaded, qp.Exists.Uploaded
|
||||
downloaded, downloadedExists := qp.Params.Downloaded, qp.Exists.Downloaded
|
||||
left, leftExists := qp.Params.Left, qp.Exists.Left
|
||||
|
||||
if infoHashes == nil {
|
||||
if len(infoHashes) == 0 {
|
||||
failure("Malformed request - missing info_hash", buf, 1*time.Hour)
|
||||
return http.StatusOK // Required by torrent clients to interpret failure response
|
||||
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
|
||||
} else if len(infoHashes) > 1 {
|
||||
failure("Malformed request - can only announce singular info_hash", buf, 1*time.Hour)
|
||||
return http.StatusOK // Required by torrent clients to interpret failure response
|
||||
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
|
||||
}
|
||||
|
||||
if peerID == "" {
|
||||
if len(peerID) == 0 {
|
||||
failure("Malformed request - missing peer_id", buf, 1*time.Hour)
|
||||
return http.StatusOK // Required by torrent clients to interpret failure response
|
||||
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
|
||||
}
|
||||
|
||||
if len(peerID) != 20 {
|
||||
failure("Malformed request - invalid peer_id", buf, 1*time.Hour)
|
||||
return http.StatusOK // Required by torrent clients to interpret failure response
|
||||
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
|
||||
}
|
||||
|
||||
if !portExists {
|
||||
failure("Malformed request - missing port", buf, 1*time.Hour)
|
||||
return http.StatusOK // Required by torrent clients to interpret failure response
|
||||
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
|
||||
}
|
||||
|
||||
if strictPort && port < 1024 || port > 65535 {
|
||||
failure(fmt.Sprintf("Malformed request - port outside of acceptable range (port: %d)", port), buf, 1*time.Hour)
|
||||
return http.StatusOK // Required by torrent clients to interpret failure response
|
||||
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
|
||||
}
|
||||
|
||||
if !uploadedExists {
|
||||
failure("Malformed request - missing uploaded", buf, 1*time.Hour)
|
||||
return http.StatusOK // Required by torrent clients to interpret failure response
|
||||
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
|
||||
}
|
||||
|
||||
if !downloadedExists {
|
||||
failure("Malformed request - missing downloaded", buf, 1*time.Hour)
|
||||
return http.StatusOK // Required by torrent clients to interpret failure response
|
||||
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
|
||||
}
|
||||
|
||||
if !leftExists {
|
||||
failure("Malformed request - missing left", buf, 1*time.Hour)
|
||||
return http.StatusOK // Required by torrent clients to interpret failure response
|
||||
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
|
||||
}
|
||||
|
||||
ipAddr := func() string {
|
||||
ipV4, existsV4 := getPublicIPV4(qp.Get("ipv4")) // First try to get IPv4 address if client sent it
|
||||
ip, exists := getPublicIPV4(qp.Get("ip")) // ... then try to get public IP if sent by client
|
||||
ipV4, existsV4 := getPublicIPV4(qp.Params.IPv4, qp.Exists.IPv4) // First try to get IPv4 address if client sent it
|
||||
ip, exists := getPublicIPV4(qp.Params.IP, qp.Exists.IP) // ... then try to get public IP if sent by client
|
||||
|
||||
// Fail if ip and ipv4 are not same, and both are provided
|
||||
if (existsV4 && exists) && (ip != ipV4) {
|
||||
|
@ -188,15 +189,17 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
|
|||
|
||||
// Check for proxied IP header
|
||||
if proxyHeader, exists := config.Section("http").Get("proxy_header", ""); exists {
|
||||
if ips, exists := header[proxyHeader]; exists && len(ips) > 0 {
|
||||
return ips[0]
|
||||
if ips := header.PeekAll(proxyHeader); len(ips) > 0 {
|
||||
return string(ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Check for IP in socket
|
||||
portIndex := strings.LastIndex(remoteAddr, ":")
|
||||
remoteAddrString := remoteAddr.String()
|
||||
portIndex := strings.LastIndex(remoteAddrString, ":")
|
||||
|
||||
if portIndex != -1 {
|
||||
return remoteAddr[0:portIndex]
|
||||
return remoteAddrString[0:portIndex]
|
||||
}
|
||||
|
||||
// Everything else failed
|
||||
|
@ -206,45 +209,36 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
|
|||
ipBytes := net.ParseIP(ipAddr).To4()
|
||||
if nil == ipBytes {
|
||||
failure(fmt.Sprintf("Failed to parse IP address (ip: %s)", ipAddr), buf, 1*time.Hour)
|
||||
return http.StatusOK // Required by torrent clients to interpret failure response
|
||||
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
|
||||
}
|
||||
|
||||
clientID, matched := clientApproved(peerID, db)
|
||||
if !matched {
|
||||
failure(fmt.Sprintf("Your client is not approved (peer_id: %s)", peerID), buf, 1*time.Hour)
|
||||
return http.StatusOK // Required by torrent clients to interpret failure response
|
||||
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
|
||||
}
|
||||
|
||||
if !db.TorrentsLock.RTryLockWithContext(ctx) {
|
||||
return http.StatusRequestTimeout
|
||||
}
|
||||
defer db.TorrentsLock.RUnlock()
|
||||
|
||||
torrent, exists := db.Torrents[infoHashes[0]]
|
||||
torrent, exists := (*db.Torrents.Load())[infoHashes[0]]
|
||||
if !exists {
|
||||
failure("This torrent does not exist", buf, 5*time.Minute)
|
||||
return http.StatusOK // Required by torrent clients to interpret failure response
|
||||
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
|
||||
}
|
||||
|
||||
// Take torrent lock to read/write on it to prevent race conditions
|
||||
torrent.Lock()
|
||||
defer torrent.Unlock()
|
||||
if torrentStatus := torrent.Status.Load(); torrentStatus == 1 && left == 0 {
|
||||
log.Info.Printf("Unpruning torrent %d", torrent.ID.Load())
|
||||
|
||||
if torrent.Status == 1 && left == 0 {
|
||||
log.Info.Printf("Unpruning torrent %d", torrent.ID)
|
||||
|
||||
torrent.Status = 0
|
||||
torrent.Status.Store(0)
|
||||
|
||||
/* It is okay to do this asynchronously as tracker's internal in-memory state has already been updated for this
|
||||
torrent. While it is technically possible that we will do this more than once in some cases, the state is of
|
||||
boolean type so there is no risk of data loss. */
|
||||
go db.UnPrune(torrent)
|
||||
} else if torrent.Status != 0 {
|
||||
failure(fmt.Sprintf("This torrent does not exist (status: %d, left: %d)", torrent.Status, left), buf, 15*time.Minute)
|
||||
return http.StatusOK // Required by torrent clients to interpret failure response
|
||||
} else if torrentStatus != 0 {
|
||||
failure(fmt.Sprintf("This torrent does not exist (status: %d, left: %d)", torrentStatus, left), buf, 15*time.Minute)
|
||||
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
|
||||
}
|
||||
|
||||
numWant, exists := qp.GetUint16("numwant")
|
||||
numWant, exists := qp.Params.NumWant, qp.Exists.NumWant
|
||||
if !exists {
|
||||
numWant = uint16(defaultNumWant)
|
||||
} else if numWant > uint16(maxNumWant) {
|
||||
|
@ -259,15 +253,21 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
|
|||
active = true
|
||||
)
|
||||
|
||||
event, _ := qp.Get("event")
|
||||
event := qp.Params.Event
|
||||
completed := event == "completed"
|
||||
|
||||
peerKey := cdb.NewPeerKey(user.ID, cdb.PeerIDFromRawString(peerID))
|
||||
peerKey := cdb.NewPeerKey(user.ID.Load(), cdb.PeerIDFromRawString(peerID))
|
||||
|
||||
// Take torrent peers lock to read/write on it to prevent race conditions
|
||||
if !torrent.TryLockWithContext(ctx) {
|
||||
return fasthttp.StatusRequestTimeout
|
||||
}
|
||||
defer torrent.Unlock()
|
||||
|
||||
if left > 0 {
|
||||
if isDisabledDownload(db, user, torrent) {
|
||||
failure("Your download privileges are disabled", buf, 1*time.Hour)
|
||||
return http.StatusOK // Required by torrent clients to interpret failure response
|
||||
return fasthttp.StatusOK // Required by torrent clients to interpret failure response
|
||||
}
|
||||
|
||||
peer, exists = torrent.Leechers[peerKey]
|
||||
|
@ -275,6 +275,8 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
|
|||
newPeer = true
|
||||
peer = &cdb.Peer{}
|
||||
torrent.Leechers[peerKey] = peer
|
||||
|
||||
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
|
||||
}
|
||||
} else if completed {
|
||||
peer, exists = torrent.Leechers[peerKey]
|
||||
|
@ -282,10 +284,15 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
|
|||
newPeer = true
|
||||
peer = &cdb.Peer{}
|
||||
torrent.Seeders[peerKey] = peer
|
||||
|
||||
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
|
||||
} else {
|
||||
// Previously tracked peer is now a seeder
|
||||
torrent.Seeders[peerKey] = peer
|
||||
delete(torrent.Leechers, peerKey)
|
||||
|
||||
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
|
||||
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
|
||||
}
|
||||
seeding = true
|
||||
} else {
|
||||
|
@ -296,12 +303,17 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
|
|||
newPeer = true
|
||||
peer = &cdb.Peer{}
|
||||
torrent.Seeders[peerKey] = peer
|
||||
|
||||
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
|
||||
} else {
|
||||
/* Previously tracked peer is now a seeder, however we never received their "completed" event.
|
||||
Broken client? Unreported snatch? Cross-seeding? Let's not report it as snatch to avoid
|
||||
over-reporting for cross-seeding */
|
||||
torrent.Seeders[peerKey] = peer
|
||||
delete(torrent.Leechers, peerKey)
|
||||
|
||||
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
|
||||
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
|
||||
}
|
||||
}
|
||||
seeding = true
|
||||
|
@ -310,8 +322,8 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
|
|||
// Update peer info/stats
|
||||
if newPeer {
|
||||
peer.ID = peerKey.PeerID()
|
||||
peer.UserID = user.ID
|
||||
peer.TorrentID = torrent.ID
|
||||
peer.UserID = user.ID.Load()
|
||||
peer.TorrentID = torrent.ID.Load()
|
||||
peer.StartTime = now
|
||||
peer.LastAnnounce = now
|
||||
peer.Uploaded = uploaded
|
||||
|
@ -333,19 +345,19 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
|
|||
torrentGroupDownMultiplier := 1.0
|
||||
torrentGroupUpMultiplier := 1.0
|
||||
|
||||
if torrentGroupFreeleech, exists := (*db.TorrentGroupFreeleech.Load())[torrent.Group]; exists {
|
||||
if torrentGroupFreeleech, exists := (*db.TorrentGroupFreeleech.Load())[torrent.Group.Key()]; exists {
|
||||
torrentGroupDownMultiplier = torrentGroupFreeleech.DownMultiplier
|
||||
torrentGroupUpMultiplier = torrentGroupFreeleech.UpMultiplier
|
||||
}
|
||||
|
||||
var deltaDownload int64
|
||||
if !database.GlobalFreeleech.Load() {
|
||||
deltaDownload = int64(float64(rawDeltaDownload) * math.Abs(user.DownMultiplier) *
|
||||
math.Abs(torrentGroupDownMultiplier) * math.Abs(torrent.DownMultiplier))
|
||||
deltaDownload = int64(float64(rawDeltaDownload) * math.Abs(math.Float64frombits(user.DownMultiplier.Load())) *
|
||||
math.Abs(torrentGroupDownMultiplier) * math.Abs(math.Float64frombits(torrent.DownMultiplier.Load())))
|
||||
}
|
||||
|
||||
deltaUpload := int64(float64(rawDeltaUpload) * math.Abs(user.UpMultiplier) *
|
||||
math.Abs(torrentGroupUpMultiplier) * math.Abs(torrent.UpMultiplier))
|
||||
deltaUpload := int64(float64(rawDeltaUpload) * math.Abs(math.Float64frombits(user.UpMultiplier.Load())) *
|
||||
math.Abs(torrentGroupUpMultiplier) * math.Abs(math.Float64frombits(torrent.UpMultiplier.Load())))
|
||||
peer.Uploaded = uploaded
|
||||
peer.Downloaded = downloaded
|
||||
peer.Left = left
|
||||
|
@ -369,7 +381,7 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
|
|||
/* Update torrent last_action only if announced action is seeding.
|
||||
This allows dead torrents without seeder but with leecher to be proeprly pruned */
|
||||
if seeding {
|
||||
torrent.LastAction = now
|
||||
torrent.LastAction.Store(now)
|
||||
}
|
||||
|
||||
var deltaSnatch uint8
|
||||
|
@ -380,17 +392,19 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
|
|||
should be gone, allowing the peer to be GC'd. */
|
||||
if seeding {
|
||||
delete(torrent.Seeders, peerKey)
|
||||
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
|
||||
} else {
|
||||
delete(torrent.Leechers, peerKey)
|
||||
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
|
||||
}
|
||||
|
||||
active = false
|
||||
} else if completed {
|
||||
go db.QueueSnatch(peer, now)
|
||||
db.QueueSnatch(peer, now)
|
||||
deltaSnatch = 1
|
||||
}
|
||||
|
||||
if user.TrackerHide {
|
||||
if user.TrackerHide.Load() {
|
||||
peer.Addr = cdb.NewPeerAddressFromIPPort(net.IP{127, 0, 0, 1}, port) // 127.0.0.1
|
||||
} else {
|
||||
peer.Addr = cdb.NewPeerAddressFromIPPort(ipBytes, port)
|
||||
|
@ -406,7 +420,7 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
|
|||
|
||||
go record.Record(
|
||||
peer.TorrentID,
|
||||
user.ID,
|
||||
user.ID.Load(),
|
||||
ipAddr,
|
||||
port,
|
||||
event,
|
||||
|
@ -418,9 +432,9 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
|
|||
left)
|
||||
|
||||
// Generate response
|
||||
seedCount := len(torrent.Seeders)
|
||||
leechCount := len(torrent.Leechers)
|
||||
snatchCount := torrent.Snatched
|
||||
seedCount := int(torrent.SeedersLength.Load())
|
||||
leechCount := int(torrent.LeechersLength.Load())
|
||||
snatchCount := uint16(torrent.Snatched.Load())
|
||||
|
||||
response := make(map[string]interface{})
|
||||
response["complete"] = seedCount
|
||||
|
@ -435,11 +449,9 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
|
|||
response["interval"] = announceInterval + announceDrift
|
||||
|
||||
if numWant > 0 && active {
|
||||
compactString, exists := qp.Get("compact")
|
||||
compact := !exists || compactString != "0" // Defaults to being compact
|
||||
compact := !qp.Exists.Compact || !qp.Params.Compact
|
||||
|
||||
noPeerIDString, exists := qp.Get("no_peer_id")
|
||||
noPeerID := exists && noPeerIDString == "1"
|
||||
noPeerID := qp.Exists.NoPeerID && qp.Params.NoPeerID
|
||||
|
||||
var peerCount int
|
||||
if seeding {
|
||||
|
@ -520,14 +532,14 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
|
|||
// Early exit before response write
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return http.StatusRequestTimeout
|
||||
return fasthttp.StatusRequestTimeout
|
||||
default:
|
||||
}
|
||||
|
||||
encoder := bencode.NewEncoder(buf)
|
||||
if err = encoder.Encode(response); err != nil {
|
||||
if err := encoder.Encode(response); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return http.StatusOK
|
||||
return fasthttp.StatusOK
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ package server
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"github.com/valyala/fasthttp"
|
||||
"time"
|
||||
|
||||
"chihaya/collectors"
|
||||
|
@ -31,39 +31,28 @@ import (
|
|||
"github.com/prometheus/common/expfmt"
|
||||
)
|
||||
|
||||
var bearerPrefix = "Bearer "
|
||||
var bearerPrefix = []byte("Bearer ")
|
||||
|
||||
func metrics(ctx context.Context, auth string, db *database.Database, buf *bytes.Buffer) int {
|
||||
if !db.UsersLock.RTryLockWithContext(ctx) {
|
||||
return http.StatusRequestTimeout
|
||||
}
|
||||
defer db.UsersLock.RUnlock()
|
||||
|
||||
if !db.TorrentsLock.RTryLockWithContext(ctx) {
|
||||
return http.StatusRequestTimeout
|
||||
}
|
||||
defer db.TorrentsLock.RUnlock()
|
||||
func metrics(ctx context.Context, auth []byte, db *database.Database, buf *bytes.Buffer) int {
|
||||
dbUsers := *db.Users.Load()
|
||||
dbTorrents := *db.Torrents.Load()
|
||||
|
||||
peers := 0
|
||||
|
||||
for _, t := range db.Torrents {
|
||||
func() {
|
||||
t.RLock()
|
||||
defer t.RUnlock()
|
||||
peers += len(t.Leechers) + len(t.Seeders)
|
||||
}()
|
||||
for _, t := range dbTorrents {
|
||||
peers += int(t.LeechersLength.Load()) + int(t.SeedersLength.Load())
|
||||
}
|
||||
|
||||
// Early exit before response write
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return http.StatusRequestTimeout
|
||||
return fasthttp.StatusRequestTimeout
|
||||
default:
|
||||
}
|
||||
|
||||
collectors.UpdateUptime(time.Since(handler.startTime).Seconds())
|
||||
collectors.UpdateUsers(len(db.Users))
|
||||
collectors.UpdateTorrents(len(db.Torrents))
|
||||
collectors.UpdateUsers(len(dbUsers))
|
||||
collectors.UpdateTorrents(len(dbTorrents))
|
||||
collectors.UpdateClients(len(*db.Clients.Load()))
|
||||
collectors.UpdateHitAndRuns(len(*db.HitAndRuns.Load()))
|
||||
collectors.UpdatePeers(peers)
|
||||
|
@ -79,9 +68,9 @@ func metrics(ctx context.Context, auth string, db *database.Database, buf *bytes
|
|||
}
|
||||
|
||||
n := len(bearerPrefix)
|
||||
if len(auth) > n && auth[:n] == bearerPrefix {
|
||||
if len(auth) > n && bytes.Equal(auth[:n], bearerPrefix) {
|
||||
adminToken, exists := config.Section("http").Get("admin_token", "")
|
||||
if exists && auth[n:] == adminToken {
|
||||
if exists && bytes.Equal(auth[:n], []byte(adminToken)) {
|
||||
mfs, _ := prometheus.DefaultGatherer.Gather()
|
||||
|
||||
for _, mf := range mfs {
|
||||
|
@ -93,5 +82,5 @@ func metrics(ctx context.Context, auth string, db *database.Database, buf *bytes
|
|||
}
|
||||
}
|
||||
|
||||
return http.StatusOK
|
||||
return fasthttp.StatusOK
|
||||
}
|
||||
|
|
|
@ -19,99 +19,153 @@
|
|||
package params
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"bytes"
|
||||
cdb "chihaya/database/types"
|
||||
"github.com/valyala/fasthttp"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type QueryParam struct {
|
||||
query string
|
||||
params map[string]string
|
||||
infoHashes []cdb.TorrentHash
|
||||
}
|
||||
Params struct {
|
||||
Uploaded uint64
|
||||
Downloaded uint64
|
||||
Left uint64
|
||||
|
||||
func ParseQuery(query string) (qp *QueryParam, err error) {
|
||||
qp = &QueryParam{
|
||||
query: query,
|
||||
infoHashes: nil,
|
||||
params: make(map[string]string),
|
||||
Port uint16
|
||||
NumWant uint16
|
||||
|
||||
PeerID string
|
||||
IPv4 string
|
||||
IP string
|
||||
Event string
|
||||
|
||||
TestGarbageUnescape string
|
||||
|
||||
Compact bool
|
||||
NoPeerID bool
|
||||
|
||||
InfoHashes []cdb.TorrentHash
|
||||
}
|
||||
|
||||
for query != "" {
|
||||
key := query
|
||||
if i := strings.Index(key, "&"); i >= 0 {
|
||||
key, query = key[:i], key[i+1:]
|
||||
} else {
|
||||
query = ""
|
||||
Exists struct {
|
||||
Uploaded bool
|
||||
Downloaded bool
|
||||
Left bool
|
||||
|
||||
Port bool
|
||||
NumWant bool
|
||||
|
||||
PeerID bool
|
||||
IPv4 bool
|
||||
IP bool
|
||||
Event bool
|
||||
|
||||
TestGarbageUnescape bool
|
||||
|
||||
Compact bool
|
||||
NoPeerID bool
|
||||
|
||||
InfoHashes bool
|
||||
}
|
||||
}
|
||||
|
||||
var uploadedKey = []byte("uploaded")
|
||||
var downloadedKey = []byte("downloaded")
|
||||
var leftKey = []byte("left")
|
||||
|
||||
var portKey = []byte("port")
|
||||
var numWant = []byte("numwant")
|
||||
|
||||
var peerIDKey = []byte("peer_id")
|
||||
var ipv4Key = []byte("ipv4")
|
||||
var ipKey = []byte("ip")
|
||||
var eventKey = []byte("event")
|
||||
|
||||
var testGarbageUnescapeKey = []byte("!@#")
|
||||
|
||||
var infoHashKey = []byte("info_hash")
|
||||
|
||||
var compactKey = []byte("compact")
|
||||
var noPeerIDKey = []byte("no_peer_id")
|
||||
|
||||
func ParseQuery(queryArgs *fasthttp.Args) (qp QueryParam, err error) {
|
||||
var returnError error
|
||||
|
||||
queryArgs.VisitAll(func(key, value []byte) {
|
||||
if returnError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
value := ""
|
||||
if i := strings.Index(key, "="); i >= 0 {
|
||||
key, value = key[:i], key[i+1:]
|
||||
}
|
||||
|
||||
key, err = url.QueryUnescape(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
value, err = url.QueryUnescape(value)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if key == "info_hash" {
|
||||
if len(value) == cdb.TorrentHashSize {
|
||||
qp.infoHashes = append(qp.infoHashes, cdb.TorrentHashFromBytes([]byte(value)))
|
||||
key = bytes.ToLower(key)
|
||||
switch true {
|
||||
case bytes.Equal(key, uploadedKey):
|
||||
n, err := strconv.ParseUint(string(value), 10, 0)
|
||||
if err != nil {
|
||||
returnError = err
|
||||
return
|
||||
}
|
||||
} else {
|
||||
qp.params[strings.ToLower(key)] = value
|
||||
qp.Params.Uploaded = n
|
||||
qp.Exists.Uploaded = true
|
||||
case bytes.Equal(key, downloadedKey):
|
||||
n, err := strconv.ParseUint(string(value), 10, 0)
|
||||
if err != nil {
|
||||
returnError = err
|
||||
return
|
||||
}
|
||||
qp.Params.Downloaded = n
|
||||
qp.Exists.Downloaded = true
|
||||
case bytes.Equal(key, leftKey):
|
||||
n, err := strconv.ParseUint(string(value), 10, 0)
|
||||
if err != nil {
|
||||
returnError = err
|
||||
return
|
||||
}
|
||||
qp.Params.Left = n
|
||||
qp.Exists.Left = true
|
||||
case bytes.Equal(key, portKey):
|
||||
n, err := strconv.ParseUint(string(value), 10, 16)
|
||||
if err != nil {
|
||||
returnError = err
|
||||
return
|
||||
}
|
||||
qp.Params.Port = uint16(n)
|
||||
qp.Exists.Port = true
|
||||
case bytes.Equal(key, numWant):
|
||||
n, err := strconv.ParseUint(string(value), 10, 16)
|
||||
if err != nil {
|
||||
returnError = err
|
||||
return
|
||||
}
|
||||
qp.Params.NumWant = uint16(n)
|
||||
qp.Exists.NumWant = true
|
||||
case bytes.Equal(key, peerIDKey):
|
||||
qp.Params.PeerID = string(value)
|
||||
qp.Exists.PeerID = true
|
||||
case bytes.Equal(key, ipv4Key):
|
||||
qp.Params.IPv4 = string(value)
|
||||
qp.Exists.IPv4 = true
|
||||
case bytes.Equal(key, ipKey):
|
||||
qp.Params.IP = string(value)
|
||||
qp.Exists.IP = true
|
||||
case bytes.Equal(key, eventKey):
|
||||
qp.Params.Event = string(value)
|
||||
qp.Exists.Event = true
|
||||
case bytes.Equal(key, testGarbageUnescapeKey):
|
||||
qp.Params.TestGarbageUnescape = string(value)
|
||||
qp.Exists.TestGarbageUnescape = true
|
||||
case bytes.Equal(key, infoHashKey):
|
||||
if len(value) == cdb.TorrentHashSize {
|
||||
qp.Params.InfoHashes = append(qp.Params.InfoHashes, cdb.TorrentHashFromBytes(value))
|
||||
qp.Exists.InfoHashes = true
|
||||
}
|
||||
case bytes.Equal(key, compactKey):
|
||||
qp.Params.Compact = bytes.Equal(value, []byte{'1'})
|
||||
qp.Exists.Compact = true
|
||||
case bytes.Equal(key, noPeerIDKey):
|
||||
qp.Params.NoPeerID = bytes.Equal(value, []byte{'1'})
|
||||
qp.Exists.NoPeerID = true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return qp, nil
|
||||
}
|
||||
|
||||
func (qp *QueryParam) getUint(which string, bitSize int) (ret uint64, exists bool) {
|
||||
str, exists := qp.params[which]
|
||||
if exists {
|
||||
var err error
|
||||
|
||||
ret, err = strconv.ParseUint(str, 10, bitSize)
|
||||
if err != nil {
|
||||
exists = false
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (qp *QueryParam) Get(which string) (ret string, exists bool) {
|
||||
ret, exists = qp.params[which]
|
||||
return
|
||||
}
|
||||
|
||||
func (qp *QueryParam) GetUint64(which string) (ret uint64, exists bool) {
|
||||
return qp.getUint(which, 64)
|
||||
}
|
||||
|
||||
func (qp *QueryParam) GetUint16(which string) (ret uint16, exists bool) {
|
||||
tmp, exists := qp.getUint(which, 16)
|
||||
ret = uint16(tmp)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (qp *QueryParam) InfoHashes() []cdb.TorrentHash {
|
||||
return qp.infoHashes
|
||||
}
|
||||
|
||||
func (qp *QueryParam) RawQuery() string {
|
||||
return qp.query
|
||||
return qp, returnError
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package params
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/valyala/fasthttp"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
|
@ -44,83 +45,95 @@ func TestMain(m *testing.M) {
|
|||
}
|
||||
|
||||
func TestParseQuery(t *testing.T) {
|
||||
query := ""
|
||||
var queryParsed QueryParam
|
||||
queryParsed.Params.Event, queryParsed.Exists.Event = "completed", true
|
||||
queryParsed.Params.Port, queryParsed.Exists.Port = 25362, true
|
||||
queryParsed.Params.PeerID, queryParsed.Exists.Left = "-CH010-VnpZR7uz31I1A", true
|
||||
queryParsed.Params.Left, queryParsed.Exists.Left = 0, true
|
||||
|
||||
query := fmt.Sprintf("event=%s&port=%d&peer_id=%s&left=%d",
|
||||
queryParsed.Params.Event,
|
||||
queryParsed.Params.Port,
|
||||
queryParsed.Params.PeerID,
|
||||
queryParsed.Params.Left,
|
||||
)
|
||||
|
||||
for _, infoHash := range infoHashes {
|
||||
query += "info_hash=" + url.QueryEscape(string(infoHash[:])) + "&"
|
||||
queryParsed.Params.InfoHashes = append(queryParsed.Params.InfoHashes, infoHash)
|
||||
queryParsed.Exists.InfoHashes = true
|
||||
query += "&info_hash=" + url.QueryEscape(string(infoHash[:]))
|
||||
}
|
||||
|
||||
queryMap := make(map[string]string)
|
||||
queryMap["event"] = "completed"
|
||||
queryMap["port"] = "25362"
|
||||
queryMap["peer_id"] = "-CH010-VnpZR7uz31I1A"
|
||||
queryMap["left"] = "0"
|
||||
args := fasthttp.Args{}
|
||||
args.Parse(query)
|
||||
|
||||
for k, v := range queryMap {
|
||||
query += k + "=" + v + "&"
|
||||
}
|
||||
|
||||
query = query[:len(query)-1]
|
||||
|
||||
qp, err := ParseQuery(query)
|
||||
qp, err := ParseQuery(&args)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(qp.params, queryMap) {
|
||||
t.Fatalf("Parsed query map (%v) is not deeply equal as original (%v)!", qp.params, queryMap)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(qp.infoHashes, infoHashes) {
|
||||
t.Fatalf("Parsed info hashes (%v) are not deeply equal as original (%v)!", qp.infoHashes, infoHashes)
|
||||
if !reflect.DeepEqual(qp, queryParsed) {
|
||||
t.Fatalf("Parsed query map (%v) is not deeply equal as original (%v)!", qp, queryParsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokenParseQuery(t *testing.T) {
|
||||
brokenQueryMap := make(map[string]string)
|
||||
brokenQueryMap["event"] = "started"
|
||||
brokenQueryMap["bug"] = ""
|
||||
brokenQueryMap["yes"] = ""
|
||||
var brokenQueryParsed QueryParam
|
||||
brokenQueryParsed.Params.Event, brokenQueryParsed.Exists.Event = "started", true
|
||||
brokenQueryParsed.Params.IPv4, brokenQueryParsed.Exists.IPv4 = "", true
|
||||
brokenQueryParsed.Params.IP, brokenQueryParsed.Exists.IP = "", true
|
||||
|
||||
qp, err := ParseQuery("event=started&bug=&yes=")
|
||||
args := fasthttp.Args{}
|
||||
args.Parse("event=started&ipv4=&ip=")
|
||||
|
||||
qp, err := ParseQuery(&args)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(qp.params, brokenQueryMap) {
|
||||
t.Fatalf("Parsed query map (%v) is not deeply equal as original (%v)!", qp.params, brokenQueryMap)
|
||||
if !reflect.DeepEqual(qp, brokenQueryParsed) {
|
||||
t.Fatalf("Parsed query map (%v) is not deeply equal as original (%v)!", qp, brokenQueryParsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLowerKey(t *testing.T) {
|
||||
qp, err := ParseQuery("EvEnT=c0mPl3tED")
|
||||
args := fasthttp.Args{}
|
||||
args.Parse("EvEnT=c0mPl3tED")
|
||||
|
||||
qp, err := ParseQuery(&args)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if param, exists := qp.Get("event"); !exists || param != "c0mPl3tED" {
|
||||
if param, exists := qp.Params.Event, qp.Exists.Event; !exists || param != "c0mPl3tED" {
|
||||
t.Fatalf("Got parsed value %s but expected c0mPl3tED for \"event\"!", param)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnescape(t *testing.T) {
|
||||
qp, err := ParseQuery("%21%40%23=%24%25%5E")
|
||||
args := fasthttp.Args{}
|
||||
args.Parse("%21%40%23=%24%25%5E")
|
||||
|
||||
qp, err := ParseQuery(&args)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if param, exists := qp.Get("!@#"); !exists || param != "$%^" {
|
||||
if param, exists := qp.Params.TestGarbageUnescape, qp.Exists.TestGarbageUnescape; !exists || param != "$%^" {
|
||||
t.Fatal(fmt.Sprintf("Got parsed value %s but expected", param), "$%^ for \"!@#\"!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
qp, err := ParseQuery("event=completed")
|
||||
func TestString(t *testing.T) {
|
||||
args := fasthttp.Args{}
|
||||
args.Parse("event=completed")
|
||||
|
||||
qp, err := ParseQuery(&args)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if param, exists := qp.Get("event"); !exists || param != "completed" {
|
||||
if param, exists := qp.Params.Event, qp.Exists.Event; !exists || param != "completed" {
|
||||
t.Fatalf("Got parsed value %s but expected completed for \"event\"!", param)
|
||||
}
|
||||
}
|
||||
|
@ -128,12 +141,15 @@ func TestGet(t *testing.T) {
|
|||
func TestGetUint64(t *testing.T) {
|
||||
val := uint64(1<<62 + 42)
|
||||
|
||||
qp, err := ParseQuery("left=" + strconv.FormatUint(val, 10))
|
||||
args := fasthttp.Args{}
|
||||
args.Parse("left=" + strconv.FormatUint(val, 10))
|
||||
|
||||
qp, err := ParseQuery(&args)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if parsedVal, exists := qp.GetUint64("left"); !exists || parsedVal != val {
|
||||
if parsedVal, exists := qp.Params.Left, qp.Exists.Left; !exists || parsedVal != val {
|
||||
t.Fatalf("Got parsed value %v but expected %v for \"left\"!", parsedVal, val)
|
||||
}
|
||||
}
|
||||
|
@ -141,12 +157,15 @@ func TestGetUint64(t *testing.T) {
|
|||
func TestGetUint16(t *testing.T) {
|
||||
val := uint16(1<<15 + 4242)
|
||||
|
||||
qp, err := ParseQuery("port=" + strconv.FormatUint(uint64(val), 10))
|
||||
args := fasthttp.Args{}
|
||||
args.Parse("port=" + strconv.FormatUint(uint64(val), 10))
|
||||
|
||||
qp, err := ParseQuery(&args)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if parsedVal, exists := qp.GetUint16("port"); !exists || parsedVal != val {
|
||||
if parsedVal, exists := qp.Params.Port, qp.Exists.Port; !exists || parsedVal != val {
|
||||
t.Fatalf("Got parsed value %v but expected %v for \"port\"!", parsedVal, val)
|
||||
}
|
||||
}
|
||||
|
@ -160,25 +179,15 @@ func TestInfoHashes(t *testing.T) {
|
|||
|
||||
query = query[:len(query)-1]
|
||||
|
||||
qp, err := ParseQuery(query)
|
||||
args := fasthttp.Args{}
|
||||
args.Parse(query)
|
||||
|
||||
qp, err := ParseQuery(&args)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(qp.InfoHashes(), infoHashes) {
|
||||
t.Fatalf("Parsed info hashes (%v) are not deeply equal as original (%v)!", qp.InfoHashes(), infoHashes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRawQuery(t *testing.T) {
|
||||
q := "event=completed&port=25541&left=0&uploaded=0&downloaded=0"
|
||||
|
||||
qp, err := ParseQuery(q)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if rq := qp.RawQuery(); rq != q {
|
||||
t.Fatalf("Got raw query %s but expected %s", rq, q)
|
||||
if !reflect.DeepEqual(qp.Params.InfoHashes, infoHashes) {
|
||||
t.Fatalf("Parsed info hashes (%v) are not deeply equal as original (%v)!", qp.Params.InfoHashes, infoHashes)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,13 +19,12 @@ package server
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"chihaya/config"
|
||||
"chihaya/database"
|
||||
cdb "chihaya/database/types"
|
||||
"chihaya/server/params"
|
||||
"context"
|
||||
"github.com/valyala/fasthttp"
|
||||
"github.com/zeebo/bencode"
|
||||
)
|
||||
|
||||
|
@ -37,19 +36,18 @@ func init() {
|
|||
}
|
||||
|
||||
func writeScrapeInfo(torrent *cdb.Torrent) map[string]interface{} {
|
||||
torrent.RLock()
|
||||
defer torrent.RUnlock()
|
||||
|
||||
ret := make(map[string]interface{})
|
||||
ret["complete"] = len(torrent.Seeders)
|
||||
ret["downloaded"] = torrent.Snatched
|
||||
ret["incomplete"] = len(torrent.Leechers)
|
||||
ret["complete"] = torrent.SeedersLength.Load()
|
||||
ret["downloaded"] = torrent.Snatched.Load()
|
||||
ret["incomplete"] = torrent.LeechersLength.Load()
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func scrape(ctx context.Context, qs string, user *cdb.User, db *database.Database, buf *bytes.Buffer) int {
|
||||
qp, err := params.ParseQuery(qs)
|
||||
func scrape(
|
||||
ctx context.Context, queryArgs *fasthttp.Args,
|
||||
user *cdb.User, db *database.Database, buf *bytes.Buffer) int {
|
||||
qp, err := params.ParseQuery(queryArgs)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -57,14 +55,11 @@ func scrape(ctx context.Context, qs string, user *cdb.User, db *database.Databas
|
|||
scrapeData := make(map[string]interface{})
|
||||
fileData := make(map[cdb.TorrentHash]interface{})
|
||||
|
||||
if !db.TorrentsLock.RTryLockWithContext(ctx) {
|
||||
return http.StatusRequestTimeout
|
||||
}
|
||||
defer db.TorrentsLock.RUnlock()
|
||||
dbTorrents := *db.Torrents.Load()
|
||||
|
||||
if qp.InfoHashes() != nil {
|
||||
for _, infoHash := range qp.InfoHashes() {
|
||||
torrent, exists := db.Torrents[infoHash]
|
||||
if len(qp.Params.InfoHashes) > 0 {
|
||||
for _, infoHash := range qp.Params.InfoHashes {
|
||||
torrent, exists := dbTorrents[infoHash]
|
||||
if exists {
|
||||
if !isDisabledDownload(db, user, torrent) {
|
||||
fileData[infoHash] = writeScrapeInfo(torrent)
|
||||
|
@ -86,14 +81,14 @@ func scrape(ctx context.Context, qs string, user *cdb.User, db *database.Databas
|
|||
// Early exit before response write
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return http.StatusRequestTimeout
|
||||
return fasthttp.StatusRequestTimeout
|
||||
default:
|
||||
}
|
||||
|
||||
encoder := bencode.NewEncoder(buf)
|
||||
if err = encoder.Encode(scrapeData); err != nil {
|
||||
if err := encoder.Encode(scrapeData); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return http.StatusOK
|
||||
return fasthttp.StatusOK
|
||||
}
|
||||
|
|
101
server/server.go
101
server/server.go
|
@ -20,10 +20,9 @@ package server
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"github.com/valyala/fasthttp"
|
||||
"net"
|
||||
"net/http"
|
||||
"path"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -63,10 +62,10 @@ var (
|
|||
listener net.Listener
|
||||
)
|
||||
|
||||
func (handler *httpHandler) respond(ctx context.Context, r *http.Request, buf *bytes.Buffer) int {
|
||||
dir, action := path.Split(r.URL.Path)
|
||||
func (handler *httpHandler) respond(ctx context.Context, requestContext *fasthttp.RequestCtx, buf *bytes.Buffer) int {
|
||||
dir, action := path.Split(string(requestContext.Request.URI().Path()))
|
||||
if action == "" {
|
||||
return http.StatusNotFound
|
||||
return fasthttp.StatusNotFound
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -82,7 +81,7 @@ func (handler *httpHandler) respond(ctx context.Context, r *http.Request, buf *b
|
|||
return alive(buf)
|
||||
}
|
||||
|
||||
return http.StatusNotFound
|
||||
return fasthttp.StatusNotFound
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -91,31 +90,36 @@ func (handler *httpHandler) respond(ctx context.Context, r *http.Request, buf *b
|
|||
* ===================================================
|
||||
*/
|
||||
|
||||
user, err := isPasskeyValid(ctx, passkey, handler.db)
|
||||
if err != nil {
|
||||
return http.StatusRequestTimeout
|
||||
} else if user == nil {
|
||||
user := isPasskeyValid(passkey, handler.db)
|
||||
if user == nil {
|
||||
failure("Your passkey is invalid", buf, 1*time.Hour)
|
||||
return http.StatusOK
|
||||
return fasthttp.StatusOK
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "announce":
|
||||
return announce(ctx, r.URL.RawQuery, r.Header, r.RemoteAddr, user, handler.db, buf)
|
||||
return announce(
|
||||
ctx,
|
||||
requestContext.Request.URI().QueryArgs(),
|
||||
&requestContext.Request.Header,
|
||||
requestContext.RemoteAddr(),
|
||||
user,
|
||||
handler.db,
|
||||
buf)
|
||||
case "scrape":
|
||||
if enabled, _ := config.GetBool("scrape", true); !enabled {
|
||||
return http.StatusNotFound
|
||||
return fasthttp.StatusNotFound
|
||||
}
|
||||
|
||||
return scrape(ctx, r.URL.RawQuery, user, handler.db, buf)
|
||||
return scrape(ctx, requestContext.Request.URI().QueryArgs(), user, handler.db, buf)
|
||||
case "metrics":
|
||||
return metrics(ctx, r.Header.Get("Authorization"), handler.db, buf)
|
||||
return metrics(ctx, requestContext.Request.Header.PeekBytes([]byte("Authorization")), handler.db, buf)
|
||||
}
|
||||
|
||||
return http.StatusNotFound
|
||||
return fasthttp.StatusNotFound
|
||||
}
|
||||
|
||||
func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
func (handler *httpHandler) ServeHTTP(requestContext *fasthttp.RequestCtx) {
|
||||
if handler.terminate {
|
||||
return
|
||||
}
|
||||
|
@ -127,9 +131,6 @@ func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
handler.waitGroup.Add(1)
|
||||
defer handler.waitGroup.Done()
|
||||
|
||||
// Flush buffered data to client
|
||||
defer w.(http.Flusher).Flush()
|
||||
|
||||
buf := handler.bufferPool.Take()
|
||||
// Mark buf to be returned to bufferPool after we are done with it
|
||||
defer handler.bufferPool.Give(buf)
|
||||
|
@ -137,46 +138,48 @@ func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
// Gracefully handle panics so that they're confined to single request and don't crash server
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
log.Error.Printf("Recovered from panicking request handler - %v\nURL was: %s", err, r.URL)
|
||||
log.Error.Printf("Recovered from panicking request handler - %v\nURL was: %s", err, requestContext.URI().String())
|
||||
log.WriteStack()
|
||||
|
||||
collectors.IncrementErroredRequests()
|
||||
|
||||
if w.Header().Get("Content-Type") == "" {
|
||||
if len(requestContext.Response.Header.ContentType()) == 0 {
|
||||
buf.Reset()
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
requestContext.Response.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Prepare and start new context with timeout to abort long-running requests
|
||||
ctx, cancel := context.WithTimeout(r.Context(), handler.contextTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), handler.contextTimeout)
|
||||
defer cancel()
|
||||
|
||||
/* Pass flow to handler; note that handler should be responsible for actually cancelling
|
||||
its own work based on request context cancellation */
|
||||
status := handler.respond(ctx, r, buf)
|
||||
status := handler.respond(ctx, requestContext, buf)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
collectors.IncrementTimeoutRequests()
|
||||
|
||||
switch err := ctx.Err(); err {
|
||||
case context.DeadlineExceeded:
|
||||
collectors.IncrementTimeoutRequests()
|
||||
|
||||
failure("Request context deadline exceeded", buf, 5*time.Minute)
|
||||
|
||||
w.Header().Add("Content-Length", strconv.Itoa(buf.Len()))
|
||||
w.Header().Add("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK) // Required by torrent clients to interpret failure response
|
||||
_, _ = w.Write(buf.Bytes())
|
||||
default:
|
||||
w.WriteHeader(http.StatusRequestTimeout)
|
||||
requestContext.Request.Header.SetContentLength(buf.Len())
|
||||
requestContext.Request.Header.SetContentTypeBytes([]byte("text/plain"))
|
||||
requestContext.Response.SetStatusCode(fasthttp.StatusOK) // Required by torrent clients to interpret failure response
|
||||
_, _ = requestContext.Write(buf.Bytes())
|
||||
case context.Canceled:
|
||||
collectors.IncrementCancelRequests()
|
||||
|
||||
requestContext.Response.SetStatusCode(fasthttp.StatusRequestTimeout)
|
||||
}
|
||||
default:
|
||||
w.Header().Add("Content-Length", strconv.Itoa(buf.Len()))
|
||||
w.Header().Add("Content-Type", "text/plain")
|
||||
w.WriteHeader(status)
|
||||
_, _ = w.Write(buf.Bytes())
|
||||
requestContext.Request.Header.SetContentLength(buf.Len())
|
||||
requestContext.Request.Header.SetContentTypeBytes([]byte("text/plain"))
|
||||
requestContext.Response.SetStatusCode(status)
|
||||
_, _ = requestContext.Write(buf.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -190,7 +193,6 @@ func Start() {
|
|||
|
||||
addr, _ := config.Section("http").Get("addr", ":34000")
|
||||
readTimeout, _ := config.Section("http").Section("timeout").GetInt("read", 1)
|
||||
readHeaderTimeout, _ := config.Section("http").Section("timeout").GetInt("read_header", 2)
|
||||
writeTimeout, _ := config.Section("http").Section("timeout").GetInt("write", 3)
|
||||
idleTimeout, _ := config.Section("http").Section("timeout").GetInt("idle", 30)
|
||||
|
||||
|
@ -198,17 +200,23 @@ func Start() {
|
|||
handler.contextTimeout = time.Duration(writeTimeout)*time.Second - 200*time.Millisecond
|
||||
|
||||
// Create new server instance
|
||||
server := &http.Server{
|
||||
Handler: handler,
|
||||
ReadTimeout: time.Duration(readTimeout) * time.Second,
|
||||
ReadHeaderTimeout: time.Duration(readHeaderTimeout) * time.Second,
|
||||
WriteTimeout: time.Duration(writeTimeout) * time.Second,
|
||||
IdleTimeout: time.Duration(idleTimeout) * time.Second,
|
||||
server := &fasthttp.Server{
|
||||
Handler: handler.ServeHTTP,
|
||||
ReadTimeout: time.Duration(readTimeout) * time.Second,
|
||||
WriteTimeout: time.Duration(writeTimeout) * time.Second,
|
||||
IdleTimeout: time.Duration(idleTimeout) * time.Second,
|
||||
GetOnly: true,
|
||||
DisablePreParseMultipartForm: true,
|
||||
NoDefaultServerHeader: true,
|
||||
NoDefaultDate: true,
|
||||
NoDefaultContentType: true,
|
||||
CloseOnShutdown: true,
|
||||
}
|
||||
|
||||
if idleTimeout <= 0 {
|
||||
log.Warning.Print("Setting idleTimeout <= 0 disables Keep-Alive which might negatively impact performance")
|
||||
server.SetKeepAlivesEnabled(false)
|
||||
|
||||
server.DisableKeepalive = true
|
||||
}
|
||||
|
||||
// Start new goroutine to calculate throughput
|
||||
|
@ -261,8 +269,7 @@ func Start() {
|
|||
// Wait for active connections to finish processing
|
||||
handler.waitGroup.Wait()
|
||||
|
||||
// Close server so that it does not Accept(), https://github.com/golang/go/issues/10527
|
||||
_ = server.Close()
|
||||
_ = server.Shutdown()
|
||||
|
||||
log.Info.Print("Now closed and not accepting any new connections")
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@ package server
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"chihaya/database"
|
||||
|
@ -69,18 +68,13 @@ func clientApproved(peerID string, db *database.Database) (uint16, bool) {
|
|||
return 0, false
|
||||
}
|
||||
|
||||
func isPasskeyValid(ctx context.Context, passkey string, db *database.Database) (*cdb.User, error) {
|
||||
if !db.UsersLock.RTryLockWithContext(ctx) {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
defer db.UsersLock.RUnlock()
|
||||
|
||||
user, exists := db.Users[passkey]
|
||||
func isPasskeyValid(passkey string, db *database.Database) *cdb.User {
|
||||
user, exists := (*db.Users.Load())[passkey]
|
||||
if !exists {
|
||||
return nil, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return user, nil
|
||||
return user
|
||||
}
|
||||
|
||||
func hasHitAndRun(db *database.Database, userID, torrentID uint32) bool {
|
||||
|
@ -96,5 +90,5 @@ func hasHitAndRun(db *database.Database, userID, torrentID uint32) bool {
|
|||
|
||||
func isDisabledDownload(db *database.Database, user *cdb.User, torrent *cdb.Torrent) bool {
|
||||
// Only disable download if the torrent doesn't have a HnR against it
|
||||
return user.DisableDownload && !hasHitAndRun(db, user.ID, torrent.ID)
|
||||
return user.DisableDownload.Load() && !hasHitAndRun(db, user.ID.Load(), torrent.ID.Load())
|
||||
}
|
||||
|
|
20
util/context_ticker.go
Normal file
20
util/context_ticker.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
func ContextTick(ctx context.Context, d time.Duration, onTick func()) {
|
||||
ticker := time.NewTicker(d)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
onTick()
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue