all: introduce transport.Transport interface and Transport plugin mechanism

Updates go-zeromq/zmq4#87.
This commit is contained in:
Sebastien Binet 2020-10-21 09:32:11 +02:00
parent 061cd688ca
commit c289829eda
7 changed files with 256 additions and 53 deletions

View file

@ -0,0 +1,35 @@
// Copyright 2020 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package inproc
import (
"context"
"net"
"github.com/go-zeromq/zmq4/transport"
)
// Transport implements the zmq4 Transport interface for the inproc transport.
type Transport struct{}
// Dial connects to the address on the named network using the provided
// context.
func (Transport) Dial(ctx context.Context, dialer transport.Dialer, addr string) (net.Conn, error) {
return Dial(addr)
}
// Listen announces on the provided network address.
func (Transport) Listen(ctx context.Context, addr string) (net.Listener, error) {
return Listen(addr)
}
// Addr returns the end-point address.
func (Transport) Addr(ep string) (addr string, err error) {
return ep, nil
}
var (
_ transport.Transport = (*Transport)(nil)
)

View file

@ -15,8 +15,6 @@ import (
"strings"
"sync"
"time"
"github.com/go-zeromq/zmq4/internal/inproc"
)
const (
@ -180,15 +178,10 @@ func (sck *socket) Listen(endpoint string) error {
var l net.Listener
switch network {
case "ipc":
l, err = net.Listen("unix", addr)
case "tcp":
l, err = net.Listen("tcp", addr)
case "udp":
l, err = net.Listen("udp", addr)
case "inproc":
l, err = inproc.Listen(addr)
trans, ok := drivers.get(network)
switch {
case ok:
l, err = trans.Listen(sck.ctx, addr)
default:
panic("zmq4: unknown protocol " + network)
}
@ -240,18 +233,15 @@ func (sck *socket) Dial(endpoint string) error {
return err
}
retries := 0
var conn net.Conn
var (
conn net.Conn
trans, ok = drivers.get(network)
retries = 0
)
connect:
switch network {
case "ipc":
conn, err = sck.dialer.DialContext(sck.ctx, "unix", addr)
case "tcp":
conn, err = sck.dialer.DialContext(sck.ctx, "tcp", addr)
case "udp":
conn, err = sck.dialer.DialContext(sck.ctx, "udp", addr)
case "inproc":
conn, err = inproc.Dial(addr)
switch {
case ok:
conn, err = trans.Dial(sck.ctx, &sck.dialer, addr)
default:
panic("zmq4: unknown protocol " + network)
}
@ -262,7 +252,7 @@ connect:
time.Sleep(sck.retry)
goto connect
}
return fmt.Errorf("zmq4: could not dial to %q: %w", endpoint, err)
return fmt.Errorf("zmq4: could not dial to %q (retry=%v): %w", endpoint, sck.retry, err)
}
if conn == nil {

72
transport.go Normal file
View file

@ -0,0 +1,72 @@
// Copyright 2018 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zmq4
import (
"fmt"
"sort"
"sync"
"github.com/go-zeromq/zmq4/internal/inproc"
"github.com/go-zeromq/zmq4/transport"
)
// Transports returns the sorted list of currently registered transports.
func Transports() []string {
return drivers.names()
}
// RegisterTransport registers a new transport with the zmq4 package.
func RegisterTransport(name string, trans transport.Transport) error {
return drivers.add(name, trans)
}
type transports struct {
sync.RWMutex
db map[string]transport.Transport
}
func (ts *transports) get(name string) (transport.Transport, bool) {
ts.RLock()
defer ts.RUnlock()
v, ok := ts.db[name]
return v, ok
}
func (ts *transports) add(name string, trans transport.Transport) error {
ts.Lock()
defer ts.Unlock()
if old, dup := ts.db[name]; dup {
return fmt.Errorf("zmq4: duplicate transport %q (%T)", name, old)
}
ts.db[name] = trans
return nil
}
func (ts *transports) names() []string {
ts.RLock()
defer ts.RUnlock()
o := make([]string, 0, len(ts.db))
for k := range ts.db {
o = append(o, k)
}
sort.Strings(o)
return o
}
var drivers = transports{
db: make(map[string]transport.Transport),
}
func init() {
RegisterTransport("ipc", transport.New("unix"))
RegisterTransport("tcp", transport.New("tcp"))
RegisterTransport("udp", transport.New("udp"))
RegisterTransport("inproc", inproc.Transport{})
}

86
transport/transport.go Normal file
View file

@ -0,0 +1,86 @@
// Copyright 2020 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package transport defines the Transport interface and provides a net-based
// implementation that can be used by zmq4 sockets to exchange messages.
package transport // import "github.com/go-zeromq/zmq4/transport"
import (
"context"
"fmt"
"net"
)
// Dialer is the interface that wraps the DialContext method.
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// Transport is the zmq4 transport interface that wraps
// the Dial and Listen methods.
type Transport interface {
// Dial connects to the address on the named network using the provided
// context.
Dial(ctx context.Context, dialer Dialer, addr string) (net.Conn, error)
// Listen announces on the provided network address.
Listen(ctx context.Context, addr string) (net.Listener, error)
// Addr returns the end-point address.
Addr(ep string) (addr string, err error)
}
// netTransport implements the Transport interface using the net package.
type netTransport struct {
prot string
}
// New returns a new net-based transport with the given network (e.g "tcp").
func New(network string) Transport {
return netTransport{prot: network}
}
// Dial connects to the address on the named network using the provided
// context.
func (trans netTransport) Dial(ctx context.Context, dialer Dialer, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, trans.prot, addr)
}
// Listen announces on the provided network address.
func (trans netTransport) Listen(ctx context.Context, addr string) (net.Listener, error) {
return net.Listen(trans.prot, addr)
}
// Addr returns the end-point address.
func (trans netTransport) Addr(ep string) (addr string, err error) {
switch trans.prot {
case "tcp", "udp":
host, port, err := net.SplitHostPort(ep)
if err != nil {
return addr, err
}
switch port {
case "0", "*", "":
port = "0"
}
switch host {
case "", "*":
host = "0.0.0.0"
}
addr = net.JoinHostPort(host, port)
return addr, err
case "unix":
return ep, nil
default:
err = fmt.Errorf("zmq4: unknown protocol %q", trans.prot)
}
return addr, err
}
var (
_ Transport = (*netTransport)(nil)
)

31
transport_test.go Normal file
View file

@ -0,0 +1,31 @@
// Copyright 2018 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zmq4
import (
"reflect"
"testing"
"github.com/go-zeromq/zmq4/internal/inproc"
)
func TestTransport(t *testing.T) {
if got, want := Transports(), []string{"inproc", "ipc", "tcp", "udp"}; !reflect.DeepEqual(got, want) {
t.Fatalf("invalid list of transports.\ngot= %q\nwant=%q", got, want)
}
err := RegisterTransport("tcp", inproc.Transport{})
if err == nil {
t.Fatalf("expected a duplicate-registration error")
}
if got, want := err.Error(), "zmq4: duplicate transport \"tcp\" (transport.netTransport)"; got != want {
t.Fatalf("invalid duplicate registration error:\ngot= %s\nwant=%s", got, want)
}
err = RegisterTransport("inproc2", inproc.Transport{})
if err != nil {
t.Fatalf("could not register 'inproc2': %+v", err)
}
}

View file

@ -9,7 +9,6 @@ import (
"fmt"
"io"
"log"
"net"
"strings"
)
@ -20,39 +19,15 @@ func splitAddr(v string) (network, addr string, err error) {
err = errInvalidAddress
return network, addr, err
}
var (
host string
port string
)
network = ep[0]
switch network {
case "tcp", "udp":
host, port, err = net.SplitHostPort(ep[1])
if err != nil {
return network, addr, err
}
switch port {
case "0", "*", "":
port = "0"
}
switch host {
case "", "*":
host = "0.0.0.0"
}
addr = net.JoinHostPort(host, port)
return network, addr, err
case "ipc":
host = ep[1]
port = ""
return network, host, nil
case "inproc":
host = ep[1]
return "inproc", host, nil
default:
err = fmt.Errorf("zmq4: unknown protocol %q", network)
trans, ok := drivers.get(network)
if !ok {
err = fmt.Errorf("zmq4: unknown transport %q", network)
return network, addr, err
}
addr, err = trans.Addr(ep[1])
return network, addr, err
}

View file

@ -38,6 +38,20 @@ func TestSplitAddr(t *testing.T) {
addr: "[::1]:7000",
err: nil,
},
{
desc: "ipc",
v: "ipc://some-ep",
network: "ipc",
addr: "some-ep",
err: nil,
},
{
desc: "inproc",
v: "inproc://some-ep",
network: "inproc",
addr: "some-ep",
err: nil,
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {