zmq4/zmq4_pair_test.go

174 lines
3.6 KiB
Go

// Copyright 2020 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zmq4_test
import (
"bytes"
"context"
"fmt"
"sync"
"testing"
"time"
"git.gammaspectra.live/P2Pool/zmq4"
"golang.org/x/sync/errgroup"
)
var (
pairs = []testCasePair{
{
name: "tcp-pair-pair",
endpoint: must(EndPoint("tcp")),
srv: zmq4.NewPair(bkg),
cli: zmq4.NewPair(bkg),
},
{
name: "ipc-pair-pair",
endpoint: "ipc://ipc-pair-pair",
srv: zmq4.NewPair(bkg),
cli: zmq4.NewPair(bkg),
},
{
name: "inproc-pair-pair",
endpoint: "inproc://inproc-pair-pair",
srv: zmq4.NewPair(bkg),
cli: zmq4.NewPair(bkg),
},
}
)
type testCasePair struct {
name string
skip bool
endpoint string
srv zmq4.Socket
cli zmq4.Socket
}
func TestPair(t *testing.T) {
var (
msg0 = zmq4.NewMsgString("")
msg1 = zmq4.NewMsgString("MSG 1")
msg2 = zmq4.NewMsgString("msg 2")
msgs = []zmq4.Msg{
msg0,
msg1,
msg2,
}
)
for i := range pairs {
tc := pairs[i]
t.Run(tc.name, func(t *testing.T) {
defer tc.srv.Close()
defer tc.cli.Close()
ep := tc.endpoint
cleanUp(ep)
if tc.skip {
t.Skipf(tc.name)
}
// t.Parallel()
var (
wg1 sync.WaitGroup
wg2 sync.WaitGroup
)
wg1.Add(1)
wg2.Add(1)
ctx, timeout := context.WithTimeout(context.Background(), 20*time.Second)
defer timeout()
grp, ctx := errgroup.WithContext(ctx)
grp.Go(func() error {
err := tc.srv.Listen(ep)
if err != nil {
return fmt.Errorf("could not listen: %w", err)
}
if addr := tc.srv.Addr(); addr == nil {
return fmt.Errorf("listener with nil Addr")
}
wg1.Wait()
wg2.Done()
for _, msg := range msgs {
err = tc.srv.Send(msg)
if err != nil {
return fmt.Errorf("could not send message %v: %w", msg, err)
}
reply, err := tc.srv.Recv()
if err != nil {
return fmt.Errorf("could not recv reply to %v: %w", msg, err)
}
if got, want := reply, zmq4.NewMsgString("reply: "+string(msg.Bytes())); !bytes.Equal(got.Bytes(), want.Bytes()) {
return fmt.Errorf("invalid cli reply for msg #%d: got=%v, want=%v", i, got, want)
}
}
quit, err := tc.srv.Recv()
if err != nil {
return fmt.Errorf("could not recv QUIT message: %w", err)
}
if got, want := quit, zmq4.NewMsgString("QUIT"); !bytes.Equal(got.Bytes(), want.Bytes()) {
return fmt.Errorf("invalid QUIT message from cli: got=%v, want=%v", got, want)
}
return err
})
grp.Go(func() error {
err := tc.cli.Dial(ep)
if err != nil {
return fmt.Errorf("could not dial: %w", err)
}
wg1.Done()
wg2.Wait()
for i := range msgs {
msg, err := tc.cli.Recv()
if err != nil {
return fmt.Errorf("could not recv #%d msg from srv: %w", i, err)
}
if !bytes.Equal(msg.Bytes(), msgs[i].Bytes()) {
return fmt.Errorf("invalid #%d msg from srv: got=%v, want=%v",
i, msg, msgs[i],
)
}
err = tc.cli.Send(zmq4.NewMsgString("reply: " + string(msg.Bytes())))
if err != nil {
return fmt.Errorf("could not send message %v: %w", msg, err)
}
}
err = tc.cli.Send(zmq4.NewMsgString("QUIT"))
if err != nil {
return fmt.Errorf("could not send QUIT message: %w", err)
}
return err
})
if err := grp.Wait(); err != nil {
t.Fatalf("error: %+v", err)
}
if err := ctx.Err(); err != nil && err != context.Canceled {
t.Fatalf("error: %+v", err)
}
})
}
}