e75c615ba1
Fixes #149 Co-authored-by: Sergey Egorov <sergey.egorov@teleste.com> Co-authored-by: Sebastien Binet <binet@cern.ch>
403 lines
8.7 KiB
Go
403 lines
8.7 KiB
Go
// Copyright 2018 The go-zeromq Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package zmq4
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
defaultRetry = 250 * time.Millisecond
|
|
defaultTimeout = 5 * time.Minute
|
|
defaultMaxRetries = 10
|
|
)
|
|
|
|
var (
|
|
errInvalidAddress = errors.New("zmq4: invalid address")
|
|
|
|
ErrBadProperty = errors.New("zmq4: bad property")
|
|
)
|
|
|
|
// socket implements the ZeroMQ socket interface
|
|
type socket struct {
|
|
ep string // socket end-point
|
|
typ SocketType
|
|
id SocketIdentity
|
|
retry time.Duration
|
|
maxRetries int
|
|
sec Security
|
|
log *log.Logger
|
|
subTopics func() []string
|
|
autoReconnect bool
|
|
timeout time.Duration
|
|
|
|
mu sync.RWMutex
|
|
conns []*Conn // ZMTP connections
|
|
r rpool
|
|
w wpool
|
|
|
|
props map[string]interface{} // properties of this socket
|
|
|
|
ctx context.Context // life-line of socket
|
|
cancel context.CancelFunc
|
|
listener net.Listener
|
|
dialer net.Dialer
|
|
|
|
closedConns []*Conn
|
|
reaperCond *sync.Cond
|
|
reaperStarted bool
|
|
}
|
|
|
|
func newDefaultSocket(ctx context.Context, sockType SocketType) *socket {
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
return &socket{
|
|
typ: sockType,
|
|
retry: defaultRetry,
|
|
maxRetries: defaultMaxRetries,
|
|
timeout: defaultTimeout,
|
|
sec: nullSecurity{},
|
|
conns: nil,
|
|
r: newQReader(ctx),
|
|
w: newMWriter(ctx),
|
|
props: make(map[string]interface{}),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
dialer: net.Dialer{Timeout: defaultTimeout},
|
|
reaperCond: sync.NewCond(&sync.Mutex{}),
|
|
}
|
|
}
|
|
|
|
func newSocket(ctx context.Context, sockType SocketType, opts ...Option) *socket {
|
|
sck := newDefaultSocket(ctx, sockType)
|
|
for _, opt := range opts {
|
|
opt(sck)
|
|
}
|
|
if len(sck.id) == 0 {
|
|
sck.id = SocketIdentity(newUUID())
|
|
}
|
|
if sck.log == nil {
|
|
sck.log = log.New(os.Stderr, "zmq4: ", 0)
|
|
}
|
|
|
|
return sck
|
|
}
|
|
|
|
func (sck *socket) topics() []string {
|
|
var (
|
|
keys = make(map[string]struct{})
|
|
topics []string
|
|
)
|
|
sck.mu.RLock()
|
|
for _, con := range sck.conns {
|
|
con.mu.RLock()
|
|
for topic := range con.topics {
|
|
if _, dup := keys[topic]; dup {
|
|
continue
|
|
}
|
|
keys[topic] = struct{}{}
|
|
topics = append(topics, topic)
|
|
}
|
|
con.mu.RUnlock()
|
|
}
|
|
sck.mu.RUnlock()
|
|
sort.Strings(topics)
|
|
return topics
|
|
}
|
|
|
|
// Close closes the open Socket
|
|
func (sck *socket) Close() error {
|
|
sck.cancel()
|
|
sck.reaperCond.Signal()
|
|
|
|
if sck.listener != nil {
|
|
defer sck.listener.Close()
|
|
}
|
|
|
|
sck.mu.RLock()
|
|
defer sck.mu.RUnlock()
|
|
|
|
var err error
|
|
for _, conn := range sck.conns {
|
|
e := conn.Close()
|
|
if e != nil && err == nil {
|
|
err = e
|
|
}
|
|
}
|
|
|
|
// Remove the unix socket file if created by net.Listen
|
|
if sck.listener != nil && strings.HasPrefix(sck.ep, "ipc://") {
|
|
os.Remove(sck.ep[len("ipc://"):])
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// Send puts the message on the outbound send queue.
|
|
// Send blocks until the message can be queued or the send deadline expires.
|
|
func (sck *socket) Send(msg Msg) error {
|
|
ctx, cancel := context.WithTimeout(sck.ctx, sck.Timeout())
|
|
defer cancel()
|
|
return sck.w.write(ctx, msg)
|
|
}
|
|
|
|
// SendMulti puts the message on the outbound send queue.
|
|
// SendMulti blocks until the message can be queued or the send deadline expires.
|
|
// The message will be sent as a multipart message.
|
|
func (sck *socket) SendMulti(msg Msg) error {
|
|
msg.multipart = true
|
|
ctx, cancel := context.WithTimeout(sck.ctx, sck.Timeout())
|
|
defer cancel()
|
|
return sck.w.write(ctx, msg)
|
|
}
|
|
|
|
// Recv receives a complete message.
|
|
func (sck *socket) Recv() (Msg, error) {
|
|
ctx, cancel := context.WithCancel(sck.ctx)
|
|
defer cancel()
|
|
var msg Msg
|
|
err := sck.r.read(ctx, &msg)
|
|
return msg, err
|
|
}
|
|
|
|
// Listen connects a local endpoint to the Socket.
|
|
func (sck *socket) Listen(endpoint string) error {
|
|
sck.ep = endpoint
|
|
network, addr, err := splitAddr(endpoint)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
trans, ok := drivers.get(network)
|
|
if !ok {
|
|
return UnknownTransportError{Name: network}
|
|
}
|
|
|
|
l, err := trans.Listen(sck.ctx, addr)
|
|
if err != nil {
|
|
return fmt.Errorf("zmq4: could not listen to %q: %w", endpoint, err)
|
|
}
|
|
sck.listener = l
|
|
|
|
go sck.accept()
|
|
go sck.connReaper()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (sck *socket) accept() {
|
|
ctx, cancel := context.WithCancel(sck.ctx)
|
|
defer cancel()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
conn, err := sck.listener.Accept()
|
|
if err != nil {
|
|
// FIXME(sbinet): maybe bubble up this error to application code?
|
|
//sck.log.Printf("error accepting connection from %q: %+v", sck.ep, err)
|
|
continue
|
|
}
|
|
|
|
zconn, err := Open(conn, sck.sec, sck.typ, sck.id, true, sck.scheduleRmConn)
|
|
if err != nil {
|
|
// FIXME(sbinet): maybe bubble up this error to application code?
|
|
sck.log.Printf("could not open a ZMTP connection with %q: %+v", sck.ep, err)
|
|
continue
|
|
}
|
|
|
|
sck.addConn(zconn)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Dial connects a remote endpoint to the Socket.
|
|
func (sck *socket) Dial(endpoint string) error {
|
|
sck.ep = endpoint
|
|
|
|
network, addr, err := splitAddr(endpoint)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var (
|
|
conn net.Conn
|
|
trans, ok = drivers.get(network)
|
|
retries = 0
|
|
)
|
|
if !ok {
|
|
return UnknownTransportError{Name: network}
|
|
}
|
|
|
|
connect:
|
|
conn, err = trans.Dial(sck.ctx, &sck.dialer, addr)
|
|
if err != nil {
|
|
// retry if retry count is lower than maximum retry count and context has not been canceled
|
|
if (sck.maxRetries == -1 || retries < sck.maxRetries) && sck.ctx.Err() == nil {
|
|
retries++
|
|
time.Sleep(sck.retry)
|
|
goto connect
|
|
}
|
|
return fmt.Errorf("zmq4: could not dial to %q (retry=%v): %w", endpoint, sck.retry, err)
|
|
}
|
|
|
|
if conn == nil {
|
|
return fmt.Errorf("zmq4: got a nil dial-conn to %q", endpoint)
|
|
}
|
|
|
|
zconn, err := Open(conn, sck.sec, sck.typ, sck.id, false, sck.scheduleRmConn)
|
|
if err != nil {
|
|
return fmt.Errorf("zmq4: could not open a ZMTP connection: %w", err)
|
|
}
|
|
if zconn == nil {
|
|
return fmt.Errorf("zmq4: got a nil ZMTP connection to %q", endpoint)
|
|
}
|
|
|
|
if !sck.reaperStarted {
|
|
go sck.connReaper()
|
|
sck.reaperStarted = true
|
|
}
|
|
sck.addConn(zconn)
|
|
return nil
|
|
}
|
|
|
|
func (sck *socket) addConn(c *Conn) {
|
|
sck.mu.Lock()
|
|
defer sck.mu.Unlock()
|
|
sck.conns = append(sck.conns, c)
|
|
if len(c.Peer.Meta[sysSockID]) == 0 {
|
|
switch c.typ {
|
|
case Router: // TODO: STREAM type when implemented
|
|
// if empty Identity metadata is received from some client
|
|
// need to assign an uuid such that router socket can reply to the correct client
|
|
c.Peer.Meta[sysSockID] = newUUID()
|
|
}
|
|
}
|
|
if sck.w != nil {
|
|
sck.w.addConn(c)
|
|
}
|
|
if sck.r != nil {
|
|
sck.r.addConn(c)
|
|
}
|
|
// resend subscriptions for topics if there are any
|
|
if sck.subTopics != nil {
|
|
for _, topic := range sck.subTopics() {
|
|
_ = sck.Send(NewMsg(append([]byte{1}, topic...)))
|
|
}
|
|
}
|
|
}
|
|
|
|
func (sck *socket) rmConn(c *Conn) {
|
|
sck.mu.Lock()
|
|
defer sck.mu.Unlock()
|
|
|
|
cur := -1
|
|
for i := range sck.conns {
|
|
if sck.conns[i] == c {
|
|
cur = i
|
|
break
|
|
}
|
|
}
|
|
|
|
if cur == -1 {
|
|
return
|
|
}
|
|
|
|
sck.conns = append(sck.conns[:cur], sck.conns[cur+1:]...)
|
|
if sck.r != nil {
|
|
sck.r.rmConn(c)
|
|
}
|
|
if sck.w != nil {
|
|
sck.w.rmConn(c)
|
|
}
|
|
}
|
|
|
|
func (sck *socket) scheduleRmConn(c *Conn) {
|
|
sck.reaperCond.L.Lock()
|
|
sck.closedConns = append(sck.closedConns, c)
|
|
sck.reaperCond.Signal()
|
|
sck.reaperCond.L.Unlock()
|
|
|
|
if sck.autoReconnect {
|
|
sck.Dial(sck.ep)
|
|
}
|
|
}
|
|
|
|
// Type returns the type of this Socket (PUB, SUB, ...)
|
|
func (sck *socket) Type() SocketType {
|
|
return sck.typ
|
|
}
|
|
|
|
// Addr returns the listener's address.
|
|
// Addr returns nil if the socket isn't a listener.
|
|
func (sck *socket) Addr() net.Addr {
|
|
if sck.listener == nil {
|
|
return nil
|
|
}
|
|
return sck.listener.Addr()
|
|
}
|
|
|
|
// GetOption is used to retrieve an option for a socket.
|
|
func (sck *socket) GetOption(name string) (interface{}, error) {
|
|
v, ok := sck.props[name]
|
|
if !ok {
|
|
return nil, ErrBadProperty
|
|
}
|
|
return v, nil
|
|
}
|
|
|
|
// SetOption is used to set an option for a socket.
|
|
func (sck *socket) SetOption(name string, value interface{}) error {
|
|
// FIXME(sbinet) different socket types support different options.
|
|
sck.props[name] = value
|
|
return nil
|
|
}
|
|
|
|
func (sck *socket) Timeout() time.Duration {
|
|
return sck.timeout
|
|
}
|
|
|
|
func (sck *socket) connReaper() {
|
|
sck.reaperCond.L.Lock()
|
|
defer sck.reaperCond.L.Unlock()
|
|
|
|
for {
|
|
for len(sck.closedConns) == 0 && sck.ctx.Err() == nil {
|
|
sck.reaperCond.Wait()
|
|
}
|
|
|
|
if sck.ctx.Err() != nil {
|
|
return
|
|
}
|
|
|
|
// Clone the known closed connections to avoid data race
|
|
// and remove those under reaper unlocked.
|
|
// That should fix the deadlock reported in #149.
|
|
cc := append([]*Conn{}, sck.closedConns...) // clone
|
|
sck.closedConns = sck.closedConns[:0]
|
|
sck.reaperCond.L.Unlock()
|
|
for _, c := range cc {
|
|
sck.rmConn(c)
|
|
}
|
|
sck.reaperCond.L.Lock()
|
|
}
|
|
}
|
|
|
|
var (
|
|
_ Socket = (*socket)(nil)
|
|
)
|