From efa01d5c31f3210b97208821504ccbc36b22f712 Mon Sep 17 00:00:00 2001 From: WeebDataHoarder <57538841+WeebDataHoarder@users.noreply.github.com> Date: Mon, 5 Sep 2022 16:54:15 +0200 Subject: [PATCH] Refactor Queue to use cancellable entries --- audio/queue/entry.go | 45 ++++++++ audio/queue/queue.go | 228 ++++++++++++++++++-------------------- audio/queue/queue_test.go | 6 +- 3 files changed, 154 insertions(+), 125 deletions(-) create mode 100644 audio/queue/entry.go diff --git a/audio/queue/entry.go b/audio/queue/entry.go new file mode 100644 index 0000000..cf5b297 --- /dev/null +++ b/audio/queue/entry.go @@ -0,0 +1,45 @@ +package queue + +import ( + "git.gammaspectra.live/S.O.N.G/Kirika/audio" + "sync" + "sync/atomic" +) + +type EntryCallback func(q *Queue, entry *Entry) + +type Entry struct { + Identifier Identifier + Source audio.Source + ReadSamples atomic.Uint64 + cancel chan struct{} + cancelMutex sync.Mutex + StartCallback EntryCallback + EndCallback EntryCallback + RemoveCallback EntryCallback +} + +func NewEntry(identifier Identifier, source audio.Source, cancel chan struct{}, start, end, remove EntryCallback) *Entry { + return &Entry{ + Identifier: identifier, + Source: source, + cancel: cancel, + StartCallback: start, + EndCallback: end, + RemoveCallback: remove, + } +} + +func (e *Entry) Done() <-chan struct{} { + return e.cancel +} + +func (e *Entry) Cancel() { + e.cancelMutex.Lock() + defer e.cancelMutex.Unlock() + select { + case <-e.Done(): + default: + close(e.cancel) + } +} diff --git a/audio/queue/queue.go b/audio/queue/queue.go index e1f4d7f..62eb799 100644 --- a/audio/queue/queue.go +++ b/audio/queue/queue.go @@ -4,31 +4,20 @@ import ( "git.gammaspectra.live/S.O.N.G/Kirika/audio" "git.gammaspectra.live/S.O.N.G/Kirika/audio/filter" "log" + "runtime" "sync" "sync/atomic" ) -type QueueIdentifier int - -type QueueEntry struct { - Identifier QueueIdentifier - Source audio.Source - ReadSamples atomic.Uint64 - cancel chan bool - StartCallback func(q *Queue, entry *QueueEntry) - EndCallback func(q *Queue, entry *QueueEntry) - RemoveCallback func(q *Queue, entry *QueueEntry) -} +type Identifier int type Queue struct { - queue []*QueueEntry + queue []*Entry output audio.Source - interrupt chan bool - interruptDepth atomic.Int64 - closed bool + closed atomic.Bool lock sync.RWMutex wg sync.WaitGroup - identifierCounter QueueIdentifier + identifierCounter Identifier } func NewQueue(format audio.SourceFormat, bitDepth, sampleRate, channels int) *Queue { @@ -36,9 +25,7 @@ func NewQueue(format audio.SourceFormat, bitDepth, sampleRate, channels int) *Qu log.Panicf("not allowed channel number %d", channels) } - q := &Queue{ - interrupt: make(chan bool, 1), - } + q := &Queue{} switch format { case audio.SourceFloat32: @@ -56,8 +43,8 @@ func NewQueue(format audio.SourceFormat, bitDepth, sampleRate, channels int) *Qu return q } -func spliceHelper[T audio.AllowedSourceTypes](input audio.TypedSource[T]) (output audio.TypedSource[T], cancel chan bool) { - cancel = make(chan bool, 1) +func spliceHelper[T audio.AllowedSourceTypes](input audio.TypedSource[T]) (output audio.TypedSource[T], cancel chan struct{}) { + cancel = make(chan struct{}, 1) output = audio.NewSource[T](input.GetBitDepth(), input.GetSampleRate(), input.GetChannels()) bitDepth := input.GetBitDepth() @@ -89,7 +76,7 @@ func spliceHelper[T audio.AllowedSourceTypes](input audio.TypedSource[T]) (outpu return } -func (q *Queue) spliceSources(input audio.Source) (audio.Source, chan bool) { +func (q *Queue) spliceSources(input audio.Source) (audio.Source, chan struct{}) { switch q.GetSource().GetFormat() { case audio.SourceFloat32: @@ -105,56 +92,75 @@ func (q *Queue) spliceSources(input audio.Source) (audio.Source, chan bool) { return nil, nil } +func (q *Queue) current() *Entry { + q.lock.RLock() + defer q.lock.RUnlock() + if len(q.queue) > 0 { + return q.queue[0] + } + return nil +} + func queueLoopStart[T audio.AllowedSourceTypes](q *Queue) { q.wg.Add(1) go func() { defer q.wg.Done() - var current *QueueEntry - var currentBlocks chan []T + var current *Entry L1: for { - q.lock.RLock() - if q.closed { + if q.closed.Load() { q.output.Close() break L1 } - if len(q.queue) == 0 { //no more entries, wait for interrupt - q.lock.RUnlock() - <-q.interrupt - q.interruptDepth.Add(-1) + + if func() bool { + currentEntry := q.current() + if currentEntry == nil { + return current == nil + } else if current == nil || current.Identifier != currentEntry.Identifier { + current = currentEntry + } + return false + }() { + runtime.Gosched() continue } - if current == nil || current.Identifier != q.queue[0].Identifier { - current = q.queue[0] - currentBlocks = current.Source.(audio.TypedSource[T]).GetBlocks() - } - q.lock.RUnlock() - F1: - for { - select { - case <-q.interrupt: - q.interruptDepth.Add(-1) - //force recheck - break F1 - case block, more := <-currentBlocks: - if !more { - //no more blocks! skip - if current.EndCallback != nil { - current.EndCallback(q, current) + func() { + currentBlocks := current.Source.(audio.TypedSource[T]).GetBlocks() + defer current.Source.Unlock() + + if current.StartCallback != nil && current.ReadSamples.Load() == 0 { + current.StartCallback(q, current) + } + + output := q.output.(audio.TypedSource[T]) + + for { + select { + case <-current.Done(): + //song has been cancelled elsewhere + return + case block, more := <-currentBlocks: + if !more { + //no more blocks to read + current.Cancel() + return + } else { + current.ReadSamples.Add(uint64(len(block) / current.Source.GetChannels())) + output.IngestNative(block, current.Source.GetBitDepth()) } - q.Remove(current.Identifier) - break F1 - } else { - if current.StartCallback != nil && current.ReadSamples.Load() == 0 && len(block) > 0 { - current.StartCallback(q, current) - } - current.ReadSamples.Add(uint64(len(block) / current.Source.GetChannels())) - q.output.(audio.TypedSource[T]).IngestNative(block, current.Source.GetBitDepth()) } } - } + }() + { + //no more blocks! skip + if current.EndCallback != nil { + current.EndCallback(q, current) + } + q.Remove(current.Identifier) + } } }() @@ -172,90 +178,75 @@ func (q *Queue) getFilterChain(source audio.Source) audio.Source { } } -func (q *Queue) AddHead(source audio.Source, startCallback func(q *Queue, entry *QueueEntry), endCallback func(q *Queue, entry *QueueEntry), removeCallback func(q *Queue, entry *QueueEntry)) (identifier QueueIdentifier) { +func (q *Queue) AddHead(source audio.Source, startCallback, endCallback, removeCallback EntryCallback) (identifier Identifier) { q.lock.Lock() + defer q.lock.Unlock() splicedOutput, cancel := q.spliceSources(source) identifier = q.identifierCounter + + entry := NewEntry(identifier, q.getFilterChain(splicedOutput), cancel, startCallback, endCallback, removeCallback) if len(q.queue) > 0 { - q.queue = append(q.queue[:1], append([]*QueueEntry{{ - Identifier: identifier, - Source: q.getFilterChain(splicedOutput), - cancel: cancel, - StartCallback: startCallback, - EndCallback: endCallback, - RemoveCallback: removeCallback, - }}, q.queue[1:]...)...) + q.queue = append(q.queue[:1], append([]*Entry{entry}, q.queue[1:]...)...) } else { - q.queue = append(q.queue, &QueueEntry{ - Identifier: identifier, - Source: q.getFilterChain(splicedOutput), - cancel: cancel, - StartCallback: startCallback, - EndCallback: endCallback, - RemoveCallback: removeCallback, - }) + q.queue = append(q.queue, entry) } q.identifierCounter++ - q.lock.Unlock() - q.sendInterrupt() return } -func (q *Queue) AddTail(source audio.Source, startCallback func(q *Queue, entry *QueueEntry), endCallback func(q *Queue, entry *QueueEntry), removeCallback func(q *Queue, entry *QueueEntry)) (identifier QueueIdentifier) { +func (q *Queue) AddTail(source audio.Source, startCallback, endCallback, removeCallback EntryCallback) (identifier Identifier) { q.lock.Lock() + defer q.lock.Unlock() splicedOutput, cancel := q.spliceSources(source) identifier = q.identifierCounter - q.queue = append(q.queue, &QueueEntry{ - Identifier: identifier, - Source: q.getFilterChain(splicedOutput), - cancel: cancel, - StartCallback: startCallback, - EndCallback: endCallback, - RemoveCallback: removeCallback, - }) + + entry := NewEntry(identifier, q.getFilterChain(splicedOutput), cancel, startCallback, endCallback, removeCallback) + + q.queue = append(q.queue, entry) q.identifierCounter++ - q.lock.Unlock() - q.sendInterrupt() return } func (q *Queue) IsClosed() bool { - q.lock.RLock() - defer q.lock.RUnlock() - return q.closed + return q.closed.Load() } -func (q *Queue) Remove(identifier QueueIdentifier) bool { - q.lock.Lock() +func (q *Queue) Remove(identifier Identifier) bool { + var entry *Entry - for i, e := range q.queue { - if e.Identifier == identifier { - q.sendInterrupt() - e.cancel <- true + func() { + q.lock.Lock() + defer q.lock.Unlock() - e.Source.Unlock() + for i, e := range q.queue { + if e.Identifier == identifier { + e.Cancel() - go audio.NewNullSink().Process(e.Source) - //delete entry - q.queue = append(q.queue[:i], q.queue[i+1:]...) - q.lock.Unlock() - if e.RemoveCallback != nil { - e.RemoveCallback(q, e) + e.Source.Unlock() + + go audio.NewNullSink().Process(e.Source) + //delete entry + q.queue = append(q.queue[:i], q.queue[i+1:]...) + entry = e } - return true } + }() + + if entry != nil { + if entry.RemoveCallback != nil { + entry.RemoveCallback(q, entry) + } + return true } - q.lock.Unlock() return false - } -func (q *Queue) GetQueueHead() *QueueEntry { +func (q *Queue) GetQueueHead() *Entry { q.lock.RLock() defer q.lock.RUnlock() if len(q.queue) > 0 { @@ -264,7 +255,7 @@ func (q *Queue) GetQueueHead() *QueueEntry { return nil } -func (q *Queue) GetQueueTail() (index int, entry *QueueEntry) { +func (q *Queue) GetQueueTail() (index int, entry *Entry) { q.lock.RLock() defer q.lock.RUnlock() if len(q.queue) > 0 { @@ -273,7 +264,7 @@ func (q *Queue) GetQueueTail() (index int, entry *QueueEntry) { return 0, nil } -func (q *Queue) GetQueueIndex(index int) *QueueEntry { +func (q *Queue) GetQueueIndex(index int) *Entry { q.lock.RLock() defer q.lock.RUnlock() if len(q.queue) > index { @@ -282,7 +273,7 @@ func (q *Queue) GetQueueIndex(index int) *QueueEntry { return nil } -func (q *Queue) GetQueueEntry(identifier QueueIdentifier) (index int, entry *QueueEntry) { +func (q *Queue) GetQueueEntry(identifier Identifier) (index int, entry *Entry) { q.lock.RLock() defer q.lock.RUnlock() for i, e := range q.queue { @@ -300,11 +291,11 @@ func (q *Queue) GetQueueSize() int { return len(q.queue) } -func (q *Queue) GetQueue() (entries []*QueueEntry) { +func (q *Queue) GetQueue() (entries []*Entry) { q.lock.RLock() defer q.lock.RUnlock() - entries = make([]*QueueEntry, len(q.queue)) + entries = make([]*Entry, len(q.queue)) copy(entries, q.queue) return @@ -322,18 +313,11 @@ func (q *Queue) GetChannels() int { return q.GetSource().GetChannels() } -func (q *Queue) sendInterrupt() { - //TODO: maybe use len() on channel? - if q.interruptDepth.Load() == 0 { //not waiting on interrupt - q.interruptDepth.Add(1) - q.interrupt <- true - } -} - func (q *Queue) Close() { - if !q.closed { - q.closed = true - q.sendInterrupt() + if !q.closed.Swap(true) { + if current := q.current(); current != nil { + current.Cancel() + } } } diff --git a/audio/queue/queue_test.go b/audio/queue/queue_test.go index 865e974..eb6dbb9 100644 --- a/audio/queue/queue_test.go +++ b/audio/queue/queue_test.go @@ -39,11 +39,11 @@ func TestQueue(t *testing.T) { return } - q.AddTail(source, func(q *Queue, entry *QueueEntry) { + q.AddTail(source, func(q *Queue, entry *Entry) { t.Logf("Started playback of %d %s\n", entry.Identifier, fullPath) - }, func(q *Queue, entry *QueueEntry) { + }, func(q *Queue, entry *Entry) { t.Logf("Finished playback of %d %s: output %d samples\n", entry.Identifier, fullPath, entry.ReadSamples.Load()) - }, func(q *Queue, entry *QueueEntry) { + }, func(q *Queue, entry *Entry) { fp.Close() if q.GetQueueSize() == 0 { t.Log("Finished playback, closing\n")