zmq4/rep.go
Sergey Egorov 16ca7c091b
zmq4: add timeout support on send
Add internal/errgroup package to support cancellable error groups.

Fixes #147.

Authored-by: Sergey Egorov <sergey.egorov@teleste.com>
2023-12-15 09:37:49 +01:00

341 lines
6.6 KiB
Go

// Copyright 2018 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zmq4
import (
"context"
"fmt"
"net"
"sync"
)
// NewRep returns a new REP ZeroMQ socket.
// The returned socket value is initially unbound.
func NewRep(ctx context.Context, opts ...Option) Socket {
rep := &repSocket{newSocket(ctx, Rep, opts...)}
sharedState := newRepState()
rep.sck.w = newRepWriter(rep.sck.ctx, sharedState)
rep.sck.r = newRepReader(rep.sck.ctx, sharedState)
return rep
}
// repSocket is a REP ZeroMQ socket.
type repSocket struct {
sck *socket
}
// Close closes the open Socket
func (rep *repSocket) Close() error {
return rep.sck.Close()
}
// Send puts the message on the outbound send queue.
// Send blocks until the message can be queued or the send deadline expires.
func (rep *repSocket) Send(msg Msg) error {
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.Timeout())
defer cancel()
return rep.sck.w.write(ctx, msg)
}
// SendMulti puts the message on the outbound send queue.
// SendMulti blocks until the message can be queued or the send deadline expires.
// The message will be sent as a multipart message.
func (rep *repSocket) SendMulti(msg Msg) error {
msg.multipart = true
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.Timeout())
defer cancel()
return rep.sck.w.write(ctx, msg)
}
// Recv receives a complete message.
func (rep *repSocket) Recv() (Msg, error) {
ctx, cancel := context.WithCancel(rep.sck.ctx)
defer cancel()
var msg Msg
err := rep.sck.r.read(ctx, &msg)
return msg, err
}
// Listen connects a local endpoint to the Socket.
func (rep *repSocket) Listen(ep string) error {
return rep.sck.Listen(ep)
}
// Dial connects a remote endpoint to the Socket.
func (rep *repSocket) Dial(ep string) error {
return rep.sck.Dial(ep)
}
// Type returns the type of this Socket (PUB, SUB, ...)
func (rep *repSocket) Type() SocketType {
return rep.sck.Type()
}
// Addr returns the listener's address.
// Addr returns nil if the socket isn't a listener.
func (rep *repSocket) Addr() net.Addr {
return rep.sck.Addr()
}
// GetOption is used to retrieve an option for a socket.
func (rep *repSocket) GetOption(name string) (interface{}, error) {
return rep.sck.GetOption(name)
}
// SetOption is used to set an option for a socket.
func (rep *repSocket) SetOption(name string, value interface{}) error {
return rep.sck.SetOption(name, value)
}
type repMsg struct {
conn *Conn
msg Msg
}
type repReader struct {
ctx context.Context
state *repState
mu sync.Mutex
conns []*Conn
msgCh chan repMsg
}
func newRepReader(ctx context.Context, state *repState) *repReader {
const qsize = 10
return &repReader{
ctx: ctx,
msgCh: make(chan repMsg, qsize),
state: state,
}
}
func (r *repReader) addConn(c *Conn) {
r.mu.Lock()
r.conns = append(r.conns, c)
r.mu.Unlock()
go r.listen(r.ctx, c)
}
func (r *repReader) rmConn(conn *Conn) {
r.mu.Lock()
defer r.mu.Unlock()
cur := -1
for i := range r.conns {
if r.conns[i] == conn {
cur = i
break
}
}
if cur >= 0 {
r.conns = append(r.conns[:cur], r.conns[cur+1:]...)
}
}
func (r *repReader) read(ctx context.Context, msg *Msg) error {
select {
case <-ctx.Done():
return ctx.Err()
case repMsg := <-r.msgCh:
if repMsg.msg.err != nil {
return repMsg.msg.err
}
pre, innerMsg := splitReq(repMsg.msg)
if pre == nil {
return fmt.Errorf("zmq4: invalid REP message")
}
*msg = innerMsg
r.state.Set(repMsg.conn, pre)
}
return nil
}
func (r *repReader) listen(ctx context.Context, conn *Conn) {
defer r.rmConn(conn)
defer conn.Close()
for {
msg := conn.read()
select {
case <-ctx.Done():
return
default:
if msg.err != nil {
return
}
r.msgCh <- repMsg{conn, msg}
}
}
}
func (r *repReader) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
var err error
for _, conn := range r.conns {
e := conn.Close()
if e != nil && err == nil {
err = e
}
}
r.conns = nil
return err
}
func splitReq(envelope Msg) (preamble [][]byte, msg Msg) {
for i, frame := range envelope.Frames {
if len(frame) != 0 {
continue
}
preamble = envelope.Frames[:i+1]
if i+1 < len(envelope.Frames) {
msg = NewMsgFrom(envelope.Frames[i+1:]...)
}
}
return
}
type repSendPayload struct {
conn *Conn
preamble [][]byte
msg Msg
}
type repWriter struct {
ctx context.Context
state *repState
mu sync.Mutex
conns []*Conn
sendCh chan repSendPayload
}
func (r repSendPayload) buildReplyMsg() Msg {
var frames = make([][]byte, 0, len(r.preamble)+len(r.msg.Frames))
frames = append(frames, r.preamble...)
frames = append(frames, r.msg.Frames...)
return NewMsgFrom(frames...)
}
func newRepWriter(ctx context.Context, state *repState) *repWriter {
r := &repWriter{
ctx: ctx,
state: state,
sendCh: make(chan repSendPayload),
}
go r.run()
return r
}
func (r *repWriter) addConn(w *Conn) {
r.mu.Lock()
r.conns = append(r.conns, w)
r.mu.Unlock()
}
func (r *repWriter) rmConn(conn *Conn) {
r.mu.Lock()
defer r.mu.Unlock()
cur := -1
for i := range r.conns {
if r.conns[i] == conn {
cur = i
break
}
}
if cur >= 0 {
r.conns = append(r.conns[:cur], r.conns[cur+1:]...)
}
}
func (r *repWriter) write(ctx context.Context, msg Msg) error {
conn, preamble := r.state.Get()
select {
case <-ctx.Done():
return ctx.Err()
case <-r.ctx.Done(): // repWriter.run() terminates on this, sendCh <- will not complete
return r.ctx.Err()
case r.sendCh <- repSendPayload{conn, preamble, msg}:
return nil
}
}
func (r *repWriter) run() {
for {
select {
case <-r.ctx.Done():
return
case payload, ok := <-r.sendCh:
if !ok {
return
}
r.sendPayload(payload)
}
}
}
func (r *repWriter) sendPayload(payload repSendPayload) {
r.mu.Lock()
defer r.mu.Unlock()
for _, conn := range r.conns {
if conn == payload.conn {
reply := payload.buildReplyMsg()
// not much we can do at this point. Perhaps log the error?
_ = conn.SendMsg(reply)
return
}
}
}
func (r *repWriter) Close() error {
close(r.sendCh)
r.mu.Lock()
defer r.mu.Unlock()
var err error
for _, conn := range r.conns {
e := conn.Close()
if e != nil && err == nil {
err = e
}
}
r.conns = nil
return err
}
type repState struct {
mu sync.Mutex
conn *Conn
preamble [][]byte // includes delimiter
}
func newRepState() *repState {
return &repState{}
}
func (r *repState) Get() (conn *Conn, preamble [][]byte) {
r.mu.Lock()
conn = r.conn
preamble = r.preamble
r.mu.Unlock()
return
}
func (r *repState) Set(conn *Conn, pre [][]byte) {
r.mu.Lock()
r.conn = conn
r.preamble = pre
r.mu.Unlock()
}
var (
_ Socket = (*repSocket)(nil)
)