zmq4: fix another connection reaper deadlock
Fixes #149 Co-authored-by: Sergey Egorov <sergey.egorov@teleste.com> Co-authored-by: Sebastien Binet <binet@cern.ch>
This commit is contained in:
parent
16ca7c091b
commit
e75c615ba1
102
reaper_test.go
Normal file
102
reaper_test.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
// Copyright 2024 The go-zeromq Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package zmq4
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestConnReaperDeadlock2(t *testing.T) {
|
||||
ep := must(EndPoint("tcp"))
|
||||
defer cleanUp(ep)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Bind the server.
|
||||
srv := NewRouter(ctx, WithLogger(Devnull)).(*routerSocket)
|
||||
if err := srv.Listen(ep); err != nil {
|
||||
t.Fatalf("could not listen on %q: %+v", ep, err)
|
||||
}
|
||||
defer srv.Close()
|
||||
|
||||
// Add modified clients connection to server
|
||||
// so any send to client will trigger context switch
|
||||
// and be failing.
|
||||
// Idea is that while srv.Send is progressing,
|
||||
// the connection will be closed and assigned
|
||||
// for connection reaper, and reaper will try to remove those
|
||||
id := "client-x"
|
||||
srv.sck.mu.Lock()
|
||||
rmw := srv.sck.w.(*routerMWriter)
|
||||
for i := 0; i < 2; i++ {
|
||||
w := &Conn{}
|
||||
w.Peer.Meta = make(Metadata)
|
||||
w.Peer.Meta[sysSockID] = id
|
||||
w.rw = &sockSendEOF{}
|
||||
w.onCloseErrorCB = srv.sck.scheduleRmConn
|
||||
// Do not to call srv.addConn as we dont want to have listener on this fake socket
|
||||
rmw.addConn(w)
|
||||
srv.sck.conns = append(srv.sck.conns, w)
|
||||
}
|
||||
srv.sck.mu.Unlock()
|
||||
|
||||
// Now try to send a message from the server to all clients.
|
||||
msg := NewMsgFrom(nil, nil, []byte("payload"))
|
||||
msg.Frames[0] = []byte(id)
|
||||
if err := srv.Send(msg); err != nil {
|
||||
t.Logf("Send to %s failed: %+v\n", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
type sockSendEOF struct {
|
||||
}
|
||||
|
||||
var a atomic.Int32
|
||||
|
||||
func (r *sockSendEOF) Write(b []byte) (n int, err error) {
|
||||
// Each odd write fails asap.
|
||||
// Each even write fails after sleep.
|
||||
// Such a way we ensure the short write failure
|
||||
// will cause socket be assinged to connection reaper
|
||||
// while srv.Send is still in progress due to long writes.
|
||||
if x := a.Add(1); x&1 == 0 {
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (r *sockSendEOF) Read(b []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *sockSendEOF) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *sockSendEOF) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *sockSendEOF) RemoteAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *sockSendEOF) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *sockSendEOF) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *sockSendEOF) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
10
socket.go
10
socket.go
|
@ -384,10 +384,16 @@ func (sck *socket) connReaper() {
|
|||
return
|
||||
}
|
||||
|
||||
for _, c := range sck.closedConns {
|
||||
// Clone the known closed connections to avoid data race
|
||||
// and remove those under reaper unlocked.
|
||||
// That should fix the deadlock reported in #149.
|
||||
cc := append([]*Conn{}, sck.closedConns...) // clone
|
||||
sck.closedConns = sck.closedConns[:0]
|
||||
sck.reaperCond.L.Unlock()
|
||||
for _, c := range cc {
|
||||
sck.rmConn(c)
|
||||
}
|
||||
sck.closedConns = nil
|
||||
sck.reaperCond.L.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
31
zall_test.go
31
zall_test.go
|
@ -5,10 +5,41 @@
|
|||
package zmq4
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
)
|
||||
|
||||
var (
|
||||
Devnull = log.New(io.Discard, "zmq4: ", 0)
|
||||
)
|
||||
|
||||
func must(str string, err error) string {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
func EndPoint(transport string) (string, error) {
|
||||
switch transport {
|
||||
case "tcp":
|
||||
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
l, err := net.ListenTCP("tcp", addr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer l.Close()
|
||||
return fmt.Sprintf("tcp://%s", l.Addr()), nil
|
||||
case "ipc":
|
||||
return "ipc://tmp-" + newUUID(), nil
|
||||
case "inproc":
|
||||
return "inproc://tmp-" + newUUID(), nil
|
||||
default:
|
||||
panic("invalid transport: [" + transport + "]")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue