diff --git a/go.mod b/go.mod index 0b7053a..bd7ea5d 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/go-zeromq/zmq4 v0.16.1-0.20240124085909-e75c615ba1b3 github.com/goccy/go-json v0.10.2 github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc + golang.org/x/sync v0.7.0 golang.org/x/sys v0.19.0 lukechampine.com/uint128 v1.3.0 ) @@ -21,7 +22,6 @@ require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/dolthub/maphash v0.1.0 // indirect golang.org/x/crypto v0.22.0 // indirect - golang.org/x/sync v0.7.0 // indirect golang.org/x/text v0.14.0 // indirect ) diff --git a/p2pool/sidechain/sidechain.go b/p2pool/sidechain/sidechain.go index a96eb08..863313e 100644 --- a/p2pool/sidechain/sidechain.go +++ b/p2pool/sidechain/sidechain.go @@ -805,15 +805,13 @@ func (c *SideChain) verifyBlock(block *PoolBlock) (verification error, invalid e var hashers []*sha3.HasherState - var anyErr atomic.Value - defer func() { for _, h := range hashers { crypto.PutKeccak256Hasher(h) } }() - if !utils.SplitWork(-2, uint64(len(rewards)), func(workIndex uint64, workerIndex int) error { + if err := utils.SplitWork(-2, uint64(len(rewards)), func(workIndex uint64, workerIndex int) error { out := block.Main.Coinbase.Outputs[workIndex] if rewards[workIndex] != out.Reward { return fmt.Errorf("has invalid reward at index %d, got %d, expected %d", workIndex, out.Reward, rewards[workIndex]) @@ -828,10 +826,8 @@ func (c *SideChain) verifyBlock(block *PoolBlock) (verification error, invalid e }, func(routines, routineIndex int) error { hashers = append(hashers, crypto.GetKeccak256Hasher()) return nil - }, func(routineIndex int, err error) { - anyErr.Store(err) - }) { - return nil, anyErr.Load().(error) + }); err != nil { + return nil, err } } diff --git a/p2pool/sidechain/utils.go b/p2pool/sidechain/utils.go index 705eacf..718587d 100644 --- a/p2pool/sidechain/utils.go +++ b/p2pool/sidechain/utils.go @@ -51,7 +51,7 @@ func CalculateOutputs(block *PoolBlock, consensus *Consensus, difficultyByHeight } }() - utils.SplitWork(-2, n, func(workIndex uint64, workerIndex int) error { + err := utils.SplitWork(-2, n, func(workIndex uint64, workerIndex int) error { output := transaction.Output{ Index: workIndex, Type: txType, @@ -65,7 +65,11 @@ func CalculateOutputs(block *PoolBlock, consensus *Consensus, difficultyByHeight }, func(routines, routineIndex int) error { hashers = append(hashers, crypto.GetKeccak256Hasher()) return nil - }, nil) + }) + + if err != nil { + return nil, 0 + } return outputs, bottomHeight } diff --git a/utils/concurrency.go b/utils/concurrency.go index 5b3a7a5..ae40c35 100644 --- a/utils/concurrency.go +++ b/utils/concurrency.go @@ -1,12 +1,12 @@ package utils import ( + "golang.org/x/sync/errgroup" "runtime" - "sync" "sync/atomic" ) -func SplitWork(routines int, workSize uint64, do func(workIndex uint64, routineIndex int) error, init func(routines, routineIndex int) error, errorFunc func(routineIndex int, err error)) bool { +func SplitWork(routines int, workSize uint64, do func(workIndex uint64, routineIndex int) error, init func(routines, routineIndex int) error) error { if routines <= 0 { routines = max(runtime.NumCPU()-routines, 4) } @@ -17,39 +17,30 @@ func SplitWork(routines int, workSize uint64, do func(workIndex uint64, routineI var counter atomic.Uint64 - var wg sync.WaitGroup - var failed atomic.Bool for routineIndex := 0; routineIndex < routines; routineIndex++ { - if init != nil { - if err := init(routines, routineIndex); err != nil { - if errorFunc != nil { - errorFunc(routineIndex, err) - } - failed.Store(true) - continue - } + if err := init(routines, routineIndex); err != nil { + return err } - wg.Add(1) - go func(routineIndex int) { - defer wg.Done() + } + + var eg errgroup.Group + + for routineIndex := 0; routineIndex < routines; routineIndex++ { + innerRoutineIndex := routineIndex + eg.Go(func() error { var err error + for { workIndex := counter.Add(1) if workIndex > workSize { - return + return nil } - if err = do(workIndex-1, routineIndex); err != nil { - if errorFunc != nil { - errorFunc(routineIndex, err) - failed.Store(true) - } - return + if err = do(workIndex-1, innerRoutineIndex); err != nil { + return err } } - }(routineIndex) + }) } - wg.Wait() - - return !failed.Load() + return eg.Wait() }