Resolve lock contention whenever bans occur with multiple clients connected
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
DataHoarder 2023-05-23 07:36:35 +02:00
parent 099c393235
commit db5a26de4f
Signed by: DataHoarder
SSH key fingerprint: SHA256:OLTRf6Fl87G52SiR7sWLGNzlJt4WOX+tfI2yxo0z7xk
2 changed files with 61 additions and 40 deletions

View file

@ -978,20 +978,22 @@ func (c *Client) Close() bool {
c.Ban(DefaultBanTime, errors.New("disconnected before finishing handshake"))
}
c.Owner.clientsLock.Lock()
defer c.Owner.clientsLock.Unlock()
if c.Owner.fastestPeer == c {
c.Owner.fastestPeer = nil
}
if i := slices.Index(c.Owner.clients, c); i != -1 {
c.Owner.clients = slices.Delete(c.Owner.clients, i, i+1)
if c.IsIncomingConnection {
c.Owner.NumIncomingConnections.Add(-1)
} else {
c.Owner.NumOutgoingConnections.Add(-1)
c.Owner.PendingOutgoingConnections.Replace(c.AddressPort.Addr().String(), "")
func() {
c.Owner.clientsLock.Lock()
defer c.Owner.clientsLock.Unlock()
if c.Owner.fastestPeer == c {
c.Owner.fastestPeer = nil
}
}
if i := slices.Index(c.Owner.clients, c); i != -1 {
c.Owner.clients = slices.Delete(c.Owner.clients, i, i+1)
if c.IsIncomingConnection {
c.Owner.NumIncomingConnections.Add(-1)
} else {
c.Owner.NumOutgoingConnections.Add(-1)
c.Owner.PendingOutgoingConnections.Replace(c.AddressPort.Addr().String(), "")
}
}
}()
_ = c.Connection.Close()
close(c.closeChannel)

View file

@ -273,6 +273,19 @@ func (s *Server) UpdatePeerList() {
}
}
func (s *Server) CleanupBanList() {
s.bansLock.Lock()
defer s.bansLock.Unlock()
currentTime := uint64(time.Now().Unix())
for k, b := range s.bans {
if currentTime >= b.Expiration {
delete(s.bans, k)
}
}
}
func (s *Server) UpdateClientConnections() {
currentTime := uint64(time.Now().Unix())
@ -495,6 +508,14 @@ func (s *Server) Listen() (err error) {
s.RefreshOutgoingIPv6()
}
}()
wg.Add(1)
go func() {
defer wg.Done()
for range utils.ContextTick(s.ctx, time.Minute*5) {
s.CleanupBanList()
}
}()
for !s.close.Load() {
if conn, err := s.listener.AcceptTCP(); err != nil {
return err
@ -675,18 +696,14 @@ func (s *Server) IsBanned(ip netip.Addr) (bool, *BanEntry) {
k := prefix.Addr().As16()
s.bansLock.RLock()
defer s.bansLock.RUnlock()
if b, ok := s.bans[k]; ok == false {
if b, ok := func() (entry BanEntry, ok bool) {
s.bansLock.RLock()
defer s.bansLock.RUnlock()
entry, ok = s.bans[k]
return entry, ok
}(); ok == false {
return false, nil
} else if uint64(time.Now().Unix()) >= b.Expiration {
go func() {
//HACK: delay via goroutine
s.bansLock.Lock()
defer s.bansLock.Unlock()
delete(s.bans, k)
}()
return false, nil
} else {
return true, &b
@ -697,32 +714,34 @@ func (s *Server) Ban(ip netip.Addr, duration time.Duration, err error) {
if ok, _ := s.IsBanned(ip); ok {
return
}
go func() {
log.Printf("[P2PServer] Banned %s for %s: %s", ip.String(), duration.String(), err.Error())
if !ip.IsLoopback() {
ip = ip.Unmap()
var prefix netip.Prefix
if ip.Is6() {
//ban the /64
prefix, _ = ip.Prefix(64)
} else if ip.Is4() {
//ban only a single ip, /32
prefix, _ = ip.Prefix(32)
}
if prefix.IsValid() {
log.Printf("[P2PServer] Banned %s for %s: %s", ip.String(), duration.String(), err.Error())
if !ip.IsLoopback() {
ip = ip.Unmap()
var prefix netip.Prefix
if ip.Is6() {
//ban the /64
prefix, _ = ip.Prefix(64)
} else if ip.Is4() {
//ban only a single ip, /32
prefix, _ = ip.Prefix(32)
}
if prefix.IsValid() {
func() {
s.bansLock.Lock()
defer s.bansLock.Unlock()
s.bans[prefix.Addr().As16()] = BanEntry{
Error: err,
Expiration: uint64(time.Now().Unix()) + uint64(duration.Seconds()),
}
for _, c := range s.GetAddressConnectedPrefix(prefix) {
c.Close()
}
}()
for _, c := range s.GetAddressConnectedPrefix(prefix) {
c.Close()
}
}
}()
}
}
func (s *Server) Close() {