Explicitly use errgroup for SplitWork

This commit is contained in:
DataHoarder 2024-04-11 05:25:24 +02:00
parent 5136295d91
commit 8adaa81245
Signed by: DataHoarder
SSH key fingerprint: SHA256:OLTRf6Fl87G52SiR7sWLGNzlJt4WOX+tfI2yxo0z7xk
4 changed files with 27 additions and 36 deletions

2
go.mod
View file

@ -13,6 +13,7 @@ require (
github.com/go-zeromq/zmq4 v0.16.1-0.20240124085909-e75c615ba1b3 github.com/go-zeromq/zmq4 v0.16.1-0.20240124085909-e75c615ba1b3
github.com/goccy/go-json v0.10.2 github.com/goccy/go-json v0.10.2
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc
golang.org/x/sync v0.7.0
golang.org/x/sys v0.19.0 golang.org/x/sys v0.19.0
lukechampine.com/uint128 v1.3.0 lukechampine.com/uint128 v1.3.0
) )
@ -21,7 +22,6 @@ require (
github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/dolthub/maphash v0.1.0 // indirect github.com/dolthub/maphash v0.1.0 // indirect
golang.org/x/crypto v0.22.0 // indirect golang.org/x/crypto v0.22.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect
) )

View file

@ -805,15 +805,13 @@ func (c *SideChain) verifyBlock(block *PoolBlock) (verification error, invalid e
var hashers []*sha3.HasherState var hashers []*sha3.HasherState
var anyErr atomic.Value
defer func() { defer func() {
for _, h := range hashers { for _, h := range hashers {
crypto.PutKeccak256Hasher(h) crypto.PutKeccak256Hasher(h)
} }
}() }()
if !utils.SplitWork(-2, uint64(len(rewards)), func(workIndex uint64, workerIndex int) error { if err := utils.SplitWork(-2, uint64(len(rewards)), func(workIndex uint64, workerIndex int) error {
out := block.Main.Coinbase.Outputs[workIndex] out := block.Main.Coinbase.Outputs[workIndex]
if rewards[workIndex] != out.Reward { if rewards[workIndex] != out.Reward {
return fmt.Errorf("has invalid reward at index %d, got %d, expected %d", workIndex, out.Reward, rewards[workIndex]) return fmt.Errorf("has invalid reward at index %d, got %d, expected %d", workIndex, out.Reward, rewards[workIndex])
@ -828,10 +826,8 @@ func (c *SideChain) verifyBlock(block *PoolBlock) (verification error, invalid e
}, func(routines, routineIndex int) error { }, func(routines, routineIndex int) error {
hashers = append(hashers, crypto.GetKeccak256Hasher()) hashers = append(hashers, crypto.GetKeccak256Hasher())
return nil return nil
}, func(routineIndex int, err error) { }); err != nil {
anyErr.Store(err) return nil, err
}) {
return nil, anyErr.Load().(error)
} }
} }

View file

@ -51,7 +51,7 @@ func CalculateOutputs(block *PoolBlock, consensus *Consensus, difficultyByHeight
} }
}() }()
utils.SplitWork(-2, n, func(workIndex uint64, workerIndex int) error { err := utils.SplitWork(-2, n, func(workIndex uint64, workerIndex int) error {
output := transaction.Output{ output := transaction.Output{
Index: workIndex, Index: workIndex,
Type: txType, Type: txType,
@ -65,7 +65,11 @@ func CalculateOutputs(block *PoolBlock, consensus *Consensus, difficultyByHeight
}, func(routines, routineIndex int) error { }, func(routines, routineIndex int) error {
hashers = append(hashers, crypto.GetKeccak256Hasher()) hashers = append(hashers, crypto.GetKeccak256Hasher())
return nil return nil
}, nil) })
if err != nil {
return nil, 0
}
return outputs, bottomHeight return outputs, bottomHeight
} }

View file

@ -1,12 +1,12 @@
package utils package utils
import ( import (
"golang.org/x/sync/errgroup"
"runtime" "runtime"
"sync"
"sync/atomic" "sync/atomic"
) )
func SplitWork(routines int, workSize uint64, do func(workIndex uint64, routineIndex int) error, init func(routines, routineIndex int) error, errorFunc func(routineIndex int, err error)) bool { func SplitWork(routines int, workSize uint64, do func(workIndex uint64, routineIndex int) error, init func(routines, routineIndex int) error) error {
if routines <= 0 { if routines <= 0 {
routines = max(runtime.NumCPU()-routines, 4) routines = max(runtime.NumCPU()-routines, 4)
} }
@ -17,39 +17,30 @@ func SplitWork(routines int, workSize uint64, do func(workIndex uint64, routineI
var counter atomic.Uint64 var counter atomic.Uint64
var wg sync.WaitGroup
var failed atomic.Bool
for routineIndex := 0; routineIndex < routines; routineIndex++ { for routineIndex := 0; routineIndex < routines; routineIndex++ {
if init != nil { if err := init(routines, routineIndex); err != nil {
if err := init(routines, routineIndex); err != nil { return err
if errorFunc != nil {
errorFunc(routineIndex, err)
}
failed.Store(true)
continue
}
} }
wg.Add(1) }
go func(routineIndex int) {
defer wg.Done() var eg errgroup.Group
for routineIndex := 0; routineIndex < routines; routineIndex++ {
innerRoutineIndex := routineIndex
eg.Go(func() error {
var err error var err error
for { for {
workIndex := counter.Add(1) workIndex := counter.Add(1)
if workIndex > workSize { if workIndex > workSize {
return return nil
} }
if err = do(workIndex-1, routineIndex); err != nil { if err = do(workIndex-1, innerRoutineIndex); err != nil {
if errorFunc != nil { return err
errorFunc(routineIndex, err)
failed.Store(true)
}
return
} }
} }
}(routineIndex) })
} }
wg.Wait() return eg.Wait()
return !failed.Load()
} }