Compare commits

...

3 commits

19 changed files with 432 additions and 517 deletions

View file

@ -58,12 +58,7 @@ Configuration is done in `config.json`, which you'll need to create with the fol
```json
{
"database": {
"username": "chihaya",
"password": "",
"database": "chihaya",
"proto": "tcp",
"addr": "127.0.0.1:3306",
"dsn": "chihaya:@tcp(127.0.0.1:3306)/chihaya",
"deadlock_pause": 1,
"deadlock_retries": 5
},

View file

@ -133,17 +133,17 @@ func main() {
newUserID++
// Create mapping to get consistent peers
anonUserMapping[user.ID] = newUserID
anonUserMapping[user.ID.Load()] = newUserID
// Replaces user id
user.ID = newUserID
user.ID.Store(newUserID)
// Replaces hidden flag
user.TrackerHide = false
user.TrackerHide.Store(false)
// Replace Up/Down multipliers with baseline
user.UpMultiplier = 1.0
user.DownMultiplier = 1.0
user.UpMultiplier.Store(math.Float64bits(1.0))
user.DownMultiplier.Store(math.Float64bits(1.0))
// Replace passkey with a random provided one of same length
for {

View file

@ -40,18 +40,16 @@ func TestMain(m *testing.M) {
panic(err)
}
f, err := os.OpenFile("config.json", os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
f, err := os.OpenFile(configFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
panic(err)
}
configTest = make(Map)
dbConfig := map[string]interface{}{
"username": "chihaya",
"password": "",
"proto": "tcp",
"addr": "127.0.0.1:3306",
"database": "chihaya",
"dsn": "chihaya:@tcp(127.0.0.1:3306)/chihaya",
"deadlock_pause": json.Number("1"),
"deadlock_retries": json.Number("5"),
}
configTest["database"] = dbConfig
configTest["addr"] = ":34000"
@ -148,5 +146,5 @@ func TestSection(t *testing.T) {
}
func cleanup() {
_ = os.Remove("config.json")
_ = os.Remove(configFile)
}

View file

@ -19,9 +19,8 @@ package database
import (
"bytes"
"context"
"database/sql"
"fmt"
"github.com/viney-shih/go-lock"
"os"
"sync"
"sync/atomic"
@ -36,15 +35,7 @@ import (
"github.com/go-sql-driver/mysql"
)
type Connection struct {
sqlDb *sql.DB
mutex sync.Mutex
}
type Database struct {
TorrentsLock lock.RWMutex
UsersLock lock.RWMutex
snatchChannel chan *bytes.Buffer
transferHistoryChannel chan *bytes.Buffer
transferIpsChannel chan *bytes.Buffer
@ -60,19 +51,21 @@ type Database struct {
cleanStalePeersStmt *sql.Stmt
unPruneTorrentStmt *sql.Stmt
Users map[string]*cdb.User
Users atomic.Pointer[map[string]*cdb.User]
HitAndRuns atomic.Pointer[map[cdb.UserTorrentPair]struct{}]
Torrents map[cdb.TorrentHash]*cdb.Torrent // SHA-1 hash (20 bytes)
Torrents atomic.Pointer[map[cdb.TorrentHash]*cdb.Torrent]
TorrentGroupFreeleech atomic.Pointer[map[cdb.TorrentGroup]*cdb.TorrentGroupFreeleech]
Clients atomic.Pointer[map[uint16]string]
mainConn *Connection // Used for reloading and misc queries
bufferPool *util.BufferPool
transferHistoryLock lock.Mutex
transferHistoryLock sync.Mutex
terminate bool
conn *sql.DB
terminate atomic.Bool
ctx context.Context
ctxCancel func()
waitGroup sync.WaitGroup
}
@ -81,82 +74,76 @@ var (
maxDeadlockRetries int
)
var defaultDsn = map[string]string{
"username": "chihaya",
"password": "",
"proto": "tcp",
"addr": "127.0.0.1:3306",
"database": "chihaya",
}
const defaultDsn = "chihaya:@tcp(127.0.0.1:3306)/chihaya"
func (db *Database) Init() {
db.terminate = false
db.terminate.Store(false)
db.ctx, db.ctxCancel = context.WithCancel(context.Background())
log.Info.Print("Opening database connection...")
db.mainConn = Open()
// Initializing locks
db.TorrentsLock = lock.NewCASMutex()
db.UsersLock = lock.NewCASMutex()
db.conn = Open()
// Used for recording updates, so the max required size should be < 128 bytes. See queue.go for details
db.bufferPool = util.NewBufferPool(128)
var err error
db.loadUsersStmt, err = db.mainConn.sqlDb.Prepare(
db.loadUsersStmt, err = db.conn.Prepare(
"SELECT ID, torrent_pass, DownMultiplier, UpMultiplier, DisableDownload, TrackerHide " +
"FROM users_main WHERE Enabled = '1'")
if err != nil {
panic(err)
}
db.loadHnrStmt, err = db.mainConn.sqlDb.Prepare(
db.loadHnrStmt, err = db.conn.Prepare(
"SELECT h.uid, h.fid FROM transfer_history AS h " +
"JOIN users_main AS u ON u.ID = h.uid WHERE h.hnr = 1 AND u.Enabled = '1'")
if err != nil {
panic(err)
}
db.loadTorrentsStmt, err = db.mainConn.sqlDb.Prepare(
db.loadTorrentsStmt, err = db.conn.Prepare(
"SELECT ID, info_hash, DownMultiplier, UpMultiplier, Snatched, Status, GroupID, TorrentType FROM torrents")
if err != nil {
panic(err)
}
db.loadTorrentGroupFreeleechStmt, err = db.mainConn.sqlDb.Prepare(
db.loadTorrentGroupFreeleechStmt, err = db.conn.Prepare(
"SELECT GroupID, `Type`, DownMultiplier, UpMultiplier FROM torrent_group_freeleech")
if err != nil {
panic(err)
}
db.loadClientsStmt, err = db.mainConn.sqlDb.Prepare(
db.loadClientsStmt, err = db.conn.Prepare(
"SELECT id, peer_id FROM approved_clients WHERE archived = 0")
if err != nil {
panic(err)
}
db.loadFreeleechStmt, err = db.mainConn.sqlDb.Prepare(
db.loadFreeleechStmt, err = db.conn.Prepare(
"SELECT mod_setting FROM mod_core WHERE mod_option = 'global_freeleech'")
if err != nil {
panic(err)
}
db.cleanStalePeersStmt, err = db.mainConn.sqlDb.Prepare(
db.cleanStalePeersStmt, err = db.conn.Prepare(
"UPDATE transfer_history SET active = 0 WHERE last_announce < ? AND active = 1")
if err != nil {
panic(err)
}
db.unPruneTorrentStmt, err = db.mainConn.sqlDb.Prepare(
db.unPruneTorrentStmt, err = db.conn.Prepare(
"UPDATE torrents SET Status = 0 WHERE ID = ?")
if err != nil {
panic(err)
}
db.Users = make(map[string]*cdb.User)
db.Torrents = make(map[cdb.TorrentHash]*cdb.Torrent)
dbUsers := make(map[string]*cdb.User)
db.Users.Store(&dbUsers)
dbTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
db.Torrents.Store(&dbTorrents)
dbHitAndRuns := make(map[cdb.UserTorrentPair]struct{})
db.HitAndRuns.Store(&dbHitAndRuns)
@ -184,7 +171,8 @@ func (db *Database) Init() {
func (db *Database) Terminate() {
log.Info.Print("Terminating database connection...")
db.terminate = true
db.terminate.Store(true)
db.ctxCancel()
log.Info.Print("Closing all flush channels...")
db.closeFlushChannels()
@ -195,13 +183,11 @@ func (db *Database) Terminate() {
}()
db.waitGroup.Wait()
db.mainConn.mutex.Lock()
_ = db.mainConn.Close()
db.mainConn.mutex.Unlock()
_ = db.conn.Close()
db.serialize()
}
func Open() *Connection {
func Open() *sql.DB {
databaseConfig := config.Section("database")
deadlockWaitTime, _ = databaseConfig.GetInt("deadlock_pause", 1)
maxDeadlockRetries, _ = databaseConfig.GetInt("deadlock_retries", 5)
@ -217,18 +203,7 @@ func Open() *Connection {
// First try to load the DSN from environment. USeful for tests.
databaseDsn := os.Getenv("DB_DSN")
if databaseDsn == "" {
dbUsername, _ := databaseConfig.Get("username", defaultDsn["username"])
dbPassword, _ := databaseConfig.Get("password", defaultDsn["password"])
dbProto, _ := databaseConfig.Get("proto", defaultDsn["proto"])
dbAddr, _ := databaseConfig.Get("addr", defaultDsn["addr"])
dbDatabase, _ := databaseConfig.Get("database", defaultDsn["database"])
databaseDsn = fmt.Sprintf("%s:%s@%s(%s)/%s",
dbUsername,
dbPassword,
dbProto,
dbAddr,
dbDatabase,
)
databaseDsn, _ = databaseConfig.Get("dsn", defaultDsn)
}
sqlDb, err := sql.Open("mysql", databaseDsn)
@ -240,16 +215,10 @@ func Open() *Connection {
log.Fatal.Fatalf("Couldn't ping database - %s", err)
}
return &Connection{
sqlDb: sqlDb,
}
return sqlDb
}
func (db *Connection) Close() error {
return db.sqlDb.Close()
}
func (db *Connection) query(stmt *sql.Stmt, args ...interface{}) *sql.Rows { //nolint:unparam
func (db *Database) query(stmt *sql.Stmt, args ...interface{}) *sql.Rows { //nolint:unparam
rows, _ := perform(func() (interface{}, error) {
return stmt.Query(args...)
}).(*sql.Rows)
@ -257,7 +226,7 @@ func (db *Connection) query(stmt *sql.Stmt, args ...interface{}) *sql.Rows { //n
return rows
}
func (db *Connection) execute(stmt *sql.Stmt, args ...interface{}) sql.Result {
func (db *Database) execute(stmt *sql.Stmt, args ...interface{}) sql.Result {
result, _ := perform(func() (interface{}, error) {
return stmt.Exec(args...)
}).(sql.Result)
@ -265,9 +234,9 @@ func (db *Connection) execute(stmt *sql.Stmt, args ...interface{}) sql.Result {
return result
}
func (db *Connection) exec(query *bytes.Buffer, args ...interface{}) sql.Result { //nolint:unparam
func (db *Database) exec(query *bytes.Buffer, args ...interface{}) sql.Result { //nolint:unparam
result, _ := perform(func() (interface{}, error) {
return db.sqlDb.Exec(query.String(), args...)
return db.conn.Exec(query.String(), args...)
}).(sql.Result)
return result

View file

@ -19,6 +19,7 @@ package database
import (
"fmt"
"math"
"net"
"os"
"reflect"
@ -44,7 +45,7 @@ func TestMain(m *testing.M) {
db.Init()
fixtures, err = testfixtures.New(
testfixtures.Database(db.mainConn.sqlDb),
testfixtures.Database(db.conn),
testfixtures.Dialect("mariadb"),
testfixtures.Directory("fixtures"),
testfixtures.DangerousSkipTestDatabaseCheck(),
@ -59,47 +60,55 @@ func TestMain(m *testing.M) {
func TestLoadUsers(t *testing.T) {
prepareTestDatabase()
db.Users = make(map[string]*cdb.User)
dbUsers := make(map[string]*cdb.User)
db.Users.Store(&dbUsers)
testUser1 := &cdb.User{}
testUser1.ID.Store(1)
testUser1.DownMultiplier.Store(math.Float64bits(1))
testUser1.UpMultiplier.Store(math.Float64bits(1))
testUser1.DisableDownload.Store(false)
testUser1.TrackerHide.Store(false)
testUser2 := &cdb.User{}
testUser2.ID.Store(2)
testUser2.DownMultiplier.Store(math.Float64bits(2))
testUser2.UpMultiplier.Store(math.Float64bits(0.5))
testUser2.DisableDownload.Store(true)
testUser2.TrackerHide.Store(true)
users := map[string]*cdb.User{
"mUztWMpBYNCqzmge6vGeEUGSrctJbgpQ": {
ID: 1,
UpMultiplier: 1,
DownMultiplier: 1,
TrackerHide: false,
DisableDownload: false,
},
"tbHfQDQ9xDaQdsNv5CZBtHPfk7KGzaCw": {
ID: 2,
UpMultiplier: 0.5,
DownMultiplier: 2,
TrackerHide: true,
DisableDownload: true,
},
"mUztWMpBYNCqzmge6vGeEUGSrctJbgpQ": testUser1,
"tbHfQDQ9xDaQdsNv5CZBtHPfk7KGzaCw": testUser2,
}
// Test with fresh data
db.loadUsers()
if len(db.Users) != len(users) {
t.Fatal(fixtureFailure("Did not load all users as expected from fixture file", len(users), len(db.Users)))
dbUsers = *db.Users.Load()
if len(dbUsers) != len(users) {
t.Fatal(fixtureFailure("Did not load all users as expected from fixture file", len(users), len(dbUsers)))
}
for passkey, user := range users {
if !reflect.DeepEqual(user, db.Users[passkey]) {
if !reflect.DeepEqual(user, dbUsers[passkey]) {
t.Fatal(fixtureFailure(
fmt.Sprintf("Did not load user (%s) as expected from fixture file", passkey),
user,
db.Users[passkey]))
dbUsers[passkey]))
}
}
// Now test load on top of existing data
oldUsers := db.Users
oldUsers := dbUsers
db.loadUsers()
if !reflect.DeepEqual(oldUsers, db.Users) {
t.Fatal(fixtureFailure("Did not reload users as expected from fixture file", oldUsers, db.Users))
dbUsers = *db.Users.Load()
if !reflect.DeepEqual(oldUsers, dbUsers) {
t.Fatal(fixtureFailure("Did not reload users as expected from fixture file", oldUsers, dbUsers))
}
}
@ -133,7 +142,8 @@ func TestLoadHitAndRuns(t *testing.T) {
func TestLoadTorrents(t *testing.T) {
prepareTestDatabase()
db.Torrents = make(map[cdb.TorrentHash]*cdb.Torrent)
dbTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
db.Torrents.Store(&dbTorrents)
torrents := map[cdb.TorrentHash]*cdb.Torrent{
{114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56}: {
@ -180,28 +190,32 @@ func TestLoadTorrents(t *testing.T) {
// Test with fresh data
db.loadTorrents()
if len(db.Torrents) != len(torrents) {
dbTorrents = *db.Torrents.Load()
if len(dbTorrents) != len(torrents) {
t.Fatal(fixtureFailure("Did not load all torrents as expected from fixture file",
len(torrents),
len(db.Torrents)))
len(dbTorrents)))
}
for hash, torrent := range torrents {
if !reflect.DeepEqual(torrent, db.Torrents[hash]) {
if !reflect.DeepEqual(torrent, dbTorrents[hash]) {
t.Fatal(fixtureFailure(
fmt.Sprintf("Did not load torrent (%s) as expected from fixture file", hash),
torrent,
db.Torrents[hash]))
dbTorrents[hash]))
}
}
// Now test load on top of existing data
oldTorrents := db.Torrents
oldTorrents := dbTorrents
db.loadTorrents()
if !reflect.DeepEqual(oldTorrents, db.Torrents) {
t.Fatal(fixtureFailure("Did not reload torrents as expected from fixture file", oldTorrents, db.Torrents))
dbTorrents = *db.Torrents.Load()
if !reflect.DeepEqual(oldTorrents, dbTorrents) {
t.Fatal(fixtureFailure("Did not reload torrents as expected from fixture file", oldTorrents, dbTorrents))
}
}
@ -297,36 +311,43 @@ func TestLoadClients(t *testing.T) {
func TestUnPrune(t *testing.T) {
prepareTestDatabase()
dbTorrents := *db.Torrents.Load()
h := cdb.TorrentHash{114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56}
torrent := cdb.Torrent{
Seeders: db.Torrents[h].Seeders,
Leechers: db.Torrents[h].Leechers,
Group: db.Torrents[h].Group,
ID: db.Torrents[h].ID,
Snatched: db.Torrents[h].Snatched,
Status: db.Torrents[h].Status,
LastAction: db.Torrents[h].LastAction,
UpMultiplier: db.Torrents[h].UpMultiplier,
DownMultiplier: db.Torrents[h].DownMultiplier,
Seeders: dbTorrents[h].Seeders,
Leechers: dbTorrents[h].Leechers,
Group: dbTorrents[h].Group,
ID: dbTorrents[h].ID,
Snatched: dbTorrents[h].Snatched,
Status: dbTorrents[h].Status,
LastAction: dbTorrents[h].LastAction,
UpMultiplier: dbTorrents[h].UpMultiplier,
DownMultiplier: dbTorrents[h].DownMultiplier,
}
torrent.InitializeLock()
torrent.Status = 0
db.UnPrune(db.Torrents[h])
db.UnPrune(dbTorrents[h])
db.loadTorrents()
if !reflect.DeepEqual(&torrent, db.Torrents[h]) {
dbTorrents = *db.Torrents.Load()
if !reflect.DeepEqual(&torrent, dbTorrents[h]) {
t.Fatal(fixtureFailure(
fmt.Sprintf("Torrent (%x) was not unpruned properly", h),
&torrent,
db.Torrents[h]))
dbTorrents[h]))
}
}
func TestRecordAndFlushUsers(t *testing.T) {
prepareTestDatabase()
testUser := db.Users["tbHfQDQ9xDaQdsNv5CZBtHPfk7KGzaCw"]
dbUsers := *db.Users.Load()
testUser := dbUsers["tbHfQDQ9xDaQdsNv5CZBtHPfk7KGzaCw"]
var (
initUpload int64
@ -347,11 +368,11 @@ func TestRecordAndFlushUsers(t *testing.T) {
deltaRawDownload = 83472
deltaRawUpload = 245
deltaDownload = int64(float64(deltaRawDownload) * testUser.DownMultiplier)
deltaUpload = int64(float64(deltaRawUpload) * testUser.UpMultiplier)
deltaDownload = int64(float64(deltaRawDownload) * math.Float64frombits(testUser.DownMultiplier.Load()))
deltaUpload = int64(float64(deltaRawUpload) * math.Float64frombits(testUser.UpMultiplier.Load()))
row := db.mainConn.sqlDb.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
"FROM users_main WHERE ID = ?", testUser.ID)
row := db.conn.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
"FROM users_main WHERE ID = ?", testUser.ID.Load())
err := row.Scan(&initUpload, &initDownload, &initRawUpload, &initRawDownload)
if err != nil {
@ -365,8 +386,8 @@ func TestRecordAndFlushUsers(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row = db.mainConn.sqlDb.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
"FROM users_main WHERE ID = ?", testUser.ID)
row = db.conn.QueryRow("SELECT Uploaded, Downloaded, rawup, rawdl "+
"FROM users_main WHERE ID = ?", testUser.ID.Load())
err = row.Scan(&upload, &download, &rawUpload, &rawDownload)
if err != nil {
@ -446,7 +467,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
deltaActiveTime = 267
deltaSeedTime = 15
row := db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
row := db.conn.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
"FROM transfer_history WHERE uid = ? AND fid = ?", testPeer.UserID, testPeer.TorrentID)
err := row.Scan(&initRawUpload, &initRawDownload, &initActiveTime, &initSeedTime, &initActive, &initSnatch)
@ -467,7 +488,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row = db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
row = db.conn.QueryRow("SELECT uploaded, downloaded, activetime, seedtime, active, snatched "+
"FROM transfer_history WHERE uid = ? AND fid = ?", testPeer.UserID, testPeer.TorrentID)
err = row.Scan(&rawUpload, &rawDownload, &activeTime, &seedTime, &active, &snatch)
@ -528,7 +549,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
var gotStartTime int64
row = db.mainConn.sqlDb.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
row = db.conn.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
"FROM transfer_history WHERE uid = ? AND fid = ?", gotPeer.UserID, gotPeer.TorrentID)
err = row.Scan(&gotPeer.Seeding, &gotStartTime, &gotPeer.LastAnnounce, &gotPeer.Left)
@ -573,7 +594,7 @@ func TestRecordAndFlushTransferHistory(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row = db.mainConn.sqlDb.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
row = db.conn.QueryRow("SELECT seeding, starttime, last_announce, remaining "+
"FROM transfer_history WHERE uid = ? AND fid = ?", gotPeer.UserID, gotPeer.TorrentID)
err = row.Scan(&gotPeer.Seeding, &gotPeer.StartTime, &gotPeer.LastAnnounce, &gotPeer.Left)
@ -610,7 +631,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
deltaDownload = 236
deltaUpload = 3262
row := db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded "+
row := db.conn.QueryRow("SELECT uploaded, downloaded "+
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
@ -626,7 +647,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row = db.mainConn.sqlDb.QueryRow("SELECT uploaded, downloaded "+
row = db.conn.QueryRow("SELECT uploaded, downloaded "+
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
@ -662,7 +683,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
var gotStartTime int64
row = db.mainConn.sqlDb.QueryRow("SELECT port, starttime, last_announce "+
row = db.conn.QueryRow("SELECT port, starttime, last_announce "+
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
@ -707,7 +728,7 @@ func TestRecordAndFlushTransferIP(t *testing.T) {
Addr: cdb.NewPeerAddressFromIPPort(testPeer.Addr.IP(), 0),
}
row = db.mainConn.sqlDb.QueryRow("SELECT port, starttime, last_announce "+
row = db.conn.QueryRow("SELECT port, starttime, last_announce "+
"FROM transfer_ips WHERE uid = ? AND fid = ? AND ip = ? AND client_id = ?",
testPeer.UserID, testPeer.TorrentID, testPeer.Addr.IPNumeric(), testPeer.ClientID)
@ -745,7 +766,7 @@ func TestRecordAndFlushSnatch(t *testing.T) {
}
time.Sleep(200 * time.Millisecond)
row := db.mainConn.sqlDb.QueryRow("SELECT snatched_time "+
row := db.conn.QueryRow("SELECT snatched_time "+
"FROM transfer_history WHERE uid = ? AND fid = ?", testPeer.UserID, testPeer.TorrentID)
err := row.Scan(&snatchTime)
@ -765,7 +786,7 @@ func TestRecordAndFlushTorrents(t *testing.T) {
prepareTestDatabase()
h := cdb.TorrentHash{114, 239, 32, 237, 220, 181, 67, 143, 115, 182, 216, 141, 120, 196, 223, 193, 102, 123, 137, 56}
torrent := db.Torrents[h]
torrent := (*db.Torrents.Load())[h]
torrent.LastAction = time.Now().Unix()
torrent.Seeders[cdb.NewPeerKey(1, cdb.PeerIDFromRawString("test_peer_id_num_one"))] = &cdb.Peer{
UserID: 1,
@ -796,7 +817,7 @@ func TestRecordAndFlushTorrents(t *testing.T) {
numSeeders int
)
row := db.mainConn.sqlDb.QueryRow("SELECT Snatched, last_action, Seeders, Leechers "+
row := db.conn.QueryRow("SELECT Snatched, last_action, Seeders, Leechers "+
"FROM torrents WHERE ID = ?", torrent.ID)
err := row.Scan(&snatched, &lastAction, &numSeeders, &numLeechers)

View file

@ -19,8 +19,8 @@ package database
import (
"bytes"
"chihaya/util"
"errors"
"github.com/viney-shih/go-lock"
"time"
"chihaya/collectors"
@ -82,8 +82,6 @@ func (db *Database) startFlushing() {
db.transferIpsChannel = make(chan *bytes.Buffer, transferIpsFlushBufferSize)
db.snatchChannel = make(chan *bytes.Buffer, snatchFlushBufferSize)
db.transferHistoryLock = lock.NewCASMutex()
go db.flushTorrents()
go db.flushUsers()
go db.flushTransferHistory() // Can not be blocking or it will lock purgeInactivePeers when chan is empty
@ -114,25 +112,9 @@ func (db *Database) flushTorrents() {
count int
)
conn := Open()
for {
query.Reset()
query.WriteString("CREATE TEMPORARY TABLE IF NOT EXISTS flush_torrents (" +
"ID int unsigned NOT NULL, " +
"Snatched int unsigned NOT NULL DEFAULT 0, " +
"Seeders int unsigned NOT NULL DEFAULT 0, " +
"Leechers int unsigned NOT NULL DEFAULT 0, " +
"last_action int NOT NULL DEFAULT 0, " +
"PRIMARY KEY (ID)) ENGINE=MEMORY")
conn.exec(&query)
query.Reset()
query.WriteString("TRUNCATE flush_torrents")
conn.exec(&query)
query.Reset()
query.WriteString("INSERT INTO flush_torrents VALUES ")
query.WriteString("INSERT IGNORE INTO torrents (ID, Snatched, Seeders, Leechers, last_action) VALUES ")
length := len(db.torrentChannel)
@ -152,7 +134,7 @@ func (db *Database) flushTorrents() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{torrents} Flushing %d", count)
}
@ -161,18 +143,9 @@ func (db *Database) flushTorrents() {
query.WriteString(" ON DUPLICATE KEY UPDATE Snatched = Snatched + VALUE(Snatched), " +
"Seeders = VALUE(Seeders), Leechers = VALUE(Leechers), " +
"last_action = IF(last_action < VALUE(last_action), VALUE(last_action), last_action)")
conn.exec(&query)
db.exec(&query)
query.Reset()
query.WriteString("UPDATE torrents t, flush_torrents ft SET " +
"t.Snatched = t.Snatched + ft.Snatched, " +
"t.Seeders = ft.Seeders, " +
"t.Leechers = ft.Leechers, " +
"t.last_action = IF(t.last_action < ft.last_action, ft.last_action, t.last_action)" +
"WHERE t.ID = ft.ID")
conn.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("torrents", elapsedTime)
collectors.UpdateChannelsLen("torrents", count)
@ -181,14 +154,12 @@ func (db *Database) flushTorrents() {
if length < (torrentFlushBufferSize >> 1) {
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
}
} else if db.terminate {
} else if db.terminate.Load() {
break
} else {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) flushUsers() {
@ -200,25 +171,9 @@ func (db *Database) flushUsers() {
count int
)
conn := Open()
for {
query.Reset()
query.WriteString("CREATE TEMPORARY TABLE IF NOT EXISTS flush_users (" +
"ID int unsigned NOT NULL, " +
"Uploaded bigint unsigned NOT NULL DEFAULT 0, " +
"Downloaded bigint unsigned NOT NULL DEFAULT 0, " +
"rawdl bigint unsigned NOT NULL DEFAULT 0, " +
"rawup bigint unsigned NOT NULL DEFAULT 0, " +
"PRIMARY KEY (ID)) ENGINE=MEMORY")
conn.exec(&query)
query.Reset()
query.WriteString("TRUNCATE flush_users")
conn.exec(&query)
query.Reset()
query.WriteString("INSERT INTO flush_users VALUES ")
query.WriteString("INSERT IGNORE INTO users_main (ID, Uploaded, Downloaded, rawdl, rawup) VALUES ")
length := len(db.userChannel)
@ -238,7 +193,7 @@ func (db *Database) flushUsers() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{users_main} Flushing %d", count)
}
@ -246,18 +201,9 @@ func (db *Database) flushUsers() {
query.WriteString(" ON DUPLICATE KEY UPDATE Uploaded = Uploaded + VALUE(Uploaded), " +
"Downloaded = Downloaded + VALUE(Downloaded), rawdl = rawdl + VALUE(rawdl), rawup = rawup + VALUE(rawup)")
conn.exec(&query)
db.exec(&query)
query.Reset()
query.WriteString("UPDATE users_main u, flush_users fu SET " +
"u.Uploaded = u.Uploaded + fu.Uploaded, " +
"u.Downloaded = u.Downloaded + fu.Downloaded, " +
"u.rawdl = u.rawdl + fu.rawdl, " +
"u.rawup = u.rawup + fu.rawup " +
"WHERE u.ID = fu.ID")
conn.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("users", elapsedTime)
collectors.UpdateChannelsLen("users", count)
@ -266,14 +212,12 @@ func (db *Database) flushUsers() {
if length < (userFlushBufferSize >> 1) {
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
}
} else if db.terminate {
} else if db.terminate.Load() {
break
} else {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) flushTransferHistory() {
@ -285,8 +229,6 @@ func (db *Database) flushTransferHistory() {
count int
)
conn := Open()
for {
length, err := func() (int, error) {
db.transferHistoryLock.Lock()
@ -314,7 +256,7 @@ func (db *Database) flushTransferHistory() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{transfer_history} Flushing %d", count)
}
@ -326,16 +268,16 @@ func (db *Database) flushTransferHistory() {
"seedtime = seedtime + VALUE(seedtime), last_announce = VALUE(last_announce), " +
"active = VALUE(active), snatched = snatched + VALUE(snatched);")
conn.exec(&query)
db.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("transfer_history", elapsedTime)
collectors.UpdateChannelsLen("transfer_history", count)
}
return length, nil
} else if db.terminate {
} else if db.terminate.Load() {
return 0, errDbTerminate
}
@ -350,8 +292,6 @@ func (db *Database) flushTransferHistory() {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) flushTransferIps() {
@ -363,8 +303,6 @@ func (db *Database) flushTransferIps() {
count int
)
conn := Open()
for {
query.Reset()
query.WriteString("INSERT INTO transfer_ips (uid, fid, client_id, ip, port, uploaded, downloaded, " +
@ -388,7 +326,7 @@ func (db *Database) flushTransferIps() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{transfer_ips} Flushing %d", count)
}
@ -397,9 +335,9 @@ func (db *Database) flushTransferIps() {
// todo: port should be part of PK
query.WriteString("\nON DUPLICATE KEY UPDATE port = VALUE(port), downloaded = downloaded + VALUE(downloaded), " +
"uploaded = uploaded + VALUE(uploaded), last_announce = VALUE(last_announce)")
conn.exec(&query)
db.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("transfer_ips", elapsedTime)
collectors.UpdateChannelsLen("transfer_ips", count)
@ -408,14 +346,12 @@ func (db *Database) flushTransferIps() {
if length < (transferIpsFlushBufferSize >> 1) {
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
}
} else if db.terminate {
} else if db.terminate.Load() {
break
} else {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) flushSnatches() {
@ -427,8 +363,6 @@ func (db *Database) flushSnatches() {
count int
)
conn := Open()
for {
query.Reset()
query.WriteString("INSERT INTO transfer_history (uid, fid, snatched_time) VALUES\n")
@ -451,16 +385,16 @@ func (db *Database) flushSnatches() {
if count > 0 {
logFlushes, _ := config.GetBool("log_flushes", true)
if logFlushes && !db.terminate {
if logFlushes && !db.terminate.Load() {
log.Info.Printf("{snatches} Flushing %d", count)
}
startTime := time.Now()
query.WriteString("\nON DUPLICATE KEY UPDATE snatched_time = VALUE(snatched_time)")
conn.exec(&query)
db.exec(&query)
if !db.terminate {
if !db.terminate.Load() {
elapsedTime := time.Since(startTime)
collectors.UpdateFlushTime("snatches", elapsedTime)
collectors.UpdateChannelsLen("snatches", count)
@ -469,14 +403,12 @@ func (db *Database) flushSnatches() {
if length < (snatchFlushBufferSize >> 1) {
time.Sleep(time.Duration(flushSleepInterval) * time.Second)
}
} else if db.terminate {
} else if db.terminate.Load() {
break
} else {
time.Sleep(time.Second)
}
}
_ = conn.Close()
}
func (db *Database) purgeInactivePeers() {
@ -486,7 +418,7 @@ func (db *Database) purgeInactivePeers() {
count int
)
for !db.terminate {
util.ContextTick(db.ctx, time.Duration(purgeInactivePeersInterval)*time.Second, func() {
start = time.Now()
now = start.Unix()
count = 0
@ -494,45 +426,41 @@ func (db *Database) purgeInactivePeers() {
oldestActive := now - int64(peerInactivityInterval)
// First, remove inactive peers from memory
func() {
db.TorrentsLock.RLock()
defer db.TorrentsLock.RUnlock()
dbTorrents := *db.Torrents.Load()
for _, torrent := range dbTorrents {
func() {
//Take write lock to operate on entries
torrent.Lock()
defer torrent.Unlock()
for _, torrent := range db.Torrents {
func() {
//Take write lock to operate on entries
torrent.Lock()
defer torrent.Unlock()
countThisTorrent := count
countThisTorrent := count
for id, peer := range torrent.Leechers {
if peer.LastAnnounce < oldestActive {
delete(torrent.Leechers, id)
count++
}
for id, peer := range torrent.Leechers {
if peer.LastAnnounce < oldestActive {
delete(torrent.Leechers, id)
count++
}
}
if countThisTorrent != count && len(torrent.Leechers) == 0 {
/* Deallocate previous map since Go does not free space used on maps when deleting objects.
We're doing it only for Leechers as potential advantage from freeing one or two Seeders is
virtually nil, while Leechers can incur significant memory leaks due to initial swarm activity. */
torrent.Leechers = make(map[cdb.PeerKey]*cdb.Peer)
}
if countThisTorrent != count && len(torrent.Leechers) == 0 {
/* Deallocate previous map since Go does not free space used on maps when deleting objects.
We're doing it only for Leechers as potential advantage from freeing one or two Seeders is
virtually nil, while Leechers can incur significant memory leaks due to initial swarm activity. */
torrent.Leechers = make(map[cdb.PeerKey]*cdb.Peer)
}
for id, peer := range torrent.Seeders {
if peer.LastAnnounce < oldestActive {
delete(torrent.Seeders, id)
count++
}
for id, peer := range torrent.Seeders {
if peer.LastAnnounce < oldestActive {
delete(torrent.Seeders, id)
count++
}
}
if countThisTorrent != count {
db.QueueTorrent(torrent, 0)
}
}()
}
}()
if countThisTorrent != count {
db.QueueTorrent(torrent, 0)
}
}()
}
elapsedTime := time.Since(start)
collectors.UpdateFlushTime("purging_inactive_peers", elapsedTime)
@ -547,11 +475,8 @@ func (db *Database) purgeInactivePeers() {
db.transferHistoryLock.Lock()
defer db.transferHistoryLock.Unlock()
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start = time.Now()
result := db.mainConn.execute(db.cleanStalePeersStmt, oldestActive)
result := db.execute(db.cleanStalePeersStmt, oldestActive)
if result != nil {
rows, err := result.RowsAffected()
@ -562,7 +487,5 @@ func (db *Database) purgeInactivePeers() {
}
}
}()
time.Sleep(time.Duration(purgeInactivePeersInterval) * time.Second)
}
})
}

View file

@ -67,7 +67,7 @@ func (db *Database) QueueUser(
uq := db.bufferPool.Take()
uq.WriteString("(")
uq.WriteString(strconv.FormatUint(uint64(user.ID), 10))
uq.WriteString(strconv.FormatUint(uint64(user.ID.Load()), 10))
uq.WriteString(",")
uq.WriteString(strconv.FormatInt(deltaUp, 10))
uq.WriteString(",")
@ -179,7 +179,5 @@ func (db *Database) QueueSnatch(peer *cdb.Peer, now int64) {
}
func (db *Database) UnPrune(torrent *cdb.Torrent) {
db.mainConn.mutex.Lock()
db.mainConn.execute(db.unPruneTorrentStmt, torrent.ID)
db.mainConn.mutex.Unlock()
db.execute(db.unPruneTorrentStmt, torrent.ID)
}

View file

@ -18,6 +18,7 @@
package database
import (
"math"
"sync/atomic"
"time"
@ -25,6 +26,7 @@ import (
"chihaya/config"
cdb "chihaya/database/types"
"chihaya/log"
"chihaya/util"
)
var (
@ -49,32 +51,26 @@ func init() {
*/
func (db *Database) startReloading() {
go func() {
for !db.terminate {
time.Sleep(time.Duration(reloadInterval) * time.Second)
util.ContextTick(db.ctx, time.Duration(reloadInterval)*time.Second, func() {
db.waitGroup.Add(1)
defer db.waitGroup.Done()
db.loadUsers()
db.loadHitAndRuns()
db.loadTorrents()
db.loadGroupsFreeleech()
db.loadConfig()
db.loadClients()
db.waitGroup.Done()
}
})
}()
}
func (db *Database) loadUsers() {
db.UsersLock.Lock()
defer db.UsersLock.Unlock()
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start := time.Now()
newUsers := make(map[string]*cdb.User, len(db.Users))
rows := db.mainConn.query(db.loadUsersStmt)
dbUsers := *db.Users.Load()
newUsers := make(map[string]*cdb.User, len(dbUsers))
rows := db.query(db.loadUsersStmt)
if rows == nil {
log.Error.Print("Failed to load hit and runs from database")
log.WriteStack()
@ -99,40 +95,37 @@ func (db *Database) loadUsers() {
log.WriteStack()
}
if old, exists := db.Users[torrentPass]; exists && old != nil {
old.ID = id
old.DownMultiplier = downMultiplier
old.UpMultiplier = upMultiplier
old.DisableDownload = disableDownload
old.TrackerHide = trackerHide
if old, exists := dbUsers[torrentPass]; exists && old != nil {
old.ID.Store(id)
old.DownMultiplier.Store(math.Float64bits(downMultiplier))
old.UpMultiplier.Store(math.Float64bits(upMultiplier))
old.DisableDownload.Store(disableDownload)
old.TrackerHide.Store(trackerHide)
newUsers[torrentPass] = old
} else {
newUsers[torrentPass] = &cdb.User{
ID: id,
UpMultiplier: upMultiplier,
DownMultiplier: downMultiplier,
DisableDownload: disableDownload,
TrackerHide: trackerHide,
}
u := &cdb.User{}
u.ID.Store(id)
u.DownMultiplier.Store(math.Float64bits(downMultiplier))
u.UpMultiplier.Store(math.Float64bits(upMultiplier))
u.DisableDownload.Store(disableDownload)
u.TrackerHide.Store(trackerHide)
newUsers[torrentPass] = u
}
}
db.Users = newUsers
db.Users.Store(&newUsers)
elapsedTime := time.Since(start)
collectors.UpdateReloadTime("users", elapsedTime)
log.Info.Printf("User load complete (%d rows, %s)", len(db.Users), elapsedTime.String())
log.Info.Printf("User load complete (%d rows, %s)", len(newUsers), elapsedTime.String())
}
func (db *Database) loadHitAndRuns() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start := time.Now()
newHnr := make(map[cdb.UserTorrentPair]struct{})
rows := db.mainConn.query(db.loadHnrStmt)
rows := db.query(db.loadHnrStmt)
if rows == nil {
log.Error.Print("Failed to load hit and runs from database")
log.WriteStack()
@ -169,100 +162,90 @@ func (db *Database) loadHitAndRuns() {
func (db *Database) loadTorrents() {
var start time.Time
newTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
dbTorrents := *db.Torrents.Load()
func() {
db.TorrentsLock.RLock()
defer db.TorrentsLock.RUnlock()
newTorrents := make(map[cdb.TorrentHash]*cdb.Torrent, len(dbTorrents))
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start = time.Now()
start = time.Now()
rows := db.query(db.loadTorrentsStmt)
if rows == nil {
log.Error.Print("Failed to load torrents from database")
log.WriteStack()
rows := db.mainConn.query(db.loadTorrentsStmt)
if rows == nil {
log.Error.Print("Failed to load torrents from database")
log.WriteStack()
return
}
return
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
var (
infoHash cdb.TorrentHash
id uint32
downMultiplier, upMultiplier float64
snatched uint16
status uint8
group cdb.TorrentGroup
)
if err := rows.Scan(
&id,
&infoHash,
&downMultiplier,
&upMultiplier,
&snatched,
&status,
&group.GroupID,
&group.TorrentType,
); err != nil {
log.Error.Printf("Error scanning torrent row: %s", err)
log.WriteStack()
}
if old, exists := db.Torrents[infoHash]; exists && old != nil {
func() {
old.Lock()
defer old.Unlock()
old.ID = id
old.DownMultiplier = downMultiplier
old.UpMultiplier = upMultiplier
old.Snatched = snatched
old.Status = status
old.Group = group
}()
newTorrents[infoHash] = old
} else {
newTorrents[infoHash] = &cdb.Torrent{
ID: id,
UpMultiplier: upMultiplier,
DownMultiplier: downMultiplier,
Snatched: snatched,
Status: status,
Group: group,
Seeders: make(map[cdb.PeerKey]*cdb.Peer),
Leechers: make(map[cdb.PeerKey]*cdb.Peer),
}
}
}
defer func() {
_ = rows.Close()
}()
db.TorrentsLock.Lock()
defer db.TorrentsLock.Unlock()
db.Torrents = newTorrents
for rows.Next() {
var (
infoHash cdb.TorrentHash
id uint32
downMultiplier, upMultiplier float64
snatched uint16
status uint8
group cdb.TorrentGroup
)
if err := rows.Scan(
&id,
&infoHash,
&downMultiplier,
&upMultiplier,
&snatched,
&status,
&group.GroupID,
&group.TorrentType,
); err != nil {
log.Error.Printf("Error scanning torrent row: %s", err)
log.WriteStack()
}
if old, exists := dbTorrents[infoHash]; exists && old != nil {
func() {
old.Lock()
defer old.Unlock()
old.ID = id
old.DownMultiplier = downMultiplier
old.UpMultiplier = upMultiplier
old.Snatched = snatched
old.Status = status
old.Group = group
}()
newTorrents[infoHash] = old
} else {
newTorrents[infoHash] = &cdb.Torrent{
ID: id,
UpMultiplier: upMultiplier,
DownMultiplier: downMultiplier,
Snatched: snatched,
Status: status,
Group: group,
Seeders: make(map[cdb.PeerKey]*cdb.Peer),
Leechers: make(map[cdb.PeerKey]*cdb.Peer),
}
newTorrents[infoHash].InitializeLock()
}
}
db.Torrents.Store(&newTorrents)
elapsedTime := time.Since(start)
collectors.UpdateReloadTime("torrents", elapsedTime)
log.Info.Printf("Torrent load complete (%d rows, %s)", len(db.Torrents), elapsedTime.String())
log.Info.Printf("Torrent load complete (%d rows, %s)", len(newTorrents), elapsedTime.String())
}
func (db *Database) loadGroupsFreeleech() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start := time.Now()
newTorrentGroupFreeleech := make(map[cdb.TorrentGroup]*cdb.TorrentGroupFreeleech)
rows := db.mainConn.query(db.loadTorrentGroupFreeleechStmt)
rows := db.query(db.loadTorrentGroupFreeleechStmt)
if rows == nil {
log.Error.Print("Failed to load torrent group freeleech data from database")
log.WriteStack()
@ -295,14 +278,11 @@ func (db *Database) loadGroupsFreeleech() {
elapsedTime := time.Since(start)
collectors.UpdateReloadTime("groups_freeleech", elapsedTime)
log.Info.Printf("Group freeleech load complete (%d rows, %s)", len(db.Torrents), elapsedTime.String())
log.Info.Printf("Group freeleech load complete (%d rows, %s)", len(newTorrentGroupFreeleech), elapsedTime.String())
}
func (db *Database) loadConfig() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
rows := db.mainConn.query(db.loadFreeleechStmt)
rows := db.query(db.loadFreeleechStmt)
if rows == nil {
log.Error.Print("Failed to load config from database")
log.WriteStack()
@ -327,13 +307,10 @@ func (db *Database) loadConfig() {
}
func (db *Database) loadClients() {
db.mainConn.mutex.Lock()
defer db.mainConn.mutex.Unlock()
start := time.Now()
newClients := make(map[uint16]string)
rows := db.mainConn.query(db.loadClientsStmt)
rows := db.query(db.loadClientsStmt)
if rows == nil {
log.Error.Print("Failed to load clients from database")
log.WriteStack()

View file

@ -18,6 +18,7 @@
package database
import (
"chihaya/util"
"fmt"
"os"
"time"
@ -37,10 +38,9 @@ func init() {
func (db *Database) startSerializing() {
go func() {
for !db.terminate {
time.Sleep(time.Duration(serializeInterval) * time.Second)
util.ContextTick(db.ctx, time.Duration(serializeInterval)*time.Second, func() {
db.serialize()
}
})
}()
}
@ -68,10 +68,7 @@ func (db *Database) serialize() {
torrentFile.Close()
}()
db.TorrentsLock.RLock()
defer db.TorrentsLock.RUnlock()
if err = cdb.WriteTorrents(torrentFile, db.Torrents); err != nil {
if err = cdb.WriteTorrents(torrentFile, *db.Torrents.Load()); err != nil {
log.Error.Print("Failed to encode torrents for serialization: ", err)
return err
}
@ -96,10 +93,7 @@ func (db *Database) serialize() {
userFile.Close()
}()
db.UsersLock.RLock()
defer db.UsersLock.RUnlock()
if err = cdb.WriteUsers(userFile, db.Users); err != nil {
if err = cdb.WriteUsers(userFile, *db.Users.Load()); err != nil {
log.Error.Print("Failed to encode users for serialization: ", err)
return err
}
@ -124,6 +118,10 @@ func (db *Database) deserialize() {
start := time.Now()
torrents := 0
peers := 0
func() {
torrentFile, err := os.OpenFile(torrentBinFilename, os.O_RDONLY, 0)
if err != nil {
@ -134,13 +132,19 @@ func (db *Database) deserialize() {
//goland:noinspection GoUnhandledErrorResult
defer torrentFile.Close()
db.TorrentsLock.Lock()
defer db.TorrentsLock.Unlock()
if err = cdb.LoadTorrents(torrentFile, db.Torrents); err != nil {
dbTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
if err = cdb.LoadTorrents(torrentFile, dbTorrents); err != nil {
log.Error.Print("Failed to deserialize torrent cache: ", err)
return
}
torrents = len(dbTorrents)
for _, t := range dbTorrents {
peers += len(t.Leechers) + len(t.Seeders)
}
db.Torrents.Store(&dbTorrents)
}()
func() {
@ -153,33 +157,16 @@ func (db *Database) deserialize() {
//goland:noinspection GoUnhandledErrorResult
defer userFile.Close()
db.UsersLock.Lock()
defer db.UsersLock.Unlock()
if err = cdb.LoadUsers(userFile, db.Users); err != nil {
users := make(map[string]*cdb.User)
if err = cdb.LoadUsers(userFile, users); err != nil {
log.Error.Print("Failed to deserialize user cache: ", err)
return
}
db.Users.Store(&users)
}()
db.TorrentsLock.RLock()
defer db.TorrentsLock.RUnlock()
torrents := len(db.Torrents)
peers := 0
for _, t := range db.Torrents {
func() {
t.RLock()
defer t.RUnlock()
peers += len(t.Leechers) + len(t.Seeders)
}()
}
db.UsersLock.RLock()
defer db.UsersLock.RUnlock()
users := len(db.Users)
users := len(*db.Users.Load())
log.Info.Printf("Loaded %d users, %d torrents and %d peers (%s)",
users, torrents, peers, time.Since(start).String())

View file

@ -18,6 +18,7 @@
package database
import (
"math"
"net"
"reflect"
"testing"
@ -32,13 +33,14 @@ func TestSerializer(t *testing.T) {
testTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
testUsers := make(map[string]*cdb.User)
testUsers["mUztWMpBYNCqzmge6vGeEUGSrctJbgpQ"] = &cdb.User{
DisableDownload: false,
TrackerHide: false,
ID: 12,
UpMultiplier: 1,
DownMultiplier: 1,
}
testUser := &cdb.User{}
testUser.ID.Store(12)
testUser.DownMultiplier.Store(math.Float64bits(1))
testUser.UpMultiplier.Store(math.Float64bits(1))
testUser.DisableDownload.Store(false)
testUser.TrackerHide.Store(false)
testUsers["mUztWMpBYNCqzmge6vGeEUGSrctJbgpQ"] = testUser
testPeer := &cdb.Peer{
UserID: 12,
@ -69,34 +71,44 @@ func TestSerializer(t *testing.T) {
},
Leechers: map[cdb.PeerKey]*cdb.Peer{},
}
testTorrents[testTorrentHash].InitializeLock()
// Prepare empty map to populate with test data
db.Torrents = make(map[cdb.TorrentHash]*cdb.Torrent)
db.Users = make(map[string]*cdb.User)
dbTorrents := make(map[cdb.TorrentHash]*cdb.Torrent)
db.Torrents.Store(&dbTorrents)
if err := copier.Copy(&db.Torrents, testTorrents); err != nil {
dbUsers := make(map[string]*cdb.User)
db.Users.Store(&dbUsers)
if err := copier.Copy(&dbTorrents, testTorrents); err != nil {
panic(err)
}
if err := copier.Copy(&db.Users, testUsers); err != nil {
if err := copier.Copy(&dbUsers, testUsers); err != nil {
panic(err)
}
db.serialize()
// Reset map to fully test deserialization
db.Torrents = make(map[cdb.TorrentHash]*cdb.Torrent)
db.Users = make(map[string]*cdb.User)
dbTorrents = make(map[cdb.TorrentHash]*cdb.Torrent)
db.Torrents.Store(&dbTorrents)
dbUsers = make(map[string]*cdb.User)
db.Users.Store(&dbUsers)
db.deserialize()
if !reflect.DeepEqual(db.Torrents, testTorrents) {
dbTorrents = *db.Torrents.Load()
dbUsers = *db.Users.Load()
if !reflect.DeepEqual(dbTorrents, testTorrents) {
t.Fatalf("Torrents (%v) after serialization and deserialization do not match original torrents (%v)!",
db.Torrents, testTorrents)
dbTorrents, testTorrents)
}
if !reflect.DeepEqual(db.Users, testUsers) {
if !reflect.DeepEqual(dbUsers, testUsers) {
t.Fatalf("Users (%v) after serialization and deserialization do not match original users (%v)!",
db.Users, testUsers)
dbUsers, testUsers)
}
}

View file

@ -107,6 +107,7 @@ func LoadTorrents(r io.Reader, torrents map[TorrentHash]*Torrent) error {
}
t := &Torrent{}
t.InitializeLock()
if err := t.Load(version, reader); err != nil {
return err

View file

@ -18,13 +18,14 @@
package types
import (
"context"
"database/sql/driver"
"encoding/binary"
"encoding/hex"
"errors"
"github.com/viney-shih/go-lock"
"io"
"math"
"sync"
)
const TorrentHashSize = 20
@ -131,14 +132,25 @@ type Torrent struct {
DownMultiplier float64
// lock This must be taken whenever read or write is made to fields on this torrent.
// Maybe single sync.Mutex is easier to handle, but prevent concurrent access.
lock sync.RWMutex
lock *lock.CASMutex
}
func (t *Torrent) InitializeLock() {
if t.lock != nil {
panic("already initialized")
}
t.lock = lock.NewCASMutex()
}
func (t *Torrent) Lock() {
t.lock.Lock()
}
func (t *Torrent) TryLockWithContext(ctx context.Context) bool {
return t.lock.TryLockWithContext(ctx)
}
func (t *Torrent) Unlock() {
t.lock.Unlock()
}
@ -151,6 +163,10 @@ func (t *Torrent) RUnlock() {
t.lock.RUnlock()
}
func (t *Torrent) RTryLockWithContext(ctx context.Context) bool {
return t.lock.RTryLockWithContext(ctx)
}
func (t *Torrent) Load(version uint64, reader readerAndByteReader) (err error) {
var varIntLen uint64

View file

@ -20,57 +20,76 @@ package types
import (
"encoding/binary"
"math"
"sync/atomic"
)
type User struct {
ID uint32
ID atomic.Uint32
DisableDownload bool
DisableDownload atomic.Bool
TrackerHide bool
TrackerHide atomic.Bool
UpMultiplier float64
DownMultiplier float64
// UpMultiplier A float64 under the covers
UpMultiplier atomic.Uint64
// DownMultiplier A float64 under the covers
DownMultiplier atomic.Uint64
}
func (u *User) Load(_ uint64, reader readerAndByteReader) (err error) {
if err = binary.Read(reader, binary.LittleEndian, &u.ID); err != nil {
var (
id uint32
disableDownload, trackerHide bool
upMultiplier, downMultiplier float64
)
if err = binary.Read(reader, binary.LittleEndian, &id); err != nil {
return err
}
if err = binary.Read(reader, binary.LittleEndian, &u.DisableDownload); err != nil {
if err = binary.Read(reader, binary.LittleEndian, &disableDownload); err != nil {
return err
}
if err = binary.Read(reader, binary.LittleEndian, &u.TrackerHide); err != nil {
if err = binary.Read(reader, binary.LittleEndian, &trackerHide); err != nil {
return err
}
if err = binary.Read(reader, binary.LittleEndian, &u.UpMultiplier); err != nil {
if err = binary.Read(reader, binary.LittleEndian, &upMultiplier); err != nil {
return err
}
return binary.Read(reader, binary.LittleEndian, &u.DownMultiplier)
if err = binary.Read(reader, binary.LittleEndian, &downMultiplier); err != nil {
return err
}
u.ID.Store(id)
u.DisableDownload.Store(disableDownload)
u.TrackerHide.Store(trackerHide)
u.UpMultiplier.Store(math.Float64bits(upMultiplier))
u.DownMultiplier.Store(math.Float64bits(downMultiplier))
return nil
}
func (u *User) Append(preAllocatedBuffer []byte) (buf []byte) {
buf = preAllocatedBuffer
buf = binary.LittleEndian.AppendUint32(buf, u.ID)
buf = binary.LittleEndian.AppendUint32(buf, u.ID.Load())
if u.DisableDownload {
if u.DisableDownload.Load() {
buf = append(buf, 1)
} else {
buf = append(buf, 0)
}
if u.TrackerHide {
if u.TrackerHide.Load() {
buf = append(buf, 1)
} else {
buf = append(buf, 0)
}
buf = binary.LittleEndian.AppendUint64(buf, math.Float64bits(u.UpMultiplier))
buf = binary.LittleEndian.AppendUint64(buf, math.Float64bits(u.DownMultiplier))
buf = binary.LittleEndian.AppendUint64(buf, u.UpMultiplier.Load())
buf = binary.LittleEndian.AppendUint64(buf, u.DownMultiplier.Load())
return buf
}

View file

@ -215,19 +215,16 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
return http.StatusOK // Required by torrent clients to interpret failure response
}
if !db.TorrentsLock.RTryLockWithContext(ctx) {
return http.StatusRequestTimeout
}
defer db.TorrentsLock.RUnlock()
torrent, exists := db.Torrents[infoHashes[0]]
torrent, exists := (*db.Torrents.Load())[infoHashes[0]]
if !exists {
failure("This torrent does not exist", buf, 5*time.Minute)
return http.StatusOK // Required by torrent clients to interpret failure response
}
// Take torrent lock to read/write on it to prevent race conditions
torrent.Lock()
if !torrent.TryLockWithContext(ctx) {
return http.StatusRequestTimeout
}
defer torrent.Unlock()
if torrent.Status == 1 && left == 0 {
@ -262,7 +259,7 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
event, _ := qp.Get("event")
completed := event == "completed"
peerKey := cdb.NewPeerKey(user.ID, cdb.PeerIDFromRawString(peerID))
peerKey := cdb.NewPeerKey(user.ID.Load(), cdb.PeerIDFromRawString(peerID))
if left > 0 {
if isDisabledDownload(db, user, torrent) {
@ -310,7 +307,7 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
// Update peer info/stats
if newPeer {
peer.ID = peerKey.PeerID()
peer.UserID = user.ID
peer.UserID = user.ID.Load()
peer.TorrentID = torrent.ID
peer.StartTime = now
peer.LastAnnounce = now
@ -340,11 +337,11 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
var deltaDownload int64
if !database.GlobalFreeleech.Load() {
deltaDownload = int64(float64(rawDeltaDownload) * math.Abs(user.DownMultiplier) *
deltaDownload = int64(float64(rawDeltaDownload) * math.Abs(math.Float64frombits(user.DownMultiplier.Load())) *
math.Abs(torrentGroupDownMultiplier) * math.Abs(torrent.DownMultiplier))
}
deltaUpload := int64(float64(rawDeltaUpload) * math.Abs(user.UpMultiplier) *
deltaUpload := int64(float64(rawDeltaUpload) * math.Abs(math.Float64frombits(user.UpMultiplier.Load())) *
math.Abs(torrentGroupUpMultiplier) * math.Abs(torrent.UpMultiplier))
peer.Uploaded = uploaded
peer.Downloaded = downloaded
@ -390,7 +387,7 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
deltaSnatch = 1
}
if user.TrackerHide {
if user.TrackerHide.Load() {
peer.Addr = cdb.NewPeerAddressFromIPPort(net.IP{127, 0, 0, 1}, port) // 127.0.0.1
} else {
peer.Addr = cdb.NewPeerAddressFromIPPort(ipBytes, port)
@ -406,7 +403,7 @@ func announce(ctx context.Context, qs string, header http.Header, remoteAddr str
go record.Record(
peer.TorrentID,
user.ID,
user.ID.Load(),
ipAddr,
port,
event,

View file

@ -34,19 +34,12 @@ import (
var bearerPrefix = "Bearer "
func metrics(ctx context.Context, auth string, db *database.Database, buf *bytes.Buffer) int {
if !db.UsersLock.RTryLockWithContext(ctx) {
return http.StatusRequestTimeout
}
defer db.UsersLock.RUnlock()
if !db.TorrentsLock.RTryLockWithContext(ctx) {
return http.StatusRequestTimeout
}
defer db.TorrentsLock.RUnlock()
dbUsers := *db.Users.Load()
dbTorrents := *db.Torrents.Load()
peers := 0
for _, t := range db.Torrents {
for _, t := range dbTorrents {
func() {
t.RLock()
defer t.RUnlock()
@ -62,8 +55,8 @@ func metrics(ctx context.Context, auth string, db *database.Database, buf *bytes
}
collectors.UpdateUptime(time.Since(handler.startTime).Seconds())
collectors.UpdateUsers(len(db.Users))
collectors.UpdateTorrents(len(db.Torrents))
collectors.UpdateUsers(len(dbUsers))
collectors.UpdateTorrents(len(dbTorrents))
collectors.UpdateClients(len(*db.Clients.Load()))
collectors.UpdateHitAndRuns(len(*db.HitAndRuns.Load()))
collectors.UpdatePeers(peers)

View file

@ -57,14 +57,11 @@ func scrape(ctx context.Context, qs string, user *cdb.User, db *database.Databas
scrapeData := make(map[string]interface{})
fileData := make(map[cdb.TorrentHash]interface{})
if !db.TorrentsLock.RTryLockWithContext(ctx) {
return http.StatusRequestTimeout
}
defer db.TorrentsLock.RUnlock()
dbTorrents := *db.Torrents.Load()
if qp.InfoHashes() != nil {
for _, infoHash := range qp.InfoHashes() {
torrent, exists := db.Torrents[infoHash]
torrent, exists := dbTorrents[infoHash]
if exists {
if !isDisabledDownload(db, user, torrent) {
fileData[infoHash] = writeScrapeInfo(torrent)

View file

@ -91,10 +91,8 @@ func (handler *httpHandler) respond(ctx context.Context, r *http.Request, buf *b
* ===================================================
*/
user, err := isPasskeyValid(ctx, passkey, handler.db)
if err != nil {
return http.StatusRequestTimeout
} else if user == nil {
user := isPasskeyValid(passkey, handler.db)
if user == nil {
failure("Your passkey is invalid", buf, 1*time.Hour)
return http.StatusOK
}

View file

@ -19,7 +19,6 @@ package server
import (
"bytes"
"context"
"time"
"chihaya/database"
@ -69,18 +68,13 @@ func clientApproved(peerID string, db *database.Database) (uint16, bool) {
return 0, false
}
func isPasskeyValid(ctx context.Context, passkey string, db *database.Database) (*cdb.User, error) {
if !db.UsersLock.RTryLockWithContext(ctx) {
return nil, ctx.Err()
}
defer db.UsersLock.RUnlock()
user, exists := db.Users[passkey]
func isPasskeyValid(passkey string, db *database.Database) *cdb.User {
user, exists := (*db.Users.Load())[passkey]
if !exists {
return nil, nil
return nil
}
return user, nil
return user
}
func hasHitAndRun(db *database.Database, userID, torrentID uint32) bool {
@ -96,5 +90,5 @@ func hasHitAndRun(db *database.Database, userID, torrentID uint32) bool {
func isDisabledDownload(db *database.Database, user *cdb.User, torrent *cdb.Torrent) bool {
// Only disable download if the torrent doesn't have a HnR against it
return user.DisableDownload && !hasHitAndRun(db, user.ID, torrent.ID)
return user.DisableDownload.Load() && !hasHitAndRun(db, user.ID.Load(), torrent.ID)
}

20
util/context_ticker.go Normal file
View file

@ -0,0 +1,20 @@
package util
import (
"context"
"time"
)
func ContextTick(ctx context.Context, d time.Duration, onTick func()) {
ticker := time.NewTicker(d)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
onTick()
}
}
}