all: introduce transport.Transport interface and Transport plugin mechanism
Updates go-zeromq/zmq4#87.
This commit is contained in:
parent
061cd688ca
commit
c289829eda
35
internal/inproc/transport.go
Normal file
35
internal/inproc/transport.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
// Copyright 2020 The go-zeromq Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package inproc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/go-zeromq/zmq4/transport"
|
||||
)
|
||||
|
||||
// Transport implements the zmq4 Transport interface for the inproc transport.
|
||||
type Transport struct{}
|
||||
|
||||
// Dial connects to the address on the named network using the provided
|
||||
// context.
|
||||
func (Transport) Dial(ctx context.Context, dialer transport.Dialer, addr string) (net.Conn, error) {
|
||||
return Dial(addr)
|
||||
}
|
||||
|
||||
// Listen announces on the provided network address.
|
||||
func (Transport) Listen(ctx context.Context, addr string) (net.Listener, error) {
|
||||
return Listen(addr)
|
||||
}
|
||||
|
||||
// Addr returns the end-point address.
|
||||
func (Transport) Addr(ep string) (addr string, err error) {
|
||||
return ep, nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ transport.Transport = (*Transport)(nil)
|
||||
)
|
36
socket.go
36
socket.go
|
@ -15,8 +15,6 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-zeromq/zmq4/internal/inproc"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -180,15 +178,10 @@ func (sck *socket) Listen(endpoint string) error {
|
|||
|
||||
var l net.Listener
|
||||
|
||||
switch network {
|
||||
case "ipc":
|
||||
l, err = net.Listen("unix", addr)
|
||||
case "tcp":
|
||||
l, err = net.Listen("tcp", addr)
|
||||
case "udp":
|
||||
l, err = net.Listen("udp", addr)
|
||||
case "inproc":
|
||||
l, err = inproc.Listen(addr)
|
||||
trans, ok := drivers.get(network)
|
||||
switch {
|
||||
case ok:
|
||||
l, err = trans.Listen(sck.ctx, addr)
|
||||
default:
|
||||
panic("zmq4: unknown protocol " + network)
|
||||
}
|
||||
|
@ -240,18 +233,15 @@ func (sck *socket) Dial(endpoint string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
retries := 0
|
||||
var conn net.Conn
|
||||
var (
|
||||
conn net.Conn
|
||||
trans, ok = drivers.get(network)
|
||||
retries = 0
|
||||
)
|
||||
connect:
|
||||
switch network {
|
||||
case "ipc":
|
||||
conn, err = sck.dialer.DialContext(sck.ctx, "unix", addr)
|
||||
case "tcp":
|
||||
conn, err = sck.dialer.DialContext(sck.ctx, "tcp", addr)
|
||||
case "udp":
|
||||
conn, err = sck.dialer.DialContext(sck.ctx, "udp", addr)
|
||||
case "inproc":
|
||||
conn, err = inproc.Dial(addr)
|
||||
switch {
|
||||
case ok:
|
||||
conn, err = trans.Dial(sck.ctx, &sck.dialer, addr)
|
||||
default:
|
||||
panic("zmq4: unknown protocol " + network)
|
||||
}
|
||||
|
@ -262,7 +252,7 @@ connect:
|
|||
time.Sleep(sck.retry)
|
||||
goto connect
|
||||
}
|
||||
return fmt.Errorf("zmq4: could not dial to %q: %w", endpoint, err)
|
||||
return fmt.Errorf("zmq4: could not dial to %q (retry=%v): %w", endpoint, sck.retry, err)
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
|
|
72
transport.go
Normal file
72
transport.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
// Copyright 2018 The go-zeromq Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package zmq4
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/go-zeromq/zmq4/internal/inproc"
|
||||
"github.com/go-zeromq/zmq4/transport"
|
||||
)
|
||||
|
||||
// Transports returns the sorted list of currently registered transports.
|
||||
func Transports() []string {
|
||||
return drivers.names()
|
||||
}
|
||||
|
||||
// RegisterTransport registers a new transport with the zmq4 package.
|
||||
func RegisterTransport(name string, trans transport.Transport) error {
|
||||
return drivers.add(name, trans)
|
||||
}
|
||||
|
||||
type transports struct {
|
||||
sync.RWMutex
|
||||
db map[string]transport.Transport
|
||||
}
|
||||
|
||||
func (ts *transports) get(name string) (transport.Transport, bool) {
|
||||
ts.RLock()
|
||||
defer ts.RUnlock()
|
||||
|
||||
v, ok := ts.db[name]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func (ts *transports) add(name string, trans transport.Transport) error {
|
||||
ts.Lock()
|
||||
defer ts.Unlock()
|
||||
|
||||
if old, dup := ts.db[name]; dup {
|
||||
return fmt.Errorf("zmq4: duplicate transport %q (%T)", name, old)
|
||||
}
|
||||
|
||||
ts.db[name] = trans
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts *transports) names() []string {
|
||||
ts.RLock()
|
||||
defer ts.RUnlock()
|
||||
|
||||
o := make([]string, 0, len(ts.db))
|
||||
for k := range ts.db {
|
||||
o = append(o, k)
|
||||
}
|
||||
sort.Strings(o)
|
||||
return o
|
||||
}
|
||||
|
||||
var drivers = transports{
|
||||
db: make(map[string]transport.Transport),
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterTransport("ipc", transport.New("unix"))
|
||||
RegisterTransport("tcp", transport.New("tcp"))
|
||||
RegisterTransport("udp", transport.New("udp"))
|
||||
RegisterTransport("inproc", inproc.Transport{})
|
||||
}
|
86
transport/transport.go
Normal file
86
transport/transport.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
// Copyright 2020 The go-zeromq Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package transport defines the Transport interface and provides a net-based
|
||||
// implementation that can be used by zmq4 sockets to exchange messages.
|
||||
package transport // import "github.com/go-zeromq/zmq4/transport"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Dialer is the interface that wraps the DialContext method.
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// Transport is the zmq4 transport interface that wraps
|
||||
// the Dial and Listen methods.
|
||||
type Transport interface {
|
||||
// Dial connects to the address on the named network using the provided
|
||||
// context.
|
||||
Dial(ctx context.Context, dialer Dialer, addr string) (net.Conn, error)
|
||||
|
||||
// Listen announces on the provided network address.
|
||||
Listen(ctx context.Context, addr string) (net.Listener, error)
|
||||
|
||||
// Addr returns the end-point address.
|
||||
Addr(ep string) (addr string, err error)
|
||||
}
|
||||
|
||||
// netTransport implements the Transport interface using the net package.
|
||||
type netTransport struct {
|
||||
prot string
|
||||
}
|
||||
|
||||
// New returns a new net-based transport with the given network (e.g "tcp").
|
||||
func New(network string) Transport {
|
||||
return netTransport{prot: network}
|
||||
}
|
||||
|
||||
// Dial connects to the address on the named network using the provided
|
||||
// context.
|
||||
func (trans netTransport) Dial(ctx context.Context, dialer Dialer, addr string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, trans.prot, addr)
|
||||
}
|
||||
|
||||
// Listen announces on the provided network address.
|
||||
func (trans netTransport) Listen(ctx context.Context, addr string) (net.Listener, error) {
|
||||
return net.Listen(trans.prot, addr)
|
||||
}
|
||||
|
||||
// Addr returns the end-point address.
|
||||
func (trans netTransport) Addr(ep string) (addr string, err error) {
|
||||
switch trans.prot {
|
||||
case "tcp", "udp":
|
||||
host, port, err := net.SplitHostPort(ep)
|
||||
if err != nil {
|
||||
return addr, err
|
||||
}
|
||||
switch port {
|
||||
case "0", "*", "":
|
||||
port = "0"
|
||||
}
|
||||
switch host {
|
||||
case "", "*":
|
||||
host = "0.0.0.0"
|
||||
}
|
||||
addr = net.JoinHostPort(host, port)
|
||||
return addr, err
|
||||
|
||||
case "unix":
|
||||
return ep, nil
|
||||
|
||||
default:
|
||||
err = fmt.Errorf("zmq4: unknown protocol %q", trans.prot)
|
||||
}
|
||||
|
||||
return addr, err
|
||||
}
|
||||
|
||||
var (
|
||||
_ Transport = (*netTransport)(nil)
|
||||
)
|
31
transport_test.go
Normal file
31
transport_test.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
// Copyright 2018 The go-zeromq Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package zmq4
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/go-zeromq/zmq4/internal/inproc"
|
||||
)
|
||||
|
||||
func TestTransport(t *testing.T) {
|
||||
if got, want := Transports(), []string{"inproc", "ipc", "tcp", "udp"}; !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("invalid list of transports.\ngot= %q\nwant=%q", got, want)
|
||||
}
|
||||
|
||||
err := RegisterTransport("tcp", inproc.Transport{})
|
||||
if err == nil {
|
||||
t.Fatalf("expected a duplicate-registration error")
|
||||
}
|
||||
if got, want := err.Error(), "zmq4: duplicate transport \"tcp\" (transport.netTransport)"; got != want {
|
||||
t.Fatalf("invalid duplicate registration error:\ngot= %s\nwant=%s", got, want)
|
||||
}
|
||||
|
||||
err = RegisterTransport("inproc2", inproc.Transport{})
|
||||
if err != nil {
|
||||
t.Fatalf("could not register 'inproc2': %+v", err)
|
||||
}
|
||||
}
|
35
utils.go
35
utils.go
|
@ -9,7 +9,6 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
@ -20,39 +19,15 @@ func splitAddr(v string) (network, addr string, err error) {
|
|||
err = errInvalidAddress
|
||||
return network, addr, err
|
||||
}
|
||||
var (
|
||||
host string
|
||||
port string
|
||||
)
|
||||
network = ep[0]
|
||||
switch network {
|
||||
case "tcp", "udp":
|
||||
host, port, err = net.SplitHostPort(ep[1])
|
||||
if err != nil {
|
||||
return network, addr, err
|
||||
}
|
||||
switch port {
|
||||
case "0", "*", "":
|
||||
port = "0"
|
||||
}
|
||||
switch host {
|
||||
case "", "*":
|
||||
host = "0.0.0.0"
|
||||
}
|
||||
addr = net.JoinHostPort(host, port)
|
||||
return network, addr, err
|
||||
|
||||
case "ipc":
|
||||
host = ep[1]
|
||||
port = ""
|
||||
return network, host, nil
|
||||
case "inproc":
|
||||
host = ep[1]
|
||||
return "inproc", host, nil
|
||||
default:
|
||||
err = fmt.Errorf("zmq4: unknown protocol %q", network)
|
||||
trans, ok := drivers.get(network)
|
||||
if !ok {
|
||||
err = fmt.Errorf("zmq4: unknown transport %q", network)
|
||||
return network, addr, err
|
||||
}
|
||||
|
||||
addr, err = trans.Addr(ep[1])
|
||||
return network, addr, err
|
||||
}
|
||||
|
||||
|
|
|
@ -38,6 +38,20 @@ func TestSplitAddr(t *testing.T) {
|
|||
addr: "[::1]:7000",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "ipc",
|
||||
v: "ipc://some-ep",
|
||||
network: "ipc",
|
||||
addr: "some-ep",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "inproc",
|
||||
v: "inproc://some-ep",
|
||||
network: "inproc",
|
||||
addr: "some-ep",
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
|
|
Loading…
Reference in a new issue