zmq4/zmq4_pushpull_test.go

132 lines
2.7 KiB
Go

// Copyright 2018 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zmq4_test
import (
"context"
"fmt"
"reflect"
"testing"
"time"
"git.gammaspectra.live/P2Pool/zmq4"
"golang.org/x/sync/errgroup"
)
var (
pushpulls = []testCasePushPull{
{
name: "tcp-push-pull",
endpoint: must(EndPoint("tcp")),
push: zmq4.NewPush(bkg),
pull: zmq4.NewPull(bkg),
},
{
name: "ipc-push-pull",
endpoint: "ipc://ipc-push-pull",
push: zmq4.NewPush(bkg),
pull: zmq4.NewPull(bkg),
},
{
name: "inproc-push-pull",
endpoint: "inproc://push-pull",
push: zmq4.NewPush(bkg),
pull: zmq4.NewPull(bkg),
},
}
)
type testCasePushPull struct {
name string
skip bool
endpoint string
push zmq4.Socket
pull zmq4.Socket
}
func TestPushPull(t *testing.T) {
var (
hello = zmq4.NewMsg([]byte("HELLO WORLD"))
bye = zmq4.NewMsgFrom([]byte("GOOD"), []byte("BYE"))
)
for i := range pushpulls {
tc := pushpulls[i]
t.Run(tc.name, func(t *testing.T) {
defer tc.pull.Close()
defer tc.push.Close()
ep := tc.endpoint
cleanUp(ep)
if tc.skip {
t.Skipf(tc.name)
}
// t.Parallel()
ctx, timeout := context.WithTimeout(context.Background(), 20*time.Second)
defer timeout()
grp, _ := errgroup.WithContext(ctx)
grp.Go(func() error {
err := tc.push.Listen(ep)
if err != nil {
return fmt.Errorf("could not listen: %w", err)
}
if addr := tc.push.Addr(); addr == nil {
return fmt.Errorf("listener with nil Addr")
}
err = tc.push.Send(hello)
if err != nil {
return fmt.Errorf("could not send %v: %w", hello, err)
}
err = tc.push.Send(bye)
if err != nil {
return fmt.Errorf("could not send %v: %w", bye, err)
}
return err
})
grp.Go(func() error {
err := tc.pull.Dial(ep)
if err != nil {
return fmt.Errorf("could not dial: %w", err)
}
if addr := tc.pull.Addr(); addr != nil {
return fmt.Errorf("dialer with non-nil Addr")
}
msg, err := tc.pull.Recv()
if err != nil {
return fmt.Errorf("could not recv %v: %w", hello, err)
}
if got, want := msg, hello; !reflect.DeepEqual(got, want) {
return fmt.Errorf("recv1: got = %v, want= %v", got, want)
}
msg, err = tc.pull.Recv()
if err != nil {
return fmt.Errorf("could not recv %v: %w", bye, err)
}
if got, want := msg, bye; !reflect.DeepEqual(got, want) {
return fmt.Errorf("recv2: got = %v, want= %v", got, want)
}
return err
})
if err := grp.Wait(); err != nil {
t.Fatalf("error: %+v", err)
}
})
}
}