Hibiki/panako/strategy.go

319 lines
9.5 KiB
Go

package panako
import (
"git.gammaspectra.live/S.O.N.G/Kirika/audio"
"git.gammaspectra.live/S.O.N.G/Kirika/audio/filter"
"log"
"math"
"slices"
"sort"
"sync"
"time"
)
type QueryResult struct {
//query info
QueryStart time.Duration
QueryStop time.Duration
//ref info
ReferenceResourceId ResourceId
ReferenceStart time.Duration
ReferenceStop time.Duration
//match info
Score int
TimeFactor float64
FrequencyFactor float64
PercentOfSecondsWithMatches float64
}
type Strategy struct {
instance *Instance
store Store
resampleQuality filter.ResampleQuality
absoluteLatency time.Duration
}
func NewStrategy(instance *Instance, store Store, resamplerQuality ...filter.ResampleQuality) *Strategy {
s := &Strategy{
instance: instance,
store: store,
resampleQuality: filter.QualityFastest,
}
if len(resamplerQuality) > 0 {
s.resampleQuality = resamplerQuality[0]
}
eventPointProcessor, err := s.instance.GetEventPointProcessor()
if err != nil {
log.Panic(err)
}
s.absoluteLatency = time.Duration(float64(Hertz(eventPointProcessor.GetLatency())/s.instance.SampleRate) * float64(time.Second))
eventPointProcessor.GetFingerprints()
eventPointProcessor.ProcessingFinished()
return s
}
func (s *Strategy) StoreSource(resourceId ResourceId, source audio.Source) time.Duration {
prints := s.SourceToFingerprints(source)
s.StoreFingerprints(resourceId, prints)
if len(prints) > 0 {
return s.instance.BlockToDuration(prints[len(prints)-1].t3(), s.absoluteLatency)
}
return 0
}
func (s *Strategy) SourceToFingerprints(source audio.Source) []Fingerprint {
return s.BlockChannelToFingerprints(filter.NewFilterChain(source, filter.MonoFilter{}, filter.NewResampleFilter(int(s.instance.SampleRate), s.resampleQuality, s.instance.BlockSize)).ToFloat32().GetBlocks())
}
func (s *Strategy) BlockChannelToFingerprints(channel chan []float32) []Fingerprint {
eventPointProcessor, err := s.instance.GetEventPointProcessor()
if err != nil {
panic(err)
}
err = eventPointProcessor.ProcessBlockChannel(channel)
if err != nil {
panic(err)
}
return eventPointProcessor.GetFingerprints()
}
func (s *Strategy) QueryFingerprintsAsync(prints []Fingerprint, callback func(queryResults []QueryResult)) {
go func() {
callback(s.QueryFingerprints(prints))
}()
}
func (s *Strategy) QueryFingerprints(prints []Fingerprint) []QueryResult {
return s.QueryStoreRecords(getRecordsFromPrints(ResourceId(0), prints, s.instance.HashType))
}
func (s *Strategy) QueryStoreRecords(records []StoreRecord) (queryResults []QueryResult) {
hitsPerResource := make(map[ResourceId][]*MatchedRecord)
for record := range s.store.GetPanakoMatches(records, s.instance.QueryRange) {
hitsPerResource[record.Match.ResourceId] = append(hitsPerResource[record.Match.ResourceId], record)
}
var wg sync.WaitGroup
var queryResultMutex sync.RWMutex
minTimeFactor := 1 - s.instance.QueryTimeRatioFactor
maxTimeFactor := 1 + s.instance.QueryTimeRatioFactor
minFreqFactor := 1 - s.instance.QueryFrequencyRatioFactor
maxFreqFactor := 1 + s.instance.QueryFrequencyRatioFactor
for k := range hitsPerResource {
if len(hitsPerResource[k]) < s.instance.QueryMinimumHitsBeforeFiltering {
continue
}
wg.Add(1)
go func(resourceId ResourceId, hitList []*MatchedRecord) {
defer wg.Done()
//sort by query time
slices.SortStableFunc(hitList, func(a, b *MatchedRecord) int {
return int(a.Query.Time) - int(b.Query.Time)
})
ix := len(hitList) / s.instance.QueryHitListDivisor
if s.instance.QueryMinimumHitsBeforeFiltering > ix {
ix = s.instance.QueryMinimumHitsBeforeFiltering
}
if s.instance.QueryMaximumHitListSize < ix {
ix = s.instance.QueryMaximumHitListSize
}
//view the first and last hits
firstHits := hitList[0:ix]
lastHits := hitList[len(hitList)-ix:]
//TODO: improvement: select several pairs from each
var x1i []int
var ffi []float64
var x2i []int
//find the first x1 where delta t is equal to the median delta t
y1 := mostCommonDeltaTforHitList(firstHits)
for i := 0; i < len(firstHits); i++ {
hit := firstHits[i]
diff := hit.DeltaTime()
if diff == y1 {
x1i = append(x1i, i)
ffi = append(ffi, float64(s.instance.BinToHertz(hit.Match.Frequency)/s.instance.BinToHertz(hit.Query.Frequency)))
if len(x1i) == s.instance.QueryHitListDeltaTimeDepth {
break
}
}
}
//find the first x2 where delta t is equal to the median delta t
y2 := mostCommonDeltaTforHitList(lastHits)
for i := len(lastHits) - 1; i >= 0; i-- {
hit := lastHits[i]
diff := hit.DeltaTime()
if diff == y2 {
x2i = append(x2i, i)
if len(x2i) == s.instance.QueryHitListDeltaTimeDepth {
break
}
}
}
if len(x1i) == 0 || len(x2i) == 0 {
return
}
//threshold in time bins
threshold := float64(s.instance.QueryRange)
for i, j := range x1i {
frequencyFactor := ffi[i]
x1 := int64(firstHits[j].Query.Time)
for _, l := range x2i {
x2 := int64(lastHits[l].Query.Time)
slope := float64(y2-y1) / float64(x2-x1)
offset := float64(-x1)*slope + float64(y1)
timeFactor := 1. / (1. - slope)
//only continue processing when time factor is reasonable
if timeFactor > minTimeFactor && timeFactor < maxTimeFactor && frequencyFactor > minFreqFactor && frequencyFactor < maxFreqFactor {
var filteredHits []*MatchedRecord
for _, hit := range hitList {
yActual := float64(hit.DeltaTime())
x := float64(hit.Query.Time)
yPredicted := slope*x + offset
//should be within an expected range
yInExpectedRange := math.Abs(yActual-yPredicted) <= threshold
if yInExpectedRange {
filteredHits = append(filteredHits, hit)
}
}
//ignore resources with too few filtered hits remaining
if len(filteredHits) > s.instance.QueryMinimumHitsAfterFiltering {
minDuration := s.instance.QueryDurationMinimum
queryStart := s.instance.BlockToDuration(filteredHits[0].Query.Time, s.absoluteLatency)
queryStop := s.instance.BlockToDuration(filteredHits[len(filteredHits)-1].Query.Time, s.absoluteLatency)
duration := queryStop - queryStart
if duration >= minDuration {
score := len(filteredHits)
refStart := s.instance.BlockToDuration(filteredHits[0].Match.Time, s.absoluteLatency)
refStop := s.instance.BlockToDuration(filteredHits[len(filteredHits)-1].Match.Time, s.absoluteLatency)
//retrieve meta-data
//(*s.store).GetPanakoRecords(resourceId)
//Construct a histogram with the number of matches for each second
//Ideally there is a more or less equal number of matches each second
// note that the last second might not be a full second
matchesPerSecondHistogram := make(map[int]int)
for _, hit := range filteredHits {
secondBin := int((s.instance.BlockToDuration(hit.Match.Time, s.absoluteLatency) - refStart).Seconds())
if _, ok := matchesPerSecondHistogram[secondBin]; ok {
matchesPerSecondHistogram[secondBin] = matchesPerSecondHistogram[secondBin] + 1
} else {
matchesPerSecondHistogram[secondBin] = 1
}
}
numberOfMatchingSeconds := math.Ceil((refStop - refStart).Seconds())
emptySeconds := numberOfMatchingSeconds - float64(len(matchesPerSecondHistogram))
percentOfSecondsWithMatches := 1. - (emptySeconds / numberOfMatchingSeconds)
if percentOfSecondsWithMatches >= s.instance.QueryDurationPercentageMinimum {
func() {
queryResultMutex.Lock()
defer queryResultMutex.Unlock()
queryResults = append(queryResults, QueryResult{
QueryStart: queryStart,
QueryStop: queryStop,
ReferenceResourceId: resourceId,
ReferenceStart: refStart,
ReferenceStop: refStop,
Score: score,
TimeFactor: timeFactor,
FrequencyFactor: frequencyFactor,
PercentOfSecondsWithMatches: percentOfSecondsWithMatches,
})
}()
return
}
}
}
}
}
}
}(k, hitsPerResource[k])
}
wg.Wait()
//sort results by score desc
// This does not copy instead of slices.SortStable
sort.SliceStable(queryResults, func(i, j int) bool {
return queryResults[j].Score < queryResults[i].Score
})
return
}
func mostCommonDeltaTforHitList(hitList []*MatchedRecord) int64 {
countPerDiff := make(map[int64]int)
for _, v := range hitList {
deltaT := v.DeltaTime()
if _, ok := countPerDiff[deltaT]; ok {
countPerDiff[deltaT] = countPerDiff[deltaT] + 1
} else {
countPerDiff[deltaT] = 1
}
}
maxCount := 0
var mostCommonDeltaT int64 = 0
for k, v := range countPerDiff {
if v > maxCount {
maxCount = v
mostCommonDeltaT = k
}
}
return mostCommonDeltaT
}
func getRecordsFromPrints(resourceId ResourceId, prints []Fingerprint, hashType FingerprintHashType) []StoreRecord {
records := make([]StoreRecord, 0, len(prints))
for _, p := range prints {
records = append(records, StoreRecord{
ResourceId: resourceId,
Hash: p.GetHash(hashType),
Time: p.T1(),
Frequency: p.F1(),
})
}
return records
}
func (s *Strategy) StoreFingerprints(resourceId ResourceId, prints []Fingerprint) {
s.store.StorePanakoPrints(getRecordsFromPrints(resourceId, prints, s.instance.HashType))
}