package util import ( "log" "net" "sync/atomic" "time" ) type listener struct { net.Listener ReadTimeout time.Duration WriteTimeout time.Duration } func (l *listener) Accept() (net.Conn, error) { c, err := l.Listener.Accept() if err != nil { return nil, err } if tcpConnection, ok := c.(*net.TCPConn); ok { tcpConnection.SetReadBuffer(1024 * 64) tcpConnection.SetWriteBuffer(1024 * 64) tcpConnection.SetKeepAlive(true) tcpConnection.SetKeepAlivePeriod(l.WriteTimeout / 4) tcpConnection.SetNoDelay(true) tcpConnection.SetLinger(0) } log.Printf("accepted new connection from %s\n", c.RemoteAddr().String()) tc := &Conn{ Conn: c, ReadTimeout: l.ReadTimeout, WriteTimeout: l.WriteTimeout, ReadThreshold: int32((l.ReadTimeout * 1024) / time.Second), WriteThreshold: int32((l.WriteTimeout * 1024) / time.Second), } return tc, nil } // Conn wraps a net.Conn, and sets a deadline for every read // and write operation. type Conn struct { net.Conn ReadTimeout time.Duration WriteTimeout time.Duration ReadThreshold int32 WriteThreshold int32 BytesReadFromDeadline atomic.Int32 BytesWrittenFromDeadline atomic.Int32 } func (c *Conn) Read(b []byte) (n int, err error) { if c.BytesReadFromDeadline.Load() > c.ReadThreshold { c.BytesReadFromDeadline.Store(0) // we set both read and write deadlines here otherwise after the request // is read writing the response fails with an i/o timeout error err = c.Conn.SetDeadline(time.Now().Add(c.ReadTimeout)) if err != nil { return 0, err } } n, err = c.Conn.Read(b) c.BytesReadFromDeadline.Add(int32(n)) return } func (c *Conn) Write(b []byte) (n int, err error) { if c.BytesWrittenFromDeadline.Load() > c.WriteThreshold { c.BytesWrittenFromDeadline.Store(0) // we extend the read deadline too, not sure it's necessary, // but it doesn't hurt err = c.Conn.SetDeadline(time.Now().Add(c.WriteTimeout)) if err != nil { return } } n, err = c.Conn.Write(b) c.BytesWrittenFromDeadline.Add(int32(n)) return } func NewTimeoutListener(network, addr string, readTimeout, writeTimeout time.Duration) (net.Listener, error) { l, err := net.Listen(network, addr) if err != nil { return nil, err } tl := &listener{ Listener: l, ReadTimeout: readTimeout, WriteTimeout: writeTimeout, } return tl, nil }