Compare commits

...

3 commits

16 changed files with 599 additions and 419 deletions

View file

@ -58,12 +58,7 @@ Configuration is done in `config.json`, which you'll need to create with the fol
```json
{
"database": {
"username": "chihaya",
"password": "",
"database": "chihaya",
"proto": "tcp",
"addr": "127.0.0.1:3306",
"dsn": "chihaya:@tcp(127.0.0.1:3306)/chihaya",
"deadlock_pause": 1,
"deadlock_retries": 5
},

View file

@ -40,18 +40,16 @@ func TestMain(m *testing.M) {
panic(err)
}
f, err := os.OpenFile("config.json", os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
f, err := os.OpenFile(configFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
panic(err)
}
configTest = make(Map)
dbConfig := map[string]interface{}{
"username": "chihaya",
"password": "",
"proto": "tcp",
"addr": "127.0.0.1:3306",
"database": "chihaya",
"dsn": "chihaya:@tcp(127.0.0.1:3306)/chihaya",
"deadlock_pause": json.Number("1"),
"deadlock_retries": json.Number("5"),
}
configTest["database"] = dbConfig
configTest["addr"] = ":34000"
@ -148,5 +146,5 @@ func TestSection(t *testing.T) {
}
func cleanup() {
_ = os.Remove("config.json")
_ = os.Remove(configFile)
}

View file

@ -19,8 +19,8 @@ package database
import (
"bytes"
"context"
"database/sql"
"fmt"
"os"
"sync"
"sync/atomic"
@ -35,11 +35,6 @@ import (
"github.com/go-sql-driver/mysql"
)
type Connection struct {
sqlDb *sql.DB
mutex sync.Mutex
}
type Database struct {
snatchChannel chan *bytes.Buffer
transferHistoryChannel chan *bytes.Buffer
@ -59,16 +54,18 @@ type Database struct {
Users atomic.Pointer[map[string]*cdb.User]
HitAndRuns atomic.Pointer[map[cdb.UserTorrentPair]struct{}]
Torrents atomic.Pointer[map[cdb.TorrentHash]*cdb.Torrent]
TorrentGroupFreeleech atomic.Pointer[map[cdb.TorrentGroup]*cdb.TorrentGroupFreeleech]
TorrentGroupFreeleech atomic.Pointer[map[cdb.TorrentGroupKey]*cdb.TorrentGroupFreeleech]
Clients atomic.Pointer[map[uint16]string]
mainConn *Connection // Used for reloading and misc queries
bufferPool *util.BufferPool
transferHistoryLock sync.Mutex
terminate bool
conn *sql.DB
terminate atomic.Bool
ctx context.Context
ctxCancel func()
waitGroup sync.WaitGroup
}
@ -77,71 +74,66 @@ var (
maxDeadlockRetries int
)
var defaultDsn = map[string]string{
"username": "chihaya",
"password": "",
"proto": "tcp",
"addr": "127.0.0.1:3306",
"database": "chihaya",
}
const defaultDsn = "chihaya:@tcp(127.0.0.1:3306)/chihaya"
func (db *Database) Init() {
db.terminate = false
db.terminate.Store(false)
db.ctx, db.ctxCancel = context.WithCancel(context.Background())
log.Info.Print("Opening database connection...")
db.mainConn = Open()
db.conn = Open()
// Used for recording updates, so the max required size should be < 128 bytes. See queue.go for details
db.bufferPool = util.NewBufferPool(128)
var err error
db.loadUsersStmt, err = db.mainConn.sqlDb.Prepare(
db.loadUsersStmt, err = db.conn.Prepare(
"SELECT ID, torrent_pass, DownMultiplier, UpMultiplier, DisableDownload, TrackerHide " +
"FROM users_main WHERE Enabled = '1'")
if err != nil {
panic(err)
}
db.loadHnrStmt, err = db.mainConn.sqlDb.Prepare(
db.loadHnrStmt, err = db.conn.Prepare(
"SELECT h.uid, h.fid FROM transfer_history AS h " +
"JOIN users_main AS u ON u.ID = h.uid WHERE h.hnr = 1 AND u.Enabled = '1'")
if err != nil {
panic(err)
}
db.loadTorrentsStmt, err = db.mainConn.sqlDb.Prepare(
db.loadTorrentsStmt, err = db.conn.Prepare(
"SELECT ID, info_hash, DownMultiplier, UpMultiplier, Snatched, Status, GroupID, TorrentType FROM torrents")
if err != nil {
panic(err)
}
db.loadTorrentGroupFreeleechStmt, err = db.mainConn.sqlDb.Prepare(
db.loadTorrentGroupFreeleechStmt, err = db.conn.Prepare(
"SELECT GroupID, `Type`, DownMultiplier, UpMultiplier FROM torrent_group_freeleech")
if err != nil {
panic(err)
}
db.loadClientsStmt, err = db.mainConn.sqlDb.Prepare(
db.loadClientsStmt, err = db.conn.Prepare(
"SELECT id, peer_id FROM approved_clients WHERE archived = 0")
if err != nil {
panic(err)
}
db.loadFreeleechStmt, err = db.mainConn.sqlDb.Prepare(
db.loadFreeleechStmt, err = db.conn.Prepare(
"SELECT mod_setting FROM mod_core WHERE mod_option = 'global_freeleech'")
if err != nil {
panic(err)
}
db.cleanStalePeersStmt, err = db.mainConn.sqlDb.Prepare(
db.cleanStalePeersStmt, err = db.conn.Prepare(
"UPDATE transfer_history SET active = 0 WHERE last_announce < ? AND active = 1")
if err != nil {
panic(err)
}
db.unPruneTorrentStmt, err = db.mainConn.sqlDb.Prepare(
db.unPruneTorrentStmt, err = db.conn.Prepare(
"UPDATE torrents SET Status = 0 WHERE ID = ?")
if err != nil {
panic(err)
@ -179,7 +171,8 @@ func (db *Database) Init() {
func (db *Database) Terminate() {
log.Info.Print("Terminating database connection...")
db.terminate = true
db.terminate.Store(true)
db.ctxCancel()
log.Info.Print("Closing all flush channels...")
db.closeFlushChannels()
@ -190,13 +183,11 @@ func (db *Database) Terminate() {
}()
db.waitGroup.Wait()
db.mainConn.mutex.Lock()
_ = db.mainConn.Close()
db.mainConn.mutex.Unlock()
_ = db.conn.Close()
db.serialize()
}
func Open() *Connection {
func Open() *sql.DB {
databaseConfig := config.Section("database")
deadlockWaitTime, _ = databaseConfig.GetInt("deadlock_pause", 1)
maxDeadlockRetries, _ = databaseConfig.GetInt("deadlock_retries", 5)
@ -212,18 +203,7 @@ func Open() *Connection {
// First try to load the DSN from environment. USeful for tests.
databaseDsn := os.Getenv("DB_DSN")
if databaseDsn == "" {
dbUsername, _ := databaseConfig.Get("username", defaultDsn["username"])
dbPassword, _ := databaseConfig.Get("password", defaultDsn["password"])
dbProto, _ := databaseConfig.Get("proto", defaultDsn["proto"])
dbAddr, _ := databaseConfig.Get("addr", defaultDsn["addr"])
dbDatabase, _ := databaseConfig.Get("database", defaultDsn["database"])
databaseDsn = fmt.Sprintf("%s:%s@%s(%s)/%s",
dbUsername,
dbPassword,
dbProto,
dbAddr,
dbDatabase,
)
databaseDsn, _ = databaseConfig.Get("dsn", defaultDsn)
}
sqlDb, err := sql.Open("mysql", databaseDsn)
@ -235,16 +215,10 @@ func Open() *Connection {
log.Fatal.Fatalf("Couldn't ping database - %s", err)
}
return &Connection{
sqlDb: sqlDb,
}
return sqlDb
}
func (db *Connection) Close() error {
return db.sqlDb.Close()
}
func (db *Connection) query(stmt *sql.Stmt, args ...interface{}) *sql.Rows { //nolint:unparam
func (db *Database) query(stmt *sql.Stmt, args ...interface{}) *sql.Rows { //nolint:unparam
rows, _ := perform(func() (interface{}, error) {
return stmt.Query(args...)
}).(*sql.Rows)
@ -252,7 +226,7 @@ func (db *Connection) query(stmt *sql.Stmt, args ...interface{}) *sql.Rows { //n
return rows
}
func (db *Connection) execute(stmt *sql.Stmt, args ...interface{}) sql.Result {
func (db *Database) execute(stmt *sql.Stmt, args ...interface{}) sql.Result {
result, _ := perform(func() (interface{}, error) {
return stmt.Exec(args...)
}).(sql.Result)
@ -260,9 +234,9 @@ func (db *Connection) execute(stmt *sql.Stmt, args ...interface{}) sql.Result {
return result
}
func (db *Connection) exec(query *bytes.Buffer, args ...interface{}) sql.Result { //nolint:unparam
func (db *Database) exec(query *bytes.Buffer, args ...interface{}) sql.Result { //nolint:unparam
result, _ := perform(func() (interface{}, error) {
return db.sqlDb.Exec(query.String(), args...)
return db.conn.Exec(query.String(), args...)
}).(sql.Result)
return result

View file

@ -20,7 +20,6 @@ package database
import (
"fmt"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"math"
"net"
"os"
@ -47,7 +46,7 @@ func TestMain(m *testing.M) {
db.Init()
fixtures, err = testfixtures.New(
testfixtures.Database(db.mainConn.sqlDb),
testfixtures.Database(db.conn),
testfixtures.Dialect("mariadb"),
testfixtures.Directory("fixtures"),
testfixtures.DangerousSkipTestDatabaseCheck(),
@ -147,46 +146,50 @@ func TestLoadTorrents(t *testing.T) {
dbTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
db.Torrents.Store(&dbTorrents)
t1 := &cdb.Torrent{
Seeders: map[cdb.PeerKey]*cdb.Peer{},
Leechers: map[cdb.PeerKey]*cdb.Peer{},
}
t1.ID.Store(1)
t1.Status.Store(1)
t1.Snatched.Store(2)
t1.DownMultiplier.Store(math.Float64bits(1))
t1.UpMultiplier.Store(math.Float64bits(1))
t1.Group.GroupID.Store(1)
t1.Group.TorrentType.Store(cdb.MustTorrentTypeFromString("anime"))
t2 := &cdb.Torrent{
Seeders: map[cdb.PeerKey]*cdb.Peer{},
Leechers: map[cdb.PeerKey]*cdb.Peer{},
}
t2.ID.Store(2)
t2.Status.Store(0)
t2.Snatched.Store(0)
t2.DownMultiplier.Store(math.Float64bits(2))
t2.UpMultiplier.Store(math.Float64bits(0.5))
t2.Group.GroupID.Store(1)
t2.Group.TorrentType.Store(cdb.MustTorrentTypeFromString("music"))
t3 := &cdb.Torrent{
Seeders: map[cdb.PeerKey]*cdb.Peer{},
Leechers: map[cdb.PeerKey]*cdb.Peer{},
}
t3.ID.Store(3)
t3.Status.Store(0)
t3.Snatched.Store(0)
t3.DownMultiplier.Store(math.Float64bits(1))
t3.UpMultiplier.Store(math.Float64bits(1))
t3.Group.GroupID.Store(2)
t3.Group.TorrentType.Store(cdb.MustTorrentTypeFromString("anime"))
torrents := map[cdb.TorrentHash]*cdb.Torrent{
{114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56}: {
ID: 1,
Status: 1,
Snatched: 2,
DownMultiplier: 1,
UpMultiplier: 1,
Seeders: map[cdb.PeerKey]*cdb.Peer{},
Leechers: map[cdb.PeerKey]*cdb.Peer{},
Group: cdb.TorrentGroup{
GroupID: 1,
TorrentType: "anime",
},
},
{22, 168, 45, 221, 87, 225, 140, 177, 94, 34, 242, 225, 196, 234, 222, 46, 187, 131, 177, 155}: {
ID: 2,
Status: 0,
Snatched: 0,
DownMultiplier: 2,
UpMultiplier: 0.5,
Seeders: map[cdb.PeerKey]*cdb.Peer{},
Leechers: map[cdb.PeerKey]*cdb.Peer{},
Group: cdb.TorrentGroup{
GroupID: 1,
TorrentType: "music",
},
},
{89, 252, 84, 49, 177, 28, 118, 28, 148, 205, 62, 185, 8, 37, 234, 110, 109, 200, 165, 241}: {
ID: 3,
Status: 0,
Snatched: 0,
DownMultiplier: 1,
UpMultiplier: 1,
Seeders: map[cdb.PeerKey]*cdb.Peer{},
Leechers: map[cdb.PeerKey]*cdb.Peer{},
Group: cdb.TorrentGroup{
GroupID: 2,
TorrentType: "anime",
},
},
{114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56}: t1,
{22, 168, 45, 221, 87, 225, 140, 177, 94, 34, 242, 225, 196, 234, 222, 46, 187, 131, 177, 155}: t2,
{89, 252, 84, 49, 177, 28, 118, 28, 148, 205, 62, 185, 8, 37, 234, 110, 109, 200, 165, 241}: t3,
}
for k := range torrents {
torrents[k].InitializeLock()
}
// Test with fresh data
@ -201,7 +204,7 @@ func TestLoadTorrents(t *testing.T) {
}
for hash, torrent := range torrents {
if !cmp.Equal(torrent, dbTorrents[hash], cmpopts.IgnoreFields(cdb.Torrent{}, "lock")) {
if !cmp.Equal(torrent, dbTorrents[hash], cdb.TorrentTestCompareOptions...) {
hashHex, _ := hash.MarshalText()
t.Fatal(fixtureFailure(
fmt.Sprintf("Did not load torrent (%s) as expected from fixture file", string(hashHex)),
@ -217,7 +220,7 @@ func TestLoadTorrents(t *testing.T) {
dbTorrents = *db.Torrents.Load()
if !cmp.Equal(oldTorrents, dbTorrents, cmpopts.IgnoreFields(cdb.Torrent{}, "lock")) {
if !cmp.Equal(oldTorrents, dbTorrents, cdb.TorrentTestCompareOptions...) {
t.Fatal(fixtureFailure("Did not reload torrents as expected from fixture file", oldTorrents, dbTorrents))
}
}
@ -225,14 +228,11 @@ func TestLoadTorrents(t *testing.T) {
func TestLoadGroupsFreeleech(t *testing.T) {
prepareTestDatabase()
dbMap := make(map[cdb.TorrentGroup]*cdb.TorrentGroupFreeleech)
dbMap := make(map[cdb.TorrentGroupKey]*cdb.TorrentGroupFreeleech)
db.TorrentGroupFreeleech.Store(&dbMap)
torrentGroupFreeleech := map[cdb.TorrentGroup]*cdb.TorrentGroupFreeleech{
{
GroupID: 2,
TorrentType: "anime",
}: {
torrentGroupFreeleech := map[cdb.TorrentGroupKey]*cdb.TorrentGroupFreeleech{
cdb.MustTorrentGroupKeyFromString("anime", 2): {
DownMultiplier: 0,
UpMultiplier: 2,
},
@ -317,19 +317,24 @@ func TestUnPrune(t *testing.T) {
dbTorrents := *db.Torrents.Load()
h := cdb.TorrentHash{114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56}
dbTorrent := dbTorrents[h]
torrent := cdb.Torrent{
Seeders: dbTorrents[h].Seeders,
Leechers: dbTorrents[h].Leechers,
Group: dbTorrents[h].Group,
ID: dbTorrents[h].ID,
Snatched: dbTorrents[h].Snatched,
Status: dbTorrents[h].Status,
LastAction: dbTorrents[h].LastAction,
UpMultiplier: dbTorrents[h].UpMultiplier,
DownMultiplier: dbTorrents[h].DownMultiplier,
Seeders: dbTorrent.Seeders,
Leechers: dbTorrent.Leechers,
}
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
torrent.ID.Store(dbTorrent.ID.Load())
torrent.Status.Store(dbTorrent.Status.Load())
torrent.Snatched.Store(dbTorrent.Snatched.Load())
torrent.DownMultiplier.Store(dbTorrent.DownMultiplier.Load())
torrent.UpMultiplier.Store(dbTorrent.UpMultiplier.Load())
torrent.Group.GroupID.Store(dbTorrent.Group.GroupID.Load())
torrent.Group.TorrentType.Store(dbTorrent.Group.TorrentType.Load())
torrent.InitializeLock()
torrent.Status = 0
torrent.Status.Store(0)
db.UnPrune(dbTorrents[h])
@ -337,7 +342,7 @@ func TestUnPrune(t *testing.T) {
dbTorrents = *db.Torrents.Load()
if !cmp.Equal(&torrent, dbTorrents[h], cmpopts.IgnoreFields(cdb.Torrent{}, "lock")) {
if !cmp.Equal(&torrent, dbTorrents[h], cdb.TorrentTestCompareOptions...) {
t.Fatal(fixtureFailure(
fmt.Sprintf("Torrent (%x) was not unpruned properly", h),
&torrent,
@ -374,7 +379,7 @@ func TestRecordAndFlushUsers(t *testing.T) {
deltaDownload = int64(float64(deltaRawDownload) * math.Float64frombits(testUser.DownMultiplier.Load()))
deltaUpload = int64(float64(deltaRawUpload) * math.Float64frombits(testUser.UpMultiplier.Load()))
row := db.mainConn.sqlDb.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
row := db.conn.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
"FROM users_main WHERE ID = ?", testUser.ID.Load())
err := row.Scan(&initUpload, &initDownload, &initRawUpload, &initRawDownload)
@ -389,7 +394,7 @@ func TestRecordAndFlushUsers(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row = db.mainConn.sqlDb.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
row = db.conn.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
"FROM users_main WHERE ID = ?", testUser.ID.Load())
err = row.Scan(&upload, &download, &rawUpload, &rawDownload)
@ -470,7 +475,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
deltaActiveTime = 267
deltaSeedTime = 15
row := db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
row := db.conn.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
"FROM transfer_history WHERE uid = ? AND fid = ?", testPeer.UserID, testPeer.TorrentID)
err := row.Scan(&initRawUpload, &initRawDownload, &initActiveTime, &initSeedTime, &initActive, &initSnatch)
@ -491,7 +496,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row = db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
row = db.conn.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
"FROM transfer_history WHERE uid = ? AND fid = ?", testPeer.UserID, testPeer.TorrentID)
err = row.Scan(&rawUpload, &rawDownload, &activeTime, &seedTime, &active, &snatch)
@ -552,7 +557,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
var gotStartTime int64
row = db.mainConn.sqlDb.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
row = db.conn.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
"FROM transfer_history WHERE uid = ? AND fid = ?", gotPeer.UserID, gotPeer.TorrentID)
err = row.Scan(&gotPeer.Seeding, &gotStartTime, &gotPeer.LastAnnounce, &gotPeer.Left)
@ -597,7 +602,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row = db.mainConn.sqlDb.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
row = db.conn.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
"FROM transfer_history WHERE uid = ? AND fid = ?", gotPeer.UserID, gotPeer.TorrentID)
err = row.Scan(&gotPeer.Seeding, &gotPeer.StartTime, &gotPeer.LastAnnounce, &gotPeer.Left)
@ -634,7 +639,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
deltaDownload = 236
deltaUpload = 3262
row := db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded "+
row := db.conn.QueryRow("SELECT uploaded, downloaded "+
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
@ -650,7 +655,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row = db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded "+
row = db.conn.QueryRow("SELECT uploaded, downloaded "+
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
@ -686,7 +691,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
var gotStartTime int64
row = db.mainConn.sqlDb.QueryRow("SELECT port, starttime, last_announce "+
row = db.conn.QueryRow("SELECT port, starttime, last_announce "+
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
@ -731,7 +736,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
Addr: cdb.NewPeerAddressFromIPPort(testPeer.Addr.IP(), 0),
}
row = db.mainConn.sqlDb.QueryRow("SELECT port, starttime, last_announce "+
row = db.conn.QueryRow("SELECT port, starttime, last_announce "+
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
@ -769,7 +774,7 @@ func TestRecordAndFlushSnatch(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row := db.mainConn.sqlDb.QueryRow("SELECT snatched_time "+
row := db.conn.QueryRow("SELECT snatched_time "+
"FROM transfer_history WHERE uid = ? AND fid = ?", testPeer.UserID, testPeer.TorrentID)
err := row.Scan(&snatchTime)
@ -790,21 +795,23 @@ func TestRecordAndFlushTorrents(t *testing.T) {
h := cdb.TorrentHash{114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56}
torrent := (*db.Torrents.Load())[h]
torrent.LastAction = time.Now().Unix()
torrent.LastAction.Store(time.Now().Unix())
torrent.Seeders[cdb.NewPeerKey(1, cdb.PeerIDFromRawString("test_peer_id_num_one"))] = &cdb.Peer{
UserID: 1,
TorrentID: torrent.ID,
TorrentID: torrent.ID.Load(),
ClientID: 1,
StartTime: time.Now().Unix(),
LastAnnounce: time.Now().Unix(),
}
torrent.Leechers[cdb.NewPeerKey(3, cdb.PeerIDFromRawString("test_peer_id_num_two"))] = &cdb.Peer{
UserID: 3,
TorrentID: torrent.ID,
TorrentID: torrent.ID.Load(),
ClientID: 2,
StartTime: time.Now().Unix(),
LastAnnounce: time.Now().Unix(),
}
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
db.QueueTorrent(torrent, 5)
@ -820,26 +827,26 @@ func TestRecordAndFlushTorrents(t *testing.T) {
numSeeders int
)
row := db.mainConn.sqlDb.QueryRow("SELECT Snatched, last_action, Seeders, Leechers "+
"FROM torrents WHERE ID = ?", torrent.ID)
row := db.conn.QueryRow("SELECT Snatched, last_action, Seeders, Leechers "+
"FROM torrents WHERE ID = ?", torrent.ID.Load())
err := row.Scan(&snatched, &lastAction, &numSeeders, &numLeechers)
if err != nil {
panic(err)
}
if torrent.Snatched+5 != snatched {
if uint16(torrent.Snatched.Load())+5 != snatched {
t.Fatal(fixtureFailure(
fmt.Sprintf("Snatches incorrectly updated in the database for torrent %x", h),
torrent.Snatched+5,
torrent.Snatched.Load()+5,
snatched,
))
}
if torrent.LastAction != lastAction {
if torrent.LastAction.Load() != lastAction {
t.Fatal(fixtureFailure(
fmt.Sprintf("Last incorrectly updated in the database for torrent %x", h),
torrent.LastAction,
torrent.LastAction.Load(),
lastAction,
))
}
@ -859,6 +866,22 @@ func TestRecordAndFlushTorrents(t *testing.T) {
numLeechers,
))
}
if int(torrent.SeedersLength.Load()) != numSeeders {
t.Fatal(fixtureFailure(
fmt.Sprintf("SeedersLength incorrectly updated in the database for torrent %x", h),
len(torrent.Seeders),
numSeeders,
))
}
if int(torrent.LeechersLength.Load()) != numLeechers {
t.Fatal(fixtureFailure(
fmt.Sprintf("LeechersLength incorrectly updated in the database for torrent %x", h),
len(torrent.Leechers),
numLeechers,
))
}
}
func TestTerminate(_ *testing.T) {

View file

@ -19,6 +19,7 @@ package database
import (
"bytes"
"chihaya/util"
"errors"
"time"
@ -111,25 +112,9 @@ func (db *Database) flushTorrents() {
count int
)
conn := Open()
for {
query.Reset()
query.WriteString("CREATE TEMPORARY TABLE IF NOT EXISTS flush_torrents (" +
"ID int unsigned NOT NULL, " +
"Snatched int unsigned NOT NULL DEFAULT 0, " +
"Seeders int unsigned NOT NULL DEFAULT 0, " +
"Leechers int unsigned NOT NULL DEFAULT 0, " +
"last_action int NOT NULL DEFAULT 0, " +
"PRIMARY KEY (ID)) ENGINE=MEMORY")
conn.exec(&query)
query.Reset()
query.WriteString("TRUNCATE flush_torrents")
conn.exec(&query)
query.Reset()
query.WriteString("INSERT INTO flush_torrents VALUES ")
query.WriteString("INSERT IGNORE INTO torrents (ID, Snatched, Seeders, Leechers, last_action) VALUES ")
length := len(db.torrentChannel)
@ -149,7 +134,7 @@ func (db *Database) flushTorrents() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{torrents} Flushing %d", count)
}
@ -158,18 +143,9 @@ func (db *Database) flushTorrents() {
query.WriteString(" ON DUPLICATE KEY UPDATE Snatched = Snatched + VALUE(Snatched), " +
"Seeders = VALUE(Seeders), Leechers = VALUE(Leechers), " +
"last_action = IF(last_action < VALUE(last_action), VALUE(last_action), last_action)")
conn.exec(&query)
db.exec(&query)
query.Reset()
query.WriteString("UPDATE torrents t, flush_torrents ft SET " +
"t.Snatched = t.Snatched + ft.Snatched, " +
"t.Seeders = ft.Seeders, " +
"t.Leechers = ft.Leechers, " +
"t.last_action = IF(t.last_action < ft.last_action, ft.last_action, t.last_action)" +
"WHERE t.ID = ft.ID")
conn.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("torrents", elapsedTime)
collectors.UpdateChannelsLen("torrents", count)
@ -178,14 +154,12 @@ func (db *Database) flushTorrents() {
if length < (torrentFlushBufferSize >> 1) {
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
}
} else if db.terminate {
} else if db.terminate.Load() {
break
} else {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) flushUsers() {
@ -197,25 +171,9 @@ func (db *Database) flushUsers() {
count int
)
conn := Open()
for {
query.Reset()
query.WriteString("CREATE TEMPORARY TABLE IF NOT EXISTS flush_users (" +
"ID int unsigned NOT NULL, " +
"Uploaded bigint unsigned NOT NULL DEFAULT 0, " +
"Downloaded bigint unsigned NOT NULL DEFAULT 0, " +
"rawdl bigint unsigned NOT NULL DEFAULT 0, " +
"rawup bigint unsigned NOT NULL DEFAULT 0, " +
"PRIMARY KEY (ID)) ENGINE=MEMORY")
conn.exec(&query)
query.Reset()
query.WriteString("TRUNCATE flush_users")
conn.exec(&query)
query.Reset()
query.WriteString("INSERT INTO flush_users VALUES ")
query.WriteString("INSERT IGNORE INTO users_main (ID, Uploaded, Downloaded, rawdl, rawup) VALUES ")
length := len(db.userChannel)
@ -235,7 +193,7 @@ func (db *Database) flushUsers() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{users_main} Flushing %d", count)
}
@ -243,18 +201,9 @@ func (db *Database) flushUsers() {
query.WriteString(" ON DUPLICATE KEY UPDATE Uploaded = Uploaded + VALUE(Uploaded), " +
"Downloaded = Downloaded + VALUE(Downloaded), rawdl = rawdl + VALUE(rawdl), rawup = rawup + VALUE(rawup)")
conn.exec(&query)
db.exec(&query)
query.Reset()
query.WriteString("UPDATE users_main u, flush_users fu SET " +
"u.Uploaded = u.Uploaded + fu.Uploaded, " +
"u.Downloaded = u.Downloaded + fu.Downloaded, " +
"u.rawdl = u.rawdl + fu.rawdl, " +
"u.rawup = u.rawup + fu.rawup " +
"WHERE u.ID = fu.ID")
conn.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("users", elapsedTime)
collectors.UpdateChannelsLen("users", count)
@ -263,14 +212,12 @@ func (db *Database) flushUsers() {
if length < (userFlushBufferSize >> 1) {
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
}
} else if db.terminate {
} else if db.terminate.Load() {
break
} else {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) flushTransferHistory() {
@ -282,8 +229,6 @@ func (db *Database) flushTransferHistory() {
count int
)
conn := Open()
for {
length, err := func() (int, error) {
db.transferHistoryLock.Lock()
@ -311,7 +256,7 @@ func (db *Database) flushTransferHistory() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{transfer_history} Flushing %d", count)
}
@ -323,16 +268,16 @@ func (db *Database) flushTransferHistory() {
"seedtime = seedtime + VALUE(seedtime), last_announce = VALUE(last_announce), " +
"active = VALUE(active), snatched = snatched + VALUE(snatched);")
conn.exec(&query)
db.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("transfer_history", elapsedTime)
collectors.UpdateChannelsLen("transfer_history", count)
}
return length, nil
} else if db.terminate {
} else if db.terminate.Load() {
return 0, errDbTerminate
}
@ -347,8 +292,6 @@ func (db *Database) flushTransferHistory() {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) flushTransferIps() {
@ -360,8 +303,6 @@ func (db *Database) flushTransferIps() {
count int
)
conn := Open()
for {
query.Reset()
query.WriteString("INSERT INTO transfer_ips (uid, fid, client_id, ip, port, uploaded, downloaded, " +
@ -385,7 +326,7 @@ func (db *Database) flushTransferIps() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{transfer_ips} Flushing %d", count)
}
@ -394,9 +335,9 @@ func (db *Database) flushTransferIps() {
// todo: port should be part of PK
query.WriteString("\nON DUPLICATE KEY UPDATE port = VALUE(port), downloaded = downloaded + VALUE(downloaded), " +
"uploaded = uploaded + VALUE(uploaded), last_announce = VALUE(last_announce)")
conn.exec(&query)
db.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("transfer_ips", elapsedTime)
collectors.UpdateChannelsLen("transfer_ips", count)
@ -405,14 +346,12 @@ func (db *Database) flushTransferIps() {
if length < (transferIpsFlushBufferSize >> 1) {
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
}
} else if db.terminate {
} else if db.terminate.Load() {
break
} else {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) flushSnatches() {
@ -424,8 +363,6 @@ func (db *Database) flushSnatches() {
count int
)
conn := Open()
for {
query.Reset()
query.WriteString("INSERT INTO transfer_history (uid, fid, snatched_time) VALUES\n")
@ -448,16 +385,16 @@ func (db *Database) flushSnatches() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{snatches} Flushing %d", count)
}
startTime := time.Now()
query.WriteString("\nON DUPLICATE KEY UPDATE snatched_time = VALUE(snatched_time)")
conn.exec(&query)
db.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("snatches", elapsedTime)
collectors.UpdateChannelsLen("snatches", count)
@ -466,14 +403,12 @@ func (db *Database) flushSnatches() {
if length < (snatchFlushBufferSize >> 1) {
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
}
} else if db.terminate {
} else if db.terminate.Load() {
break
} else {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) purgeInactivePeers() {
@ -483,7 +418,7 @@ func (db *Database) purgeInactivePeers() {
count int
)
for !db.terminate {
util.ContextTick(db.ctx, time.Duration(purgeInactivePeersInterval)*time.Second, func() {
start = time.Now()
now = start.Unix()
count = 0
@ -494,7 +429,7 @@ func (db *Database) purgeInactivePeers() {
dbTorrents := *db.Torrents.Load()
for _, torrent := range dbTorrents {
func() {
//Take write lock to operate on entries
//Take write lock to operate on peer entries
torrent.Lock()
defer torrent.Unlock()
@ -522,6 +457,10 @@ func (db *Database) purgeInactivePeers() {
}
if countThisTorrent != count {
// Update lengths of peers
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
db.QueueTorrent(torrent, 0)
}
}()
@ -540,11 +479,8 @@ func (db *Database) purgeInactivePeers() {
db.transferHistoryLock.Lock()
defer db.transferHistoryLock.Unlock()
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start = time.Now()
result := db.mainConn.execute(db.cleanStalePeersStmt, oldestActive)
result := db.execute(db.cleanStalePeersStmt, oldestActive)
if result != nil {
rows, err := result.RowsAffected()
@ -555,7 +491,5 @@ func (db *Database) purgeInactivePeers() {
}
}
}()
time.Sleep(time.Duration(purgeInactivePeersInterval) * time.Second)
}
})
}

View file

@ -38,15 +38,15 @@ func (db *Database) QueueTorrent(torrent *cdb.Torrent, deltaSnatch uint8) {
tq := db.bufferPool.Take()
tq.WriteString("(")
tq.WriteString(strconv.FormatUint(uint64(torrent.ID), 10))
tq.WriteString(strconv.FormatUint(uint64(torrent.ID.Load()), 10))
tq.WriteString(",")
tq.WriteString(strconv.FormatUint(uint64(deltaSnatch), 10))
tq.WriteString(",")
tq.WriteString(strconv.FormatInt(int64(len(torrent.Seeders)), 10))
tq.WriteString(strconv.FormatUint(uint64(torrent.SeedersLength.Load()), 10))
tq.WriteString(",")
tq.WriteString(strconv.FormatInt(int64(len(torrent.Leechers)), 10))
tq.WriteString(strconv.FormatUint(uint64(torrent.LeechersLength.Load()), 10))
tq.WriteString(",")
tq.WriteString(strconv.FormatInt(torrent.LastAction, 10))
tq.WriteString(strconv.FormatInt(torrent.LastAction.Load(), 10))
tq.WriteString(")")
select {
@ -179,7 +179,5 @@ func (db *Database) QueueSnatch(peer *cdb.Peer, now int64) {
}
func (db *Database) UnPrune(torrent *cdb.Torrent) {
db.mainConn.mutex.Lock()
db.mainConn.execute(db.unPruneTorrentStmt, torrent.ID)
db.mainConn.mutex.Unlock()
db.execute(db.unPruneTorrentStmt, torrent.ID.Load())
}

View file

@ -26,6 +26,7 @@ import (
"chihaya/config"
cdb "chihaya/database/types"
"chihaya/log"
"chihaya/util"
)
var (
@ -50,31 +51,26 @@ func init() {
*/
func (db *Database) startReloading() {
go func() {
for !db.terminate {
time.Sleep(time.Duration(reloadInterval) * time.Second)
util.ContextTick(db.ctx, time.Duration(reloadInterval)*time.Second, func() {
db.waitGroup.Add(1)
defer db.waitGroup.Done()
db.loadUsers()
db.loadHitAndRuns()
db.loadTorrents()
db.loadGroupsFreeleech()
db.loadConfig()
db.loadClients()
db.waitGroup.Done()
}
})
}()
}
func (db *Database) loadUsers() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start := time.Now()
dbUsers := *db.Users.Load()
newUsers := make(map[string]*cdb.User, len(dbUsers))
rows := db.mainConn.query(db.loadUsersStmt)
rows := db.query(db.loadUsersStmt)
if rows == nil {
log.Error.Print("Failed to load hit and runs from database")
log.WriteStack()
@ -126,13 +122,10 @@ func (db *Database) loadUsers() {
}
func (db *Database) loadHitAndRuns() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start := time.Now()
newHnr := make(map[cdb.UserTorrentPair]struct{})
rows := db.mainConn.query(db.loadHnrStmt)
rows := db.query(db.loadHnrStmt)
if rows == nil {
log.Error.Print("Failed to load hit and runs from database")
log.WriteStack()
@ -173,12 +166,9 @@ func (db *Database) loadTorrents() {
newTorrents := make(map[cdb.TorrentHash]*cdb.Torrent, len(dbTorrents))
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start = time.Now()
rows := db.mainConn.query(db.loadTorrentsStmt)
rows := db.query(db.loadTorrentsStmt)
if rows == nil {
log.Error.Print("Failed to load torrents from database")
log.WriteStack()
@ -197,7 +187,8 @@ func (db *Database) loadTorrents() {
downMultiplier, upMultiplier float64
snatched uint16
status uint8
group cdb.TorrentGroup
groupID uint32
torrentType string
)
if err := rows.Scan(
@ -207,40 +198,48 @@ func (db *Database) loadTorrents() {
&upMultiplier,
&snatched,
&status,
&group.GroupID,
&group.TorrentType,
&groupID,
&torrentType,
); err != nil {
log.Error.Printf("Error scanning torrent row: %s", err)
log.WriteStack()
}
if old, exists := dbTorrents[infoHash]; exists && old != nil {
func() {
old.Lock()
defer old.Unlock()
torrentTypeUint64, err := cdb.TorrentTypeFromString(torrentType)
old.ID = id
old.DownMultiplier = downMultiplier
old.UpMultiplier = upMultiplier
old.Snatched = snatched
old.Status = status
old.Group = group
}()
if err != nil {
log.Error.Printf("Error storing torrent row: %s", err)
log.WriteStack()
}
if old, exists := dbTorrents[infoHash]; exists && old != nil {
old.ID.Store(id)
old.DownMultiplier.Store(math.Float64bits(downMultiplier))
old.UpMultiplier.Store(math.Float64bits(upMultiplier))
old.Snatched.Store(uint32(snatched))
old.Status.Store(uint32(status))
old.Group.TorrentType.Store(torrentTypeUint64)
old.Group.GroupID.Store(groupID)
newTorrents[infoHash] = old
} else {
newTorrents[infoHash] = &cdb.Torrent{
ID: id,
UpMultiplier: upMultiplier,
DownMultiplier: downMultiplier,
Snatched: snatched,
Status: status,
Group: group,
t := &cdb.Torrent{
Seeders: make(map[cdb.PeerKey]*cdb.Peer),
Leechers: make(map[cdb.PeerKey]*cdb.Peer),
}
newTorrents[infoHash].InitializeLock()
t.InitializeLock()
t.ID.Store(id)
t.DownMultiplier.Store(math.Float64bits(downMultiplier))
t.UpMultiplier.Store(math.Float64bits(upMultiplier))
t.Snatched.Store(uint32(snatched))
t.Status.Store(uint32(status))
t.Group.TorrentType.Store(torrentTypeUint64)
t.Group.GroupID.Store(groupID)
newTorrents[infoHash] = t
}
}
@ -252,13 +251,10 @@ func (db *Database) loadTorrents() {
}
func (db *Database) loadGroupsFreeleech() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start := time.Now()
newTorrentGroupFreeleech := make(map[cdb.TorrentGroup]*cdb.TorrentGroupFreeleech)
newTorrentGroupFreeleech := make(map[cdb.TorrentGroupKey]*cdb.TorrentGroupFreeleech)
rows := db.mainConn.query(db.loadTorrentGroupFreeleechStmt)
rows := db.query(db.loadTorrentGroupFreeleechStmt)
if rows == nil {
log.Error.Print("Failed to load torrent group freeleech data from database")
log.WriteStack()
@ -273,15 +269,22 @@ func (db *Database) loadGroupsFreeleech() {
for rows.Next() {
var (
downMultiplier, upMultiplier float64
group cdb.TorrentGroup
groupID uint32
torrentType string
)
if err := rows.Scan(&group.GroupID, &group.TorrentType, &downMultiplier, &upMultiplier); err != nil {
log.Error.Printf("Error scanning torrent row: %s", err)
if err := rows.Scan(&groupID, &torrentType, &downMultiplier, &upMultiplier); err != nil {
log.Error.Printf("Error scanning torrent group freeleech row: %s", err)
log.WriteStack()
}
newTorrentGroupFreeleech[group] = &cdb.TorrentGroupFreeleech{
k, err := cdb.TorrentGroupKeyFromString(torrentType, groupID)
if err != nil {
log.Error.Printf("Error storing torrent group freeleech row: %s", err)
log.WriteStack()
}
newTorrentGroupFreeleech[k] = &cdb.TorrentGroupFreeleech{
UpMultiplier: upMultiplier,
DownMultiplier: downMultiplier,
}
@ -295,10 +298,7 @@ func (db *Database) loadGroupsFreeleech() {
}
func (db *Database) loadConfig() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
rows := db.mainConn.query(db.loadFreeleechStmt)
rows := db.query(db.loadFreeleechStmt)
if rows == nil {
log.Error.Print("Failed to load config from database")
log.WriteStack()
@ -323,13 +323,10 @@ func (db *Database) loadConfig() {
}
func (db *Database) loadClients() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start := time.Now()
newClients := make(map[uint16]string)
rows := db.mainConn.query(db.loadClientsStmt)
rows := db.query(db.loadClientsStmt)
if rows == nil {
log.Error.Print("Failed to load clients from database")
log.WriteStack()

View file

@ -18,6 +18,7 @@
package database
import (
"chihaya/util"
"fmt"
"os"
"time"
@ -37,10 +38,9 @@ func init() {
func (db *Database) startSerializing() {
go func() {
for !db.terminate {
time.Sleep(time.Duration(serializeInterval) * time.Second)
util.ContextTick(db.ctx, time.Duration(serializeInterval)*time.Second, func() {
db.serialize()
}
})
}()
}
@ -141,7 +141,7 @@ func (db *Database) deserialize() {
torrents = len(dbTorrents)
for _, t := range dbTorrents {
peers += len(t.Leechers) + len(t.Seeders)
peers += int(t.LeechersLength.Load()) + int(t.SeedersLength.Load())
}
db.Torrents.Store(&dbTorrents)

View file

@ -19,7 +19,6 @@ package database
import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"math"
"net"
"reflect"
@ -61,19 +60,26 @@ func TestSerializer(t *testing.T) {
testTorrentHash := cdb.TorrentHash{
114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56,
}
testTorrents[testTorrentHash] = &cdb.Torrent{
Status: 1,
Snatched: 100,
ID: 10,
LastAction: time.Now().Unix(),
UpMultiplier: 1,
DownMultiplier: 1,
torrent := &cdb.Torrent{
Seeders: map[cdb.PeerKey]*cdb.Peer{
cdb.NewPeerKey(12, cdb.PeerIDFromRawString("peer_is_twenty_chars")): testPeer,
},
Leechers: map[cdb.PeerKey]*cdb.Peer{},
}
testTorrents[testTorrentHash].InitializeLock()
torrent.InitializeLock()
torrent.ID.Store(10)
torrent.Status.Store(1)
torrent.Snatched.Store(100)
torrent.LastAction.Store(time.Now().Unix())
torrent.DownMultiplier.Store(math.Float64bits(1))
torrent.UpMultiplier.Store(math.Float64bits(1))
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
torrent.Group.GroupID.Store(1)
torrent.Group.TorrentType.Store(cdb.MustTorrentTypeFromString("anime"))
testTorrents[testTorrentHash] = torrent
// Prepare empty map to populate with test data
dbTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
@ -104,7 +110,7 @@ func TestSerializer(t *testing.T) {
dbTorrents = *db.Torrents.Load()
dbUsers = *db.Users.Load()
if !cmp.Equal(dbTorrents, testTorrents, cmpopts.IgnoreFields(cdb.Torrent{}, "lock")) {
if !cmp.Equal(dbTorrents, testTorrents, cdb.TorrentTestCompareOptions...) {
t.Fatalf("Torrents (%v) after serialization and deserialization do not match original torrents (%v)!",
dbTorrents, testTorrents)
}

View file

@ -22,10 +22,14 @@ import (
"database/sql/driver"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/viney-shih/go-lock"
"io"
"math"
"sync/atomic"
)
const TorrentHashSize = 20
@ -117,22 +121,30 @@ func (h *TorrentHash) UnmarshalText(b []byte) error {
}
type Torrent struct {
// lock This must be taken whenever read or write is made to Seeders or Leechers fields on this torrent.
lock *lock.ChanMutex
Seeders map[PeerKey]*Peer
Leechers map[PeerKey]*Peer
// SeedersLength Contains the length of Seeders. When Seeders is modified this field must be updated
SeedersLength atomic.Uint32
// LeechersLength Contains the length of Leechers. When LeechersLength is modified this field must be updated
LeechersLength atomic.Uint32
Group TorrentGroup
ID uint32
ID atomic.Uint32
Snatched uint16
// Snatched 16 bits
Snatched atomic.Uint32
Status uint8
LastAction int64 // unix time
// Snatched 8 bits
Status atomic.Uint32
LastAction atomic.Int64 // unix time
UpMultiplier float64
DownMultiplier float64
// lock This must be taken whenever read or write is made to fields on this torrent.
lock *lock.CASMutex
// UpMultiplier float64
UpMultiplier atomic.Uint64
// DownMultiplier float64
DownMultiplier atomic.Uint64
}
func (t *Torrent) InitializeLock() {
@ -140,7 +152,7 @@ func (t *Torrent) InitializeLock() {
panic("already initialized")
}
t.lock = lock.NewCASMutex()
t.lock = lock.NewChanMutex()
}
func (t *Torrent) Lock() {
@ -155,19 +167,15 @@ func (t *Torrent) Unlock() {
t.lock.Unlock()
}
func (t *Torrent) RLock() {
t.lock.RLock()
}
func (t *Torrent) RUnlock() {
t.lock.RUnlock()
}
func (t *Torrent) RTryLockWithContext(ctx context.Context) bool {
return t.lock.RTryLockWithContext(ctx)
}
func (t *Torrent) Load(version uint64, reader readerAndByteReader) (err error) {
var (
id uint32
snatched uint16
status uint8
lastAction int64
upMultiplier, downMultiplier float64
)
var varIntLen uint64
if varIntLen, err = binary.ReadUvarint(reader); err != nil {
@ -191,6 +199,8 @@ func (t *Torrent) Load(version uint64, reader readerAndByteReader) (err error) {
t.Seeders[k] = s
}
t.SeedersLength.Store(uint32(len(t.Seeders)))
if varIntLen, err = binary.ReadUvarint(reader); err != nil {
return err
}
@ -211,100 +221,270 @@ func (t *Torrent) Load(version uint64, reader readerAndByteReader) (err error) {
t.Leechers[k] = l
}
t.LeechersLength.Store(uint32(len(t.Leechers)))
if err = t.Group.Load(version, reader); err != nil {
return err
}
if err = binary.Read(reader, binary.LittleEndian, &t.ID); err != nil {
if err = binary.Read(reader, binary.LittleEndian, &id); err != nil {
return err
}
if err = binary.Read(reader, binary.LittleEndian, &t.Snatched); err != nil {
if err = binary.Read(reader, binary.LittleEndian, &snatched); err != nil {
return err
}
if err = binary.Read(reader, binary.LittleEndian, &t.Status); err != nil {
if err = binary.Read(reader, binary.LittleEndian, &status); err != nil {
return err
}
if err = binary.Read(reader, binary.LittleEndian, &t.LastAction); err != nil {
if err = binary.Read(reader, binary.LittleEndian, &lastAction); err != nil {
return err
}
if err = binary.Read(reader, binary.LittleEndian, &t.UpMultiplier); err != nil {
if err = binary.Read(reader, binary.LittleEndian, &upMultiplier); err != nil {
return err
}
return binary.Read(reader, binary.LittleEndian, &t.DownMultiplier)
if err = binary.Read(reader, binary.LittleEndian, &downMultiplier); err != nil {
return err
}
t.ID.Store(id)
t.Snatched.Store(uint32(snatched))
t.Status.Store(uint32(status))
t.LastAction.Store(lastAction)
t.UpMultiplier.Store(math.Float64bits(upMultiplier))
t.DownMultiplier.Store(math.Float64bits(downMultiplier))
return nil
}
func (t *Torrent) Append(preAllocatedBuffer []byte) (buf []byte) {
t.RLock()
defer t.RUnlock()
buf = preAllocatedBuffer
buf = binary.AppendUvarint(buf, uint64(len(t.Seeders)))
for k, s := range t.Seeders {
buf = append(buf, k[:]...)
func() {
// This could be a read-only lock, but this simpler lock is faster overall
t.Lock()
defer t.Unlock()
buf = s.Append(buf)
}
buf = binary.AppendUvarint(buf, uint64(len(t.Seeders)))
buf = binary.AppendUvarint(buf, uint64(len(t.Leechers)))
for k, s := range t.Seeders {
buf = append(buf, k[:]...)
for k, l := range t.Leechers {
buf = append(buf, k[:]...)
buf = s.Append(buf)
}
buf = l.Append(buf)
}
buf = binary.AppendUvarint(buf, uint64(len(t.Leechers)))
for k, l := range t.Leechers {
buf = append(buf, k[:]...)
buf = l.Append(buf)
}
}()
buf = t.Group.Append(buf)
buf = binary.LittleEndian.AppendUint32(buf, t.ID)
buf = binary.LittleEndian.AppendUint16(buf, t.Snatched)
buf = append(buf, t.Status)
buf = binary.LittleEndian.AppendUint64(buf, uint64(t.LastAction))
buf = binary.LittleEndian.AppendUint64(buf, math.Float64bits(t.UpMultiplier))
buf = binary.LittleEndian.AppendUint64(buf, math.Float64bits(t.DownMultiplier))
buf = binary.LittleEndian.AppendUint32(buf, t.ID.Load())
buf = binary.LittleEndian.AppendUint16(buf, uint16(t.Snatched.Load()))
buf = append(buf, uint8(t.Status.Load()))
buf = binary.LittleEndian.AppendUint64(buf, uint64(t.LastAction.Load()))
buf = binary.LittleEndian.AppendUint64(buf, t.UpMultiplier.Load())
buf = binary.LittleEndian.AppendUint64(buf, t.DownMultiplier.Load())
return buf
}
var encodeJSONTorrentMap = make(map[string]any)
var encodeJSONTorrentGroupMap = make(map[string]any)
// MarshalJSON Due to using atomics, JSON will not marshal values within them.
// This is only safe to call from a single thread at once
func (t *Torrent) MarshalJSON() (buf []byte, err error) {
encodeJSONTorrentMap["ID"] = t.ID.Load()
encodeJSONTorrentMap["Seeders"] = t.Seeders
encodeJSONTorrentMap["Leechers"] = t.Leechers
var torrentTypeBuf [8]byte
binary.LittleEndian.PutUint64(torrentTypeBuf[:], t.Group.TorrentType.Load())
i := 0
for ; i < len(torrentTypeBuf); i++ {
if torrentTypeBuf[i] == 0 {
break
}
}
encodeJSONTorrentGroupMap["TorrentType"] = string(torrentTypeBuf[:i])
encodeJSONTorrentGroupMap["GroupID"] = t.Group.GroupID.Load()
encodeJSONTorrentMap["Group"] = encodeJSONTorrentGroupMap
encodeJSONTorrentMap["Snatched"] = uint16(t.Snatched.Load())
encodeJSONTorrentMap["Status"] = uint8(t.Status.Load())
encodeJSONTorrentMap["LastAction"] = t.LastAction.Load()
encodeJSONTorrentMap["UpMultiplier"] = math.Float64frombits(t.UpMultiplier.Load())
encodeJSONTorrentMap["DownMultiplier"] = math.Float64frombits(t.UpMultiplier.Load())
return json.Marshal(encodeJSONTorrentMap)
}
type decodeJSONTorrent struct {
Seeders map[PeerKey]*Peer
Leechers map[PeerKey]*Peer
Group struct {
TorrentType string
GroupID uint32
}
ID uint32
Snatched uint16
Status uint8
LastAction int64
UpMultiplier float64
DownMultiplier float64
}
// UnmarshalJSON Due to using atomics, JSON will not marshal values within them.
// This is only safe to call from a single thread at once
func (t *Torrent) UnmarshalJSON(buf []byte) (err error) {
var torrentJSON decodeJSONTorrent
if err = json.Unmarshal(buf, &torrentJSON); err != nil {
return err
}
t.Seeders = torrentJSON.Seeders
t.Leechers = torrentJSON.Leechers
t.SeedersLength.Store(uint32(len(t.Seeders)))
t.LeechersLength.Store(uint32(len(t.Leechers)))
torrentType, err := TorrentTypeFromString(torrentJSON.Group.TorrentType)
if err != nil {
return err
}
t.Group.TorrentType.Store(torrentType)
t.Group.GroupID.Store(torrentJSON.Group.GroupID)
t.Snatched.Store(uint32(torrentJSON.Snatched))
t.Status.Store(uint32(torrentJSON.Status))
t.LastAction.Store(torrentJSON.LastAction)
t.UpMultiplier.Store(math.Float64bits(torrentJSON.UpMultiplier))
t.DownMultiplier.Store(math.Float64bits(torrentJSON.DownMultiplier))
return nil
}
type TorrentGroupFreeleech struct {
UpMultiplier float64
DownMultiplier float64
}
type TorrentGroup struct {
TorrentType string
GroupID uint32
type TorrentGroupKey [8 + 4]byte
func MustTorrentGroupKeyFromString(torrentType string, groupID uint32) TorrentGroupKey {
k, err := TorrentGroupKeyFromString(torrentType, groupID)
if err != nil {
panic(err)
}
return k
}
func (g *TorrentGroup) Load(_ uint64, reader readerAndByteReader) (err error) {
var varIntLen uint64
func TorrentGroupKeyFromString(torrentType string, groupID uint32) (k TorrentGroupKey, err error) {
t, err := TorrentTypeFromString(torrentType)
if err != nil {
return TorrentGroupKey{}, err
}
if varIntLen, err = binary.ReadUvarint(reader); err != nil {
binary.LittleEndian.PutUint64(k[:], t)
binary.LittleEndian.PutUint32(k[8:], groupID)
return k, nil
}
func MustTorrentTypeFromString(torrentType string) uint64 {
t, err := TorrentTypeFromString(torrentType)
if err != nil {
panic(err)
}
return t
}
func TorrentTypeFromString(torrentType string) (t uint64, err error) {
if len(torrentType) > 8 {
return 0, err
}
var buf [8]byte
copy(buf[:], torrentType)
return binary.LittleEndian.Uint64(buf[:]), nil
}
type TorrentGroup struct {
TorrentType atomic.Uint64
GroupID atomic.Uint32
}
func (g *TorrentGroup) Key() (k TorrentGroupKey) {
binary.LittleEndian.PutUint64(k[:], g.TorrentType.Load())
binary.LittleEndian.PutUint32(k[8:], g.GroupID.Load())
return k
}
var ErrTorrentTypeTooLong = errors.New("torrent type too long, maximum 8 bytes")
func (g *TorrentGroup) Load(version uint64, reader readerAndByteReader) (err error) {
var (
torrentType uint64
groupID uint32
)
if version <= 2 {
var varIntLen uint64
if varIntLen, err = binary.ReadUvarint(reader); err != nil {
return err
}
if varIntLen > 8 {
return ErrTorrentTypeTooLong
}
buf := make([]byte, 8)
if _, err = io.ReadFull(reader, buf[:varIntLen]); err != nil {
return err
}
torrentType = binary.LittleEndian.Uint64(buf)
} else {
if err = binary.Read(reader, binary.LittleEndian, &torrentType); err != nil {
return err
}
}
if err = binary.Read(reader, binary.LittleEndian, &groupID); err != nil {
return err
}
buf := make([]byte, varIntLen)
g.TorrentType.Store(torrentType)
g.GroupID.Store(groupID)
if _, err = io.ReadFull(reader, buf); err != nil {
return err
}
g.TorrentType = string(buf)
return binary.Read(reader, binary.LittleEndian, &g.GroupID)
return nil
}
func (g *TorrentGroup) Append(preAllocatedBuffer []byte) (buf []byte) {
buf = preAllocatedBuffer
buf = binary.AppendUvarint(buf, uint64(len(g.TorrentType)))
buf = append(buf, []byte(g.TorrentType)...)
buf = binary.LittleEndian.AppendUint64(buf, g.TorrentType.Load())
return binary.LittleEndian.AppendUint32(buf, g.GroupID)
return binary.LittleEndian.AppendUint32(buf, g.GroupID.Load())
}
// TorrentCacheFile holds filename used by serializer for this type
@ -312,4 +492,12 @@ var TorrentCacheFile = "torrent-cache"
// TorrentCacheVersion Used to distinguish old versions on the on-disk cache.
// Bump when fields are altered on Torrent, Peer or TorrentGroup structs
const TorrentCacheVersion = 2
const TorrentCacheVersion = 3
var TorrentTestCompareOptions = []cmp.Option{
cmp.AllowUnexported(atomic.Uint32{}),
cmp.AllowUnexported(atomic.Uint64{}),
cmp.AllowUnexported(atomic.Int64{}),
cmp.AllowUnexported(atomic.Bool{}),
cmpopts.IgnoreFields(Torrent{}, "lock"),
}

View file

@ -19,6 +19,7 @@ package types
import (
"encoding/binary"
"encoding/json"
"math"
"sync/atomic"
)
@ -94,6 +95,45 @@ func (u *User) Append(preAllocatedBuffer []byte) (buf []byte) {
return buf
}
var encodeJSONUserMap = make(map[string]any)
// MarshalJSON Due to using atomics, JSON will not marshal values within them.
// This is only safe to call from a single thread at once
func (u *User) MarshalJSON() (buf []byte, err error) {
encodeJSONUserMap["ID"] = u.ID.Load()
encodeJSONUserMap["DisableDownload"] = u.DisableDownload.Load()
encodeJSONUserMap["TrackerHide"] = u.TrackerHide.Load()
encodeJSONUserMap["UpMultiplier"] = math.Float64frombits(u.UpMultiplier.Load())
encodeJSONUserMap["DownMultiplier"] = math.Float64frombits(u.UpMultiplier.Load())
return json.Marshal(encodeJSONUserMap)
}
type decodeJSONUser struct {
ID uint32
DisableDownload bool
TrackerHide bool
UpMultiplier float64
DownMultiplier float64
}
// UnmarshalJSON Due to using atomics, JSON will not marshal values within them.
// This is only safe to call from a single thread at once
func (u *User) UnmarshalJSON(buf []byte) (err error) {
var userJSON decodeJSONUser
if err = json.Unmarshal(buf, &userJSON); err != nil {
return err
}
u.ID.Store(userJSON.ID)
u.DisableDownload.Store(userJSON.DisableDownload)
u.TrackerHide.Store(userJSON.TrackerHide)
u.UpMultiplier.Store(math.Float64bits(userJSON.UpMultiplier))
u.DownMultiplier.Store(math.Float64bits(userJSON.DownMultiplier))
return nil
}
type UserTorrentPair struct {
UserID uint32
TorrentID uint32

View file

@ -221,23 +221,17 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
return http.StatusOK // Required by torrent clients to interpret failure response
}
// Take torrent lock to read/write on it to prevent race conditions
if !torrent.TryLockWithContext(ctx) {
return http.StatusRequestTimeout
}
defer torrent.Unlock()
if torrentStatus := torrent.Status.Load(); torrentStatus == 1 && left == 0 {
log.Info.Printf("Unpruning torrent %d", torrent.ID.Load())
if torrent.Status == 1 && left == 0 {
log.Info.Printf("Unpruning torrent %d", torrent.ID)
torrent.Status = 0
torrent.Status.Store(0)
/* It is okay to do this asynchronously as tracker's internal in-memory state has already been updated for this
torrent. While it is technically possible that we will do this more than once in some cases, the state is of
boolean type so there is no risk of data loss. */
go db.UnPrune(torrent)
} else if torrent.Status != 0 {
failure(fmt.Sprintf("This torrent does not exist (status: %d, left: %d)", torrent.Status, left), buf, 15*time.Minute)
} else if torrentStatus != 0 {
failure(fmt.Sprintf("This torrent does not exist (status: %d, left: %d)", torrentStatus, left), buf, 15*time.Minute)
return http.StatusOK // Required by torrent clients to interpret failure response
}
@ -261,6 +255,12 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
peerKey := cdb.NewPeerKey(user.ID.Load(), cdb.PeerIDFromRawString(peerID))
// Take torrent peers lock to read/write on it to prevent race conditions
if !torrent.TryLockWithContext(ctx) {
return http.StatusRequestTimeout
}
defer torrent.Unlock()
if left > 0 {
if isDisabledDownload(db, user, torrent) {
failure("Your download privileges are disabled", buf, 1*time.Hour)
@ -272,6 +272,8 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
newPeer = true
peer = &cdb.Peer{}
torrent.Leechers[peerKey] = peer
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
}
} else if completed {
peer, exists = torrent.Leechers[peerKey]
@ -279,10 +281,15 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
newPeer = true
peer = &cdb.Peer{}
torrent.Seeders[peerKey] = peer
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
} else {
// Previously tracked peer is now a seeder
torrent.Seeders[peerKey] = peer
delete(torrent.Leechers, peerKey)
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
}
seeding = true
} else {
@ -293,12 +300,17 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
newPeer = true
peer = &cdb.Peer{}
torrent.Seeders[peerKey] = peer
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
} else {
/* Previously tracked peer is now a seeder, however we never received their "completed" event.
Broken client? Unreported snatch? Cross-seeding? Let's not report it as snatch to avoid
over-reporting for cross-seeding */
torrent.Seeders[peerKey] = peer
delete(torrent.Leechers, peerKey)
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
}
}
seeding = true
@ -308,7 +320,7 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
if newPeer {
peer.ID = peerKey.PeerID()
peer.UserID = user.ID.Load()
peer.TorrentID = torrent.ID
peer.TorrentID = torrent.ID.Load()
peer.StartTime = now
peer.LastAnnounce = now
peer.Uploaded = uploaded
@ -330,7 +342,7 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
torrentGroupDownMultiplier := 1.0
torrentGroupUpMultiplier := 1.0
if torrentGroupFreeleech, exists := (*db.TorrentGroupFreeleech.Load())[torrent.Group]; exists {
if torrentGroupFreeleech, exists := (*db.TorrentGroupFreeleech.Load())[torrent.Group.Key()]; exists {
torrentGroupDownMultiplier = torrentGroupFreeleech.DownMultiplier
torrentGroupUpMultiplier = torrentGroupFreeleech.UpMultiplier
}
@ -338,11 +350,11 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
var deltaDownload int64
if !database.GlobalFreeleech.Load() {
deltaDownload = int64(float64(rawDeltaDownload) * math.Abs(math.Float64frombits(user.DownMultiplier.Load())) *
math.Abs(torrentGroupDownMultiplier) * math.Abs(torrent.DownMultiplier))
math.Abs(torrentGroupDownMultiplier) * math.Abs(math.Float64frombits(torrent.DownMultiplier.Load())))
}
deltaUpload := int64(float64(rawDeltaUpload) * math.Abs(math.Float64frombits(user.UpMultiplier.Load())) *
math.Abs(torrentGroupUpMultiplier) * math.Abs(torrent.UpMultiplier))
math.Abs(torrentGroupUpMultiplier) * math.Abs(math.Float64frombits(torrent.UpMultiplier.Load())))
peer.Uploaded = uploaded
peer.Downloaded = downloaded
peer.Left = left
@ -366,7 +378,7 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
/* Update torrent last_action only if announced action is seeding.
This allows dead torrents without seeder but with leecher to be proeprly pruned */
if seeding {
torrent.LastAction = now
torrent.LastAction.Store(now)
}
var deltaSnatch uint8
@ -377,8 +389,10 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
should be gone, allowing the peer to be GC'd. */
if seeding {
delete(torrent.Seeders, peerKey)
torrent.SeedersLength.Store(uint32(len(torrent.Seeders)))
} else {
delete(torrent.Leechers, peerKey)
torrent.LeechersLength.Store(uint32(len(torrent.Leechers)))
}
active = false
@ -415,9 +429,9 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
left)
// Generate response
seedCount := len(torrent.Seeders)
leechCount := len(torrent.Leechers)
snatchCount := torrent.Snatched
seedCount := int(torrent.SeedersLength.Load())
leechCount := int(torrent.LeechersLength.Load())
snatchCount := uint16(torrent.Snatched.Load())
response := make(map[string]interface{})
response["complete"] = seedCount

View file

@ -40,11 +40,7 @@ func metrics(ctx context.Context, auth string, db *database.Database, buf *bytes
peers := 0
for _, t := range dbTorrents {
func() {
t.RLock()
defer t.RUnlock()
peers += len(t.Leechers) + len(t.Seeders)
}()
peers += int(t.LeechersLength.Load()) + int(t.SeedersLength.Load())
}
// Early exit before response write

View file

@ -37,13 +37,10 @@ func init() {
}
func writeScrapeInfo(torrent *cdb.Torrent) map[string]interface{} {
torrent.RLock()
defer torrent.RUnlock()
ret := make(map[string]interface{})
ret["complete"] = len(torrent.Seeders)
ret["downloaded"] = torrent.Snatched
ret["incomplete"] = len(torrent.Leechers)
ret["complete"] = torrent.SeedersLength.Load()
ret["downloaded"] = torrent.Snatched.Load()
ret["incomplete"] = torrent.LeechersLength.Load()
return ret
}

View file

@ -90,5 +90,5 @@ func hasHitAndRun(db *database.Database, userID, torrentID uint32) bool {
func isDisabledDownload(db *database.Database, user *cdb.User, torrent *cdb.Torrent) bool {
// Only disable download if the torrent doesn't have a HnR against it
return user.DisableDownload.Load() && !hasHitAndRun(db, user.ID.Load(), torrent.ID)
return user.DisableDownload.Load() && !hasHitAndRun(db, user.ID.Load(), torrent.ID.Load())
}

20
util/context_ticker.go Normal file
View file

@ -0,0 +1,20 @@
package util
import (
"context"
"time"
)
func ContextTick(ctx context.Context, d time.Duration, onTick func()) {
ticker := time.NewTicker(d)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
onTick()
}
}
}