zmq4: add timeout support on send

Add internal/errgroup package to support cancellable error groups.

Fixes #147.

Authored-by: Sergey Egorov <sergey.egorov@teleste.com>
This commit is contained in:
Sergey Egorov 2023-12-15 10:37:49 +02:00 committed by GitHub
parent e16dc3e41e
commit 16ca7c091b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 306 additions and 16 deletions

View file

@ -0,0 +1,107 @@
// Copyright 2023 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package errgroup is bit more advanced than golang.org/x/sync/errgroup.
// Major difference is that when error group is created with WithContext
// the parent context would implicitly cancel all functions called by Go method.
package errgroup
import (
"context"
"golang.org/x/sync/errgroup"
)
// The Group is superior errgroup.Group which aborts whole group
// execution when parent context is cancelled
type Group struct {
grp *errgroup.Group
ctx context.Context
}
// WithContext creates Group and store inside parent context
// so the Go method would respect parent context cancellation
func WithContext(ctx context.Context) (*Group, context.Context) {
grp, child_ctx := errgroup.WithContext(ctx)
return &Group{grp: grp, ctx: ctx}, child_ctx
}
// Go runs the provided f function in a dedicated goroutine and waits for its
// completion or for the parent context cancellation.
func (g *Group) Go(f func() error) {
g.getErrGroup().Go(g.wrap(f))
}
// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
// If the error group was created via WithContext then the Wait returns error
// of cancelled parent context prior any functions calls complete.
func (g *Group) Wait() error {
return g.getErrGroup().Wait()
}
// SetLimit limits the number of active goroutines in this group to at most n.
// A negative value indicates no limit.
//
// Any subsequent call to the Go method will block until it can add an active
// goroutine without exceeding the configured limit.
//
// The limit must not be modified while any goroutines in the group are active.
func (g *Group) SetLimit(n int) {
g.getErrGroup().SetLimit(n)
}
// TryGo calls the given function in a new goroutine only if the number of
// active goroutines in the group is currently below the configured limit.
//
// The return value reports whether the goroutine was started.
func (g *Group) TryGo(f func() error) bool {
return g.getErrGroup().TryGo(g.wrap(f))
}
func (g *Group) wrap(f func() error) func() error {
if g.ctx == nil {
return f
}
return func() error {
// If parent context is canceled,
// just return its error and do not call func f
select {
case <-g.ctx.Done():
return g.ctx.Err()
default:
}
// Create return channel and call func f
// Buffered channel is used as the following select
// may be exiting by context cancellation
// and in such case the write to channel can be block
// and cause the go routine leak
ch := make(chan error, 1)
go func() {
ch <- f()
}()
// Wait func f complete or
// parent context to be cancelled,
select {
case err := <-ch:
return err
case <-g.ctx.Done():
return g.ctx.Err()
}
}
}
// The getErrGroup returns actual x/sync/errgroup.Group.
// If the group is not allocated it would implicitly allocate it.
// Thats allows the internal/errgroup.Group be fully
// compatible to x/sync/errgroup.Group
func (g *Group) getErrGroup() *errgroup.Group {
if g.grp == nil {
g.grp = &errgroup.Group{}
}
return g.grp
}

View file

@ -0,0 +1,124 @@
// Copyright 2023 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package errgroup
import (
"context"
"fmt"
"testing"
"golang.org/x/sync/errgroup"
)
// TestRegularErrGroupDoesNotRespectParentContext checks regular errgroup behavior
// where errgroup.WithContext does not respect the parent context
func TestRegularErrGroupDoesNotRespectParentContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
eg, _ := errgroup.WithContext(ctx)
what := fmt.Errorf("func generated error")
ch := make(chan error)
eg.Go(func() error { return <-ch })
cancel() // abort parent context
ch <- what // signal the func in regular errgroup to fail
err := eg.Wait()
// The error shall be one returned by the function
// as regular errgroup.WithContext does not respect parent context
if err != what {
t.Errorf("invalid error. got=%+v, want=%+v", err, what)
}
}
// TestErrGroupWithContextCanCallFunctions checks the errgroup operations
// are fine working and errgroup called function can return error
func TestErrGroupWithContextCanCallFunctions(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
eg, _ := WithContext(ctx)
what := fmt.Errorf("func generated error")
ch := make(chan error)
eg.Go(func() error { return <-ch })
ch <- what // signal the func in errgroup to fail
err := eg.Wait() // wait errgroup complete and read error
// The error shall be one returned by the function
if err != what {
t.Errorf("invalid error. got=%+v, want=%+v", err, what)
}
}
// TestErrGroupWithContextDoesRespectParentContext checks the errgroup operations
// are cancellable by parent context
func TestErrGroupWithContextDoesRespectParentContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
eg, _ := WithContext(ctx)
s1 := make(chan struct{})
s2 := make(chan struct{})
eg.Go(func() error {
s1 <- struct{}{}
<-s2
return fmt.Errorf("func generated error")
})
// We have no set limit to errgroup so
// shall be able to start function via TryGo
if ok := eg.TryGo(func() error { return nil }); !ok {
t.Errorf("Expected TryGo to be able start function!!!")
}
<-s1 // wait for function to start
cancel() // abort parent context
eg.Go(func() error {
t.Errorf("The parent context was already cancelled and this function shall not be called!!!")
return nil
})
s2 <- struct{}{} // signal the func in regular errgroup to fail
err := eg.Wait() // wait errgroup complete and read error
// The error shall be one returned by the function
// as regular errgroup.WithContext does not respect parent context
if err != context.Canceled {
t.Errorf("expected a context.Canceled error, got=%+v", err)
}
}
// TestErrGroupFallback tests fallback logic to be compatible with x/sync/errgroup
func TestErrGroupFallback(t *testing.T) {
eg := Group{}
eg.SetLimit(2)
ch1 := make(chan error)
eg.Go(func() error { return <-ch1 })
ch2 := make(chan error)
ok := eg.TryGo(func() error { return <-ch2 })
if !ok {
t.Errorf("Expected errgroup.TryGo to success!!!")
}
// The limit set to 2, so 3rd function shall not be possible to call
ok = eg.TryGo(func() error {
t.Errorf("This function is unexpected to be called!!!")
return nil
})
if ok {
t.Errorf("Expected errgroup.TryGo to fail!!!")
}
ch1 <- nil
ch2 <- nil
err := eg.Wait()
if err != nil {
t.Errorf("expected a nil error, got=%+v", err)
}
}

View file

@ -9,6 +9,7 @@ import (
"io"
"sync"
errgrp "github.com/go-zeromq/zmq4/internal/errgroup"
"golang.org/x/sync/errgroup"
)
@ -63,11 +64,11 @@ func (q *qreader) Close() error {
}
func (q *qreader) addConn(r *Conn) {
go q.listen(q.ctx, r)
q.mu.Lock()
q.sem.enable()
q.rs = append(q.rs, r)
q.mu.Unlock()
go q.listen(q.ctx, r)
}
func (q *qreader) rmConn(r *Conn) {
@ -167,7 +168,7 @@ func (mw *mwriter) rmConn(w *Conn) {
func (w *mwriter) write(ctx context.Context, msg Msg) error {
w.sem.lock(ctx)
grp, _ := errgroup.WithContext(ctx)
grp, _ := errgrp.WithContext(ctx)
w.mu.Lock()
for i := range w.ws {
ww := w.ws[i]

View file

@ -44,6 +44,13 @@ func WithDialerTimeout(timeout time.Duration) Option {
}
}
// WithTimeout sets the timeout value for socket operations
func WithTimeout(timeout time.Duration) Option {
return func(s *socket) {
s.timeout = timeout
}
}
// WithLogger sets a dedicated log.Logger for the socket.
func WithLogger(msg *log.Logger) Option {
return func(s *socket) {

6
pub.go
View file

@ -39,7 +39,7 @@ func (pub *pubSocket) Close() error {
// Send puts the message on the outbound send queue.
// Send blocks until the message can be queued or the send deadline expires.
func (pub *pubSocket) Send(msg Msg) error {
ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.timeout())
ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.Timeout())
defer cancel()
return pub.sck.w.write(ctx, msg)
}
@ -49,7 +49,7 @@ func (pub *pubSocket) Send(msg Msg) error {
// The message will be sent as a multipart message.
func (pub *pubSocket) SendMulti(msg Msg) error {
msg.multipart = true
ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.timeout())
ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.Timeout())
defer cancel()
return pub.sck.w.write(ctx, msg)
}
@ -149,11 +149,11 @@ func (q *pubQReader) Close() error {
}
func (q *pubQReader) addConn(r *Conn) {
go q.listen(q.ctx, r)
q.mu.Lock()
q.sem.enable()
q.rs = append(q.rs, r)
q.mu.Unlock()
go q.listen(q.ctx, r)
}
func (q *pubQReader) rmConn(r *Conn) {

4
rep.go
View file

@ -34,7 +34,7 @@ func (rep *repSocket) Close() error {
// Send puts the message on the outbound send queue.
// Send blocks until the message can be queued or the send deadline expires.
func (rep *repSocket) Send(msg Msg) error {
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.timeout())
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.Timeout())
defer cancel()
return rep.sck.w.write(ctx, msg)
}
@ -44,7 +44,7 @@ func (rep *repSocket) Send(msg Msg) error {
// The message will be sent as a multipart message.
func (rep *repSocket) SendMulti(msg Msg) error {
msg.multipart = true
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.timeout())
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.Timeout())
defer cancel()
return rep.sck.w.write(ctx, msg)
}

View file

@ -134,6 +134,7 @@ func TestCancellation(t *testing.T) {
defer wg.Done()
repCtx, cancel := context.WithCancel(context.Background())
defer cancel()
rep := zmq4.NewRep(repCtx)
defer rep.Close()

4
req.go
View file

@ -35,7 +35,7 @@ func (req *reqSocket) Close() error {
// Send puts the message on the outbound send queue.
// Send blocks until the message can be queued or the send deadline expires.
func (req *reqSocket) Send(msg Msg) error {
ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.timeout())
ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.Timeout())
defer cancel()
return req.sck.w.write(ctx, msg)
}
@ -45,7 +45,7 @@ func (req *reqSocket) Send(msg Msg) error {
// The message will be sent as a multipart message.
func (req *reqSocket) SendMulti(msg Msg) error {
msg.multipart = true
ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.timeout())
ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.Timeout())
defer cancel()
return req.sck.w.write(ctx, msg)
}

View file

@ -35,7 +35,7 @@ func (router *routerSocket) Close() error {
// Send puts the message on the outbound send queue.
// Send blocks until the message can be queued or the send deadline expires.
func (router *routerSocket) Send(msg Msg) error {
ctx, cancel := context.WithTimeout(router.sck.ctx, router.sck.timeout())
ctx, cancel := context.WithTimeout(router.sck.ctx, router.sck.Timeout())
defer cancel()
return router.sck.w.write(ctx, msg)
}
@ -119,11 +119,11 @@ func (q *routerQReader) Close() error {
}
func (q *routerQReader) addConn(r *Conn) {
go q.listen(q.ctx, r)
q.mu.Lock()
q.sem.enable()
q.rs = append(q.rs, r)
q.mu.Unlock()
go q.listen(q.ctx, r)
}
func (q *routerQReader) rmConn(r *Conn) {

View file

@ -40,6 +40,7 @@ type socket struct {
log *log.Logger
subTopics func() []string
autoReconnect bool
timeout time.Duration
mu sync.RWMutex
conns []*Conn // ZMTP connections
@ -67,6 +68,7 @@ func newDefaultSocket(ctx context.Context, sockType SocketType) *socket {
typ: sockType,
retry: defaultRetry,
maxRetries: defaultMaxRetries,
timeout: defaultTimeout,
sec: nullSecurity{},
conns: nil,
r: newQReader(ctx),
@ -147,7 +149,7 @@ func (sck *socket) Close() error {
// Send puts the message on the outbound send queue.
// Send blocks until the message can be queued or the send deadline expires.
func (sck *socket) Send(msg Msg) error {
ctx, cancel := context.WithTimeout(sck.ctx, sck.timeout())
ctx, cancel := context.WithTimeout(sck.ctx, sck.Timeout())
defer cancel()
return sck.w.write(ctx, msg)
}
@ -157,7 +159,7 @@ func (sck *socket) Send(msg Msg) error {
// The message will be sent as a multipart message.
func (sck *socket) SendMulti(msg Msg) error {
msg.multipart = true
ctx, cancel := context.WithTimeout(sck.ctx, sck.timeout())
ctx, cancel := context.WithTimeout(sck.ctx, sck.Timeout())
defer cancel()
return sck.w.write(ctx, msg)
}
@ -365,9 +367,8 @@ func (sck *socket) SetOption(name string, value interface{}) error {
return nil
}
func (sck *socket) timeout() time.Duration {
// FIXME(sbinet): extract from options
return defaultTimeout
func (sck *socket) Timeout() time.Duration {
return sck.timeout
}
func (sck *socket) connReaper() {

49
zmq4_timeout_test.go Normal file
View file

@ -0,0 +1,49 @@
// Copyright 2023 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zmq4
import (
"context"
"testing"
"time"
)
func TestPushTimeout(t *testing.T) {
ep := "ipc://@push_timeout_test"
push := NewPush(context.Background(), WithTimeout(1*time.Second))
defer push.Close()
if err := push.Listen(ep); err != nil {
t.FailNow()
}
pull := NewPull(context.Background())
defer pull.Close()
if err := pull.Dial(ep); err != nil {
t.FailNow()
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
for {
select {
case <-ctx.Done():
// The ctx limits overall time of execution
// If it gets canceled, that meain tests failed
// as write to socket did not genereate timeout error
t.Fatalf("test failed before being able to generate timeout error: %+v", ctx.Err())
default:
}
err := push.Send(NewMsgString("test string"))
if err == nil {
continue
}
if err != context.DeadlineExceeded {
t.Fatalf("expected a context.DeadlineExceeded error, got=%+v", err)
}
break
}
}