zmq4/socket_test.go

392 lines
10 KiB
Go

// Copyright 2020 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zmq4_test
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"testing"
"time"
"git.gammaspectra.live/P2Pool/zmq4"
"git.gammaspectra.live/P2Pool/zmq4/transport"
"golang.org/x/sync/errgroup"
)
func TestInvalidConn(t *testing.T) {
// t.Parallel()
ep := must(EndPoint("tcp"))
cleanUp(ep)
ctx, timeout := context.WithTimeout(context.Background(), 20*time.Second)
defer timeout()
pub := zmq4.NewPub(ctx)
defer pub.Close()
err := pub.Listen(ep)
if err != nil {
t.Fatalf("could not listen on end-point: %+v", err)
}
grp, ctx := errgroup.WithContext(ctx)
grp.Go(func() error {
conn, err := net.Dial("tcp", ep[len("tcp://"):])
if err != nil {
return fmt.Errorf("could not dial %q: %w", ep, err)
}
defer conn.Close()
var reply = make([]byte, 64)
_, err = io.ReadFull(conn, reply)
if err != nil {
return fmt.Errorf("could not read reply bytes...: %w", err)
}
_, err = conn.Write(make([]byte, 64))
if err != nil {
return fmt.Errorf("could not send bytes...: %w", err)
}
time.Sleep(1 * time.Second) // FIXME(sbinet): hugly.
return nil
})
if err := grp.Wait(); err != nil {
t.Fatalf("error: %+v", err)
}
if err := ctx.Err(); err != nil && err != context.Canceled {
t.Fatalf("error: %+v", err)
}
}
func TestConnPairs(t *testing.T) {
// t.Parallel()
bkg := context.Background()
for _, tc := range []struct {
name string
srv zmq4.Socket
wrong zmq4.Socket
cli zmq4.Socket
}{
{
name: "pair",
srv: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)),
wrong: zmq4.NewSub(bkg, zmq4.WithLogger(zmq4.Devnull)),
cli: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)),
},
{
name: "pub",
srv: zmq4.NewPub(bkg, zmq4.WithLogger(zmq4.Devnull)),
wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)),
cli: zmq4.NewSub(bkg, zmq4.WithLogger(zmq4.Devnull)),
},
{
name: "sub",
srv: zmq4.NewSub(bkg, zmq4.WithLogger(zmq4.Devnull)),
wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)),
cli: zmq4.NewPub(bkg, zmq4.WithLogger(zmq4.Devnull)),
},
{
name: "req",
srv: zmq4.NewReq(bkg, zmq4.WithLogger(zmq4.Devnull)),
wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)),
cli: zmq4.NewRep(bkg, zmq4.WithLogger(zmq4.Devnull)),
},
{
name: "rep",
srv: zmq4.NewRep(bkg, zmq4.WithLogger(zmq4.Devnull)),
wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)),
cli: zmq4.NewReq(bkg, zmq4.WithLogger(zmq4.Devnull)),
},
{
name: "dealer",
srv: zmq4.NewDealer(bkg, zmq4.WithLogger(zmq4.Devnull)),
wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)),
cli: zmq4.NewRouter(bkg, zmq4.WithLogger(zmq4.Devnull)),
},
{
name: "router",
srv: zmq4.NewRouter(bkg, zmq4.WithLogger(zmq4.Devnull)),
wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)),
cli: zmq4.NewDealer(bkg, zmq4.WithLogger(zmq4.Devnull)),
},
{
name: "pull",
srv: zmq4.NewPull(bkg, zmq4.WithLogger(zmq4.Devnull)),
wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)),
cli: zmq4.NewPush(bkg, zmq4.WithLogger(zmq4.Devnull)),
},
{
name: "push",
srv: zmq4.NewPush(bkg, zmq4.WithLogger(zmq4.Devnull)),
wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)),
cli: zmq4.NewPull(bkg, zmq4.WithLogger(zmq4.Devnull)),
},
{
name: "xpub",
srv: zmq4.NewXPub(bkg, zmq4.WithLogger(zmq4.Devnull)),
wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)),
cli: zmq4.NewXSub(bkg, zmq4.WithLogger(zmq4.Devnull)),
},
{
name: "xsub",
srv: zmq4.NewXSub(bkg, zmq4.WithLogger(zmq4.Devnull)),
wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)),
cli: zmq4.NewXPub(bkg, zmq4.WithLogger(zmq4.Devnull)),
},
} {
t.Run(tc.name, func(t *testing.T) {
ep := must(EndPoint("tcp"))
cleanUp(ep)
_, timeout := context.WithTimeout(bkg, 20*time.Second)
defer timeout()
defer tc.srv.Close()
defer tc.wrong.Close()
defer tc.cli.Close()
err := tc.srv.Listen(ep)
if err != nil {
t.Fatalf("could not listen on %q: %+v", ep, err)
}
err = tc.wrong.Dial(ep)
if err == nil {
t.Fatalf("dialed %q", ep)
}
want := fmt.Errorf("zmq4: could not open a ZMTP connection: zmq4: could not initialize ZMTP connection: zmq4: peer=%q not compatible with %q", tc.srv.Type(), tc.wrong.Type())
if got, want := err.Error(), want.Error(); got != want {
t.Fatalf("invalid error:\ngot = %v\nwant= %v", got, want)
}
err = tc.cli.Dial(ep)
if err != nil {
t.Fatalf("could not dial %q: %+v", ep, err)
}
})
}
}
func TestConnReaperDeadlock(t *testing.T) {
// Should avoid deadlock when multiple clients are closed rapidly.
ep := must(EndPoint("tcp"))
defer cleanUp(ep)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Bind the server.
srv := zmq4.NewRouter(ctx, zmq4.WithLogger(zmq4.Devnull))
if err := srv.Listen(ep); err != nil {
t.Fatalf("could not listen on %q: %+v", ep, err)
}
defer srv.Close()
// Connect 10 clients.
var clients []zmq4.Socket
for i := 0; i < 10; i++ {
id := fmt.Sprint("client-", i)
c := zmq4.NewReq(ctx, zmq4.WithLogger(zmq4.Devnull), zmq4.WithID(zmq4.SocketIdentity(id)))
if err := c.Dial(ep); err != nil {
t.Fatalf("could not dial %q: %+v", ep, err)
}
clients = append(clients, c)
}
// Disconnect 5 of them _from the client side_. The server does not know
// the client is gone until it tries to send a message below.
for i := 0; i < 5; i++ {
clients[i].Close()
}
// Now try to send a message from the server to all 10 clients.
msg := zmq4.NewMsgFrom(nil, nil, []byte("payload"))
for i := range clients {
id := fmt.Sprint("client-", i)
msg.Frames[0] = []byte(id)
if err := srv.Send(msg); err != nil {
t.Logf("Send to %s failed: %+v\n", id, err)
}
}
for i := 5; i < 10; i++ {
clients[i].Close()
}
}
func TestSocketSendSubscriptionOnConnect(t *testing.T) {
endpoint := "inproc://test-resub"
message := "test"
sub := zmq4.NewSub(context.Background())
defer sub.Close()
pub := zmq4.NewPub(context.Background())
defer pub.Close()
sub.SetOption(zmq4.OptionSubscribe, message)
if err := sub.Listen(endpoint); err != nil {
t.Fatalf("Sub Dial failed: %v", err)
}
if err := pub.Dial(endpoint); err != nil {
t.Fatalf("Pub Dial failed: %v", err)
}
var wg sync.WaitGroup
defer wg.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
wg.Add(1)
go func() {
defer wg.Done()
for {
pub.Send(zmq4.NewMsgFromString([]string{message}))
if ctx.Err() != nil {
return
}
time.Sleep(1 * time.Millisecond)
}
}()
msg, err := sub.Recv()
if err != nil {
t.Fatalf("Recv failed: %v", err)
}
if string(msg.Frames[0]) != message {
t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message)
}
}
type transportMock struct {
dialCalledCount int
errOnDial bool
conn net.Conn
}
func (t *transportMock) Dial(ctx context.Context, dialer transport.Dialer, addr string) (net.Conn, error) {
t.dialCalledCount++
if t.errOnDial {
return nil, errors.New("test error")
}
return t.conn, nil
}
func (t *transportMock) Listen(ctx context.Context, addr string) (net.Listener, error) {
return nil, nil
}
func (t *transportMock) Addr(ep string) (addr string, err error) {
return "", nil
}
func TestConnMaxRetries(t *testing.T) {
retryCount := 123
socket := zmq4.NewSub(context.Background(), zmq4.WithDialerRetry(time.Microsecond), zmq4.WithDialerMaxRetries(retryCount))
transport := &transportMock{errOnDial: true}
transportName := "test-maxretries"
zmq4.RegisterTransport(transportName, transport)
err := socket.Dial(transportName + "://test")
if err == nil {
t.Fatal("expected error")
}
if transport.dialCalledCount != retryCount+1 {
t.Fatalf("Dial called %d times, expected %d", transport.dialCalledCount, retryCount+1)
}
}
func TestConnMaxRetriesInfinite(t *testing.T) {
timeout := time.Millisecond
retryTime := time.Nanosecond
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
socket := zmq4.NewSub(ctx, zmq4.WithDialerRetry(retryTime), zmq4.WithDialerMaxRetries(-1))
transport := &transportMock{errOnDial: true}
transportName := "test-infiniteretries"
zmq4.RegisterTransport(transportName, transport)
err := socket.Dial(transportName + "://test")
if err == nil {
t.Fatal("expected error")
}
atLeastExpectedRetries := 100
if transport.dialCalledCount < atLeastExpectedRetries {
t.Fatalf("Dial called %d times, expected at least %d", transport.dialCalledCount, atLeastExpectedRetries)
}
}
func TestSocketAutomaticReconnect(t *testing.T) {
ep, err := EndPoint("tcp")
if err != nil {
t.Fatalf("could not find endpoint: %+v", err)
}
message := "test"
var wg sync.WaitGroup
defer wg.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sendMessages := func(socket zmq4.Socket) {
wg.Add(1)
go func(t *testing.T) {
defer wg.Done()
for {
socket.Send(zmq4.NewMsgFromString([]string{message}))
if ctx.Err() != nil {
return
}
time.Sleep(1 * time.Millisecond)
}
}(t)
}
sub := zmq4.NewSub(context.Background(), zmq4.WithAutomaticReconnect(true))
defer sub.Close()
sub.SetOption(zmq4.OptionSubscribe, message)
pub := zmq4.NewPub(context.Background())
if err := pub.Listen(ep); err != nil {
t.Fatalf("Pub Dial failed: %v", err)
}
if err := sub.Dial(ep); err != nil {
t.Fatalf("Sub Dial failed: %v", err)
}
sendMessages(pub)
checkConnectionWorking := func(socket zmq4.Socket) {
for {
msg, err := socket.Recv()
if errors.Is(err, io.EOF) {
continue
}
if err != nil {
t.Fatalf("Recv failed: %v", err)
}
if string(msg.Frames[0]) != message {
t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message)
}
return
}
}
checkConnectionWorking(sub)
pub.Close()
pub2 := zmq4.NewPub(context.Background())
defer pub2.Close()
if err := pub2.Listen(ep); err != nil {
t.Fatalf("Sub Listen failed: %v", err)
}
sendMessages(pub2)
checkConnectionWorking(sub)
}