zmq4: add timeout support on send
Add internal/errgroup package to support cancellable error groups. Fixes #147. Authored-by: Sergey Egorov <sergey.egorov@teleste.com>
This commit is contained in:
parent
e16dc3e41e
commit
16ca7c091b
107
internal/errgroup/errgroup.go
Normal file
107
internal/errgroup/errgroup.go
Normal file
|
@ -0,0 +1,107 @@
|
|||
// Copyright 2023 The go-zeromq Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package errgroup is bit more advanced than golang.org/x/sync/errgroup.
|
||||
// Major difference is that when error group is created with WithContext
|
||||
// the parent context would implicitly cancel all functions called by Go method.
|
||||
package errgroup
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// The Group is superior errgroup.Group which aborts whole group
|
||||
// execution when parent context is cancelled
|
||||
type Group struct {
|
||||
grp *errgroup.Group
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// WithContext creates Group and store inside parent context
|
||||
// so the Go method would respect parent context cancellation
|
||||
func WithContext(ctx context.Context) (*Group, context.Context) {
|
||||
grp, child_ctx := errgroup.WithContext(ctx)
|
||||
return &Group{grp: grp, ctx: ctx}, child_ctx
|
||||
}
|
||||
|
||||
// Go runs the provided f function in a dedicated goroutine and waits for its
|
||||
// completion or for the parent context cancellation.
|
||||
func (g *Group) Go(f func() error) {
|
||||
g.getErrGroup().Go(g.wrap(f))
|
||||
}
|
||||
|
||||
// Wait blocks until all function calls from the Go method have returned, then
|
||||
// returns the first non-nil error (if any) from them.
|
||||
// If the error group was created via WithContext then the Wait returns error
|
||||
// of cancelled parent context prior any functions calls complete.
|
||||
func (g *Group) Wait() error {
|
||||
return g.getErrGroup().Wait()
|
||||
}
|
||||
|
||||
// SetLimit limits the number of active goroutines in this group to at most n.
|
||||
// A negative value indicates no limit.
|
||||
//
|
||||
// Any subsequent call to the Go method will block until it can add an active
|
||||
// goroutine without exceeding the configured limit.
|
||||
//
|
||||
// The limit must not be modified while any goroutines in the group are active.
|
||||
func (g *Group) SetLimit(n int) {
|
||||
g.getErrGroup().SetLimit(n)
|
||||
}
|
||||
|
||||
// TryGo calls the given function in a new goroutine only if the number of
|
||||
// active goroutines in the group is currently below the configured limit.
|
||||
//
|
||||
// The return value reports whether the goroutine was started.
|
||||
func (g *Group) TryGo(f func() error) bool {
|
||||
return g.getErrGroup().TryGo(g.wrap(f))
|
||||
}
|
||||
|
||||
func (g *Group) wrap(f func() error) func() error {
|
||||
if g.ctx == nil {
|
||||
return f
|
||||
}
|
||||
|
||||
return func() error {
|
||||
// If parent context is canceled,
|
||||
// just return its error and do not call func f
|
||||
select {
|
||||
case <-g.ctx.Done():
|
||||
return g.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Create return channel and call func f
|
||||
// Buffered channel is used as the following select
|
||||
// may be exiting by context cancellation
|
||||
// and in such case the write to channel can be block
|
||||
// and cause the go routine leak
|
||||
ch := make(chan error, 1)
|
||||
go func() {
|
||||
ch <- f()
|
||||
}()
|
||||
|
||||
// Wait func f complete or
|
||||
// parent context to be cancelled,
|
||||
select {
|
||||
case err := <-ch:
|
||||
return err
|
||||
case <-g.ctx.Done():
|
||||
return g.ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The getErrGroup returns actual x/sync/errgroup.Group.
|
||||
// If the group is not allocated it would implicitly allocate it.
|
||||
// Thats allows the internal/errgroup.Group be fully
|
||||
// compatible to x/sync/errgroup.Group
|
||||
func (g *Group) getErrGroup() *errgroup.Group {
|
||||
if g.grp == nil {
|
||||
g.grp = &errgroup.Group{}
|
||||
}
|
||||
return g.grp
|
||||
}
|
124
internal/errgroup/errgroup_test.go
Normal file
124
internal/errgroup/errgroup_test.go
Normal file
|
@ -0,0 +1,124 @@
|
|||
// Copyright 2023 The go-zeromq Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package errgroup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// TestRegularErrGroupDoesNotRespectParentContext checks regular errgroup behavior
|
||||
// where errgroup.WithContext does not respect the parent context
|
||||
func TestRegularErrGroupDoesNotRespectParentContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
eg, _ := errgroup.WithContext(ctx)
|
||||
|
||||
what := fmt.Errorf("func generated error")
|
||||
ch := make(chan error)
|
||||
eg.Go(func() error { return <-ch })
|
||||
|
||||
cancel() // abort parent context
|
||||
ch <- what // signal the func in regular errgroup to fail
|
||||
err := eg.Wait()
|
||||
|
||||
// The error shall be one returned by the function
|
||||
// as regular errgroup.WithContext does not respect parent context
|
||||
if err != what {
|
||||
t.Errorf("invalid error. got=%+v, want=%+v", err, what)
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrGroupWithContextCanCallFunctions checks the errgroup operations
|
||||
// are fine working and errgroup called function can return error
|
||||
func TestErrGroupWithContextCanCallFunctions(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
eg, _ := WithContext(ctx)
|
||||
|
||||
what := fmt.Errorf("func generated error")
|
||||
ch := make(chan error)
|
||||
eg.Go(func() error { return <-ch })
|
||||
|
||||
ch <- what // signal the func in errgroup to fail
|
||||
err := eg.Wait() // wait errgroup complete and read error
|
||||
|
||||
// The error shall be one returned by the function
|
||||
if err != what {
|
||||
t.Errorf("invalid error. got=%+v, want=%+v", err, what)
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrGroupWithContextDoesRespectParentContext checks the errgroup operations
|
||||
// are cancellable by parent context
|
||||
func TestErrGroupWithContextDoesRespectParentContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
eg, _ := WithContext(ctx)
|
||||
|
||||
s1 := make(chan struct{})
|
||||
s2 := make(chan struct{})
|
||||
eg.Go(func() error {
|
||||
s1 <- struct{}{}
|
||||
<-s2
|
||||
return fmt.Errorf("func generated error")
|
||||
})
|
||||
|
||||
// We have no set limit to errgroup so
|
||||
// shall be able to start function via TryGo
|
||||
if ok := eg.TryGo(func() error { return nil }); !ok {
|
||||
t.Errorf("Expected TryGo to be able start function!!!")
|
||||
}
|
||||
|
||||
<-s1 // wait for function to start
|
||||
cancel() // abort parent context
|
||||
|
||||
eg.Go(func() error {
|
||||
t.Errorf("The parent context was already cancelled and this function shall not be called!!!")
|
||||
return nil
|
||||
})
|
||||
|
||||
s2 <- struct{}{} // signal the func in regular errgroup to fail
|
||||
err := eg.Wait() // wait errgroup complete and read error
|
||||
|
||||
// The error shall be one returned by the function
|
||||
// as regular errgroup.WithContext does not respect parent context
|
||||
if err != context.Canceled {
|
||||
t.Errorf("expected a context.Canceled error, got=%+v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrGroupFallback tests fallback logic to be compatible with x/sync/errgroup
|
||||
func TestErrGroupFallback(t *testing.T) {
|
||||
eg := Group{}
|
||||
eg.SetLimit(2)
|
||||
|
||||
ch1 := make(chan error)
|
||||
eg.Go(func() error { return <-ch1 })
|
||||
|
||||
ch2 := make(chan error)
|
||||
ok := eg.TryGo(func() error { return <-ch2 })
|
||||
if !ok {
|
||||
t.Errorf("Expected errgroup.TryGo to success!!!")
|
||||
}
|
||||
|
||||
// The limit set to 2, so 3rd function shall not be possible to call
|
||||
ok = eg.TryGo(func() error {
|
||||
t.Errorf("This function is unexpected to be called!!!")
|
||||
return nil
|
||||
})
|
||||
if ok {
|
||||
t.Errorf("Expected errgroup.TryGo to fail!!!")
|
||||
}
|
||||
|
||||
ch1 <- nil
|
||||
ch2 <- nil
|
||||
err := eg.Wait()
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expected a nil error, got=%+v", err)
|
||||
}
|
||||
}
|
5
msgio.go
5
msgio.go
|
@ -9,6 +9,7 @@ import (
|
|||
"io"
|
||||
"sync"
|
||||
|
||||
errgrp "github.com/go-zeromq/zmq4/internal/errgroup"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
|
@ -63,11 +64,11 @@ func (q *qreader) Close() error {
|
|||
}
|
||||
|
||||
func (q *qreader) addConn(r *Conn) {
|
||||
go q.listen(q.ctx, r)
|
||||
q.mu.Lock()
|
||||
q.sem.enable()
|
||||
q.rs = append(q.rs, r)
|
||||
q.mu.Unlock()
|
||||
go q.listen(q.ctx, r)
|
||||
}
|
||||
|
||||
func (q *qreader) rmConn(r *Conn) {
|
||||
|
@ -167,7 +168,7 @@ func (mw *mwriter) rmConn(w *Conn) {
|
|||
|
||||
func (w *mwriter) write(ctx context.Context, msg Msg) error {
|
||||
w.sem.lock(ctx)
|
||||
grp, _ := errgroup.WithContext(ctx)
|
||||
grp, _ := errgrp.WithContext(ctx)
|
||||
w.mu.Lock()
|
||||
for i := range w.ws {
|
||||
ww := w.ws[i]
|
||||
|
|
|
@ -44,6 +44,13 @@ func WithDialerTimeout(timeout time.Duration) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithTimeout sets the timeout value for socket operations
|
||||
func WithTimeout(timeout time.Duration) Option {
|
||||
return func(s *socket) {
|
||||
s.timeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger sets a dedicated log.Logger for the socket.
|
||||
func WithLogger(msg *log.Logger) Option {
|
||||
return func(s *socket) {
|
||||
|
|
6
pub.go
6
pub.go
|
@ -39,7 +39,7 @@ func (pub *pubSocket) Close() error {
|
|||
// Send puts the message on the outbound send queue.
|
||||
// Send blocks until the message can be queued or the send deadline expires.
|
||||
func (pub *pubSocket) Send(msg Msg) error {
|
||||
ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.timeout())
|
||||
ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.Timeout())
|
||||
defer cancel()
|
||||
return pub.sck.w.write(ctx, msg)
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ func (pub *pubSocket) Send(msg Msg) error {
|
|||
// The message will be sent as a multipart message.
|
||||
func (pub *pubSocket) SendMulti(msg Msg) error {
|
||||
msg.multipart = true
|
||||
ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.timeout())
|
||||
ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.Timeout())
|
||||
defer cancel()
|
||||
return pub.sck.w.write(ctx, msg)
|
||||
}
|
||||
|
@ -149,11 +149,11 @@ func (q *pubQReader) Close() error {
|
|||
}
|
||||
|
||||
func (q *pubQReader) addConn(r *Conn) {
|
||||
go q.listen(q.ctx, r)
|
||||
q.mu.Lock()
|
||||
q.sem.enable()
|
||||
q.rs = append(q.rs, r)
|
||||
q.mu.Unlock()
|
||||
go q.listen(q.ctx, r)
|
||||
}
|
||||
|
||||
func (q *pubQReader) rmConn(r *Conn) {
|
||||
|
|
4
rep.go
4
rep.go
|
@ -34,7 +34,7 @@ func (rep *repSocket) Close() error {
|
|||
// Send puts the message on the outbound send queue.
|
||||
// Send blocks until the message can be queued or the send deadline expires.
|
||||
func (rep *repSocket) Send(msg Msg) error {
|
||||
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.timeout())
|
||||
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.Timeout())
|
||||
defer cancel()
|
||||
return rep.sck.w.write(ctx, msg)
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ func (rep *repSocket) Send(msg Msg) error {
|
|||
// The message will be sent as a multipart message.
|
||||
func (rep *repSocket) SendMulti(msg Msg) error {
|
||||
msg.multipart = true
|
||||
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.timeout())
|
||||
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.Timeout())
|
||||
defer cancel()
|
||||
return rep.sck.w.write(ctx, msg)
|
||||
}
|
||||
|
|
|
@ -134,6 +134,7 @@ func TestCancellation(t *testing.T) {
|
|||
|
||||
defer wg.Done()
|
||||
repCtx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
rep := zmq4.NewRep(repCtx)
|
||||
defer rep.Close()
|
||||
|
||||
|
|
4
req.go
4
req.go
|
@ -35,7 +35,7 @@ func (req *reqSocket) Close() error {
|
|||
// Send puts the message on the outbound send queue.
|
||||
// Send blocks until the message can be queued or the send deadline expires.
|
||||
func (req *reqSocket) Send(msg Msg) error {
|
||||
ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.timeout())
|
||||
ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.Timeout())
|
||||
defer cancel()
|
||||
return req.sck.w.write(ctx, msg)
|
||||
}
|
||||
|
@ -45,7 +45,7 @@ func (req *reqSocket) Send(msg Msg) error {
|
|||
// The message will be sent as a multipart message.
|
||||
func (req *reqSocket) SendMulti(msg Msg) error {
|
||||
msg.multipart = true
|
||||
ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.timeout())
|
||||
ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.Timeout())
|
||||
defer cancel()
|
||||
return req.sck.w.write(ctx, msg)
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ func (router *routerSocket) Close() error {
|
|||
// Send puts the message on the outbound send queue.
|
||||
// Send blocks until the message can be queued or the send deadline expires.
|
||||
func (router *routerSocket) Send(msg Msg) error {
|
||||
ctx, cancel := context.WithTimeout(router.sck.ctx, router.sck.timeout())
|
||||
ctx, cancel := context.WithTimeout(router.sck.ctx, router.sck.Timeout())
|
||||
defer cancel()
|
||||
return router.sck.w.write(ctx, msg)
|
||||
}
|
||||
|
@ -119,11 +119,11 @@ func (q *routerQReader) Close() error {
|
|||
}
|
||||
|
||||
func (q *routerQReader) addConn(r *Conn) {
|
||||
go q.listen(q.ctx, r)
|
||||
q.mu.Lock()
|
||||
q.sem.enable()
|
||||
q.rs = append(q.rs, r)
|
||||
q.mu.Unlock()
|
||||
go q.listen(q.ctx, r)
|
||||
}
|
||||
|
||||
func (q *routerQReader) rmConn(r *Conn) {
|
||||
|
|
11
socket.go
11
socket.go
|
@ -40,6 +40,7 @@ type socket struct {
|
|||
log *log.Logger
|
||||
subTopics func() []string
|
||||
autoReconnect bool
|
||||
timeout time.Duration
|
||||
|
||||
mu sync.RWMutex
|
||||
conns []*Conn // ZMTP connections
|
||||
|
@ -67,6 +68,7 @@ func newDefaultSocket(ctx context.Context, sockType SocketType) *socket {
|
|||
typ: sockType,
|
||||
retry: defaultRetry,
|
||||
maxRetries: defaultMaxRetries,
|
||||
timeout: defaultTimeout,
|
||||
sec: nullSecurity{},
|
||||
conns: nil,
|
||||
r: newQReader(ctx),
|
||||
|
@ -147,7 +149,7 @@ func (sck *socket) Close() error {
|
|||
// Send puts the message on the outbound send queue.
|
||||
// Send blocks until the message can be queued or the send deadline expires.
|
||||
func (sck *socket) Send(msg Msg) error {
|
||||
ctx, cancel := context.WithTimeout(sck.ctx, sck.timeout())
|
||||
ctx, cancel := context.WithTimeout(sck.ctx, sck.Timeout())
|
||||
defer cancel()
|
||||
return sck.w.write(ctx, msg)
|
||||
}
|
||||
|
@ -157,7 +159,7 @@ func (sck *socket) Send(msg Msg) error {
|
|||
// The message will be sent as a multipart message.
|
||||
func (sck *socket) SendMulti(msg Msg) error {
|
||||
msg.multipart = true
|
||||
ctx, cancel := context.WithTimeout(sck.ctx, sck.timeout())
|
||||
ctx, cancel := context.WithTimeout(sck.ctx, sck.Timeout())
|
||||
defer cancel()
|
||||
return sck.w.write(ctx, msg)
|
||||
}
|
||||
|
@ -365,9 +367,8 @@ func (sck *socket) SetOption(name string, value interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (sck *socket) timeout() time.Duration {
|
||||
// FIXME(sbinet): extract from options
|
||||
return defaultTimeout
|
||||
func (sck *socket) Timeout() time.Duration {
|
||||
return sck.timeout
|
||||
}
|
||||
|
||||
func (sck *socket) connReaper() {
|
||||
|
|
49
zmq4_timeout_test.go
Normal file
49
zmq4_timeout_test.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
// Copyright 2023 The go-zeromq Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package zmq4
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPushTimeout(t *testing.T) {
|
||||
ep := "ipc://@push_timeout_test"
|
||||
push := NewPush(context.Background(), WithTimeout(1*time.Second))
|
||||
defer push.Close()
|
||||
if err := push.Listen(ep); err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
pull := NewPull(context.Background())
|
||||
defer pull.Close()
|
||||
if err := pull.Dial(ep); err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// The ctx limits overall time of execution
|
||||
// If it gets canceled, that meain tests failed
|
||||
// as write to socket did not genereate timeout error
|
||||
t.Fatalf("test failed before being able to generate timeout error: %+v", ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
err := push.Send(NewMsgString("test string"))
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Fatalf("expected a context.DeadlineExceeded error, got=%+v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in a new issue