Faster address derivation via benchmark, lesser cache of derivations

This commit is contained in:
DataHoarder 2022-11-07 15:58:02 +01:00
parent 51bdaf581a
commit 55b39f8535
Signed by: DataHoarder
SSH key fingerprint: SHA256:OLTRf6Fl87G52SiR7sWLGNzlJt4WOX+tfI2yxo0z7xk
9 changed files with 199 additions and 132 deletions

View file

@ -51,3 +51,21 @@ func TestSort(t *testing.T) {
t.Fatalf("expected address2 < address3, got %d", testAddress2.Compare(testAddress3))
}
}
func BenchmarkCoinbaseDerivation(b *testing.B) {
packed := testAddress3.ToPackedAddress()
txKey := crypto.PrivateKeyFromScalar(privateKey)
for i := 0; i < b.N; i++ {
GetEphemeralPublicKey(packed, txKey, uint64(i))
}
}
func BenchmarkCoinbaseDerivationInline(b *testing.B) {
packed := testAddress3.ToPackedAddress()
spendPub, viewPub := packed.SpendPublicKey().AsPoint().Point(), packed.ViewPublicKey().AsPoint().Point()
p := new(edwards25519.Point)
for i := 0; i < b.N; i++ {
getEphemeralPublicKeyInline(spendPub, viewPub, privateKey, uint64(i), p)
}
}

View file

@ -2,6 +2,7 @@ package address
import (
"encoding/binary"
"filippo.io/edwards25519"
"git.gammaspectra.live/P2Pool/moneroutil"
"git.gammaspectra.live/P2Pool/p2pool-observer/monero/crypto"
p2poolcrypto "git.gammaspectra.live/P2Pool/p2pool-observer/p2pool/crypto"
@ -21,6 +22,24 @@ func GetEphemeralPublicKey(a Interface, txKey crypto.PrivateKey, outputIndex uin
return GetPublicKeyForSharedData(a, crypto.GetDerivationSharedDataForOutputIndex(txKey.GetDerivationCofactor(a.ViewPublicKey()), outputIndex))
}
func getEphemeralPublicKeyInline(spendPub, viewPub *edwards25519.Point, txKey *edwards25519.Scalar, outputIndex uint64, p *edwards25519.Point) {
//derivation
p.ScalarMult(txKey, viewPub).MultByCofactor(p)
derivationAsBytes := p.Bytes()
var varIntBuf [binary.MaxVarintLen64]byte
sharedData := crypto.HashToScalar(derivationAsBytes, varIntBuf[:binary.PutUvarint(varIntBuf[:], outputIndex)])
//public key + add
p.ScalarBaseMult(sharedData).Add(p, spendPub)
}
func GetEphemeralPublicKeyAndViewTag(a Interface, txKey crypto.PrivateKey, outputIndex uint64) (crypto.PublicKey, uint8) {
pK, viewTag := crypto.GetDerivationSharedDataAndViewTagForOutputIndex(txKey.GetDerivationCofactor(a.ViewPublicKey()), outputIndex)
return GetPublicKeyForSharedData(a, pK), viewTag
}
func GetTxProofV2(a Interface, txId types.Hash, txKey crypto.PrivateKey, message string) string {
prefixHash := types.Hash(moneroutil.Keccak256(txId[:], []byte(message)))

View file

@ -7,11 +7,11 @@ import (
"unsafe"
)
type PackedAddress [crypto.PublicKeySize * 2]byte
type PackedAddress [2]crypto.PublicKeyBytes
func NewPackedAddressFromBytes(spend, view crypto.PublicKeyBytes) (result PackedAddress) {
copy(result[:], spend[:])
copy(result[crypto.PublicKeySize:], view[:])
copy(result[0][:], spend[:])
copy(result[1][:], view[:])
return
}
@ -19,23 +19,20 @@ func NewPackedAddress(spend, view crypto.PublicKey) (result PackedAddress) {
return NewPackedAddressFromBytes(spend.AsBytes(), view.AsBytes())
}
func (p *PackedAddress) Bytes() []byte {
return (*[crypto.PublicKeySize*2]byte)(unsafe.Pointer(p))[:]
}
func (p *PackedAddress) PublicKeys() (spend, view crypto.PublicKey) {
var s, v crypto.PublicKeyBytes
copy(s[:], (*p)[:])
copy(v[:], (*p)[crypto.PublicKeySize:])
return &s, &v
return &(*p)[0], &(*p)[1]
}
func (p *PackedAddress) SpendPublicKey() crypto.PublicKey {
var s crypto.PublicKeyBytes
copy(s[:], (*p)[:])
return &s
return &(*p)[0]
}
func (p *PackedAddress) ViewPublicKey() crypto.PublicKey {
var v crypto.PublicKeyBytes
copy(v[:], (*p)[crypto.PublicKeySize:])
return &v
return &(*p)[1]
}
func (p *PackedAddress) ToPackedAddress() *PackedAddress {
@ -48,8 +45,8 @@ func (p *PackedAddress) Compare(otherI Interface) int {
//golang might free other otherwise
defer runtime.KeepAlive(other)
defer runtime.KeepAlive(p)
a := unsafe.Slice((*uint64)(unsafe.Pointer(p)), len(*p)/int(unsafe.Sizeof(uint64(0))))
b := unsafe.Slice((*uint64)(unsafe.Pointer(other)), len(*other)/int(unsafe.Sizeof(uint64(0))))
a := (*[(2*crypto.PublicKeySize)/8]uint64)(unsafe.Pointer(p))
b := (*[(2*crypto.PublicKeySize)/8]uint64)(unsafe.Pointer(other))
//compare spend key

View file

@ -6,14 +6,28 @@ import (
)
func GetDerivationSharedDataForOutputIndex(derivation PublicKey, outputIndex uint64) PrivateKey {
return PrivateKeyFromScalar(HashToScalar(derivation.AsSlice(), binary.AppendUvarint(nil, outputIndex)))
var k = derivation.AsBytes()
var varIntBuf [binary.MaxVarintLen64]byte
return PrivateKeyFromScalar(HashToScalar(k[:], varIntBuf[:binary.PutUvarint(varIntBuf[:], outputIndex)]))
}
func GetDerivationViewTagForOutputIndex(derivation PublicKey, outputIndex uint64) uint8 {
h := moneroutil.Keccak256([]byte("view_tag"), derivation.AsSlice(), binary.AppendUvarint(nil, outputIndex))
var k = derivation.AsBytes()
var varIntBuf [binary.MaxVarintLen64]byte
h := moneroutil.Keccak256([]byte("view_tag"), k[:], varIntBuf[:binary.PutUvarint(varIntBuf[:], outputIndex)])
return h[0]
}
func GetDerivationSharedDataAndViewTagForOutputIndex(derivation PublicKey, outputIndex uint64) (PrivateKey, uint8) {
var k = derivation.AsBytes()
var varIntBuf [binary.MaxVarintLen64]byte
n := binary.PutUvarint(varIntBuf[:], outputIndex)
pK := PrivateKeyFromScalar(HashToScalar(k[:], varIntBuf[:n]))
h := moneroutil.Keccak256([]byte("view_tag"), k[:], varIntBuf[:n])
return pK, h[0]
}
func GetKeyImage(pair *KeyPair) PublicKey {
return PublicKeyFromPoint(HashToPoint(pair.PublicKey)).Multiply(pair.PrivateKey.AsScalar())
}

View file

@ -3,6 +3,9 @@ package p2p
import (
"crypto/rand"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"git.gammaspectra.live/P2Pool/p2pool-observer/p2pool/sidechain"
"git.gammaspectra.live/P2Pool/p2pool-observer/types"
"git.gammaspectra.live/P2Pool/p2pool-observer/utils"
@ -36,12 +39,14 @@ type Client struct {
BlockPendingRequests int64
ChainTipBlockRequest bool
ExpectedMessage MessageId
expectedMessage MessageId
HandshakeChallenge HandshakeChallenge
handshakeChallenge HandshakeChallenge
messageChannel chan *ClientMessage
closeChannel chan struct{}
blockRequestThrottler <-chan time.Time
}
@ -51,16 +56,17 @@ func NewClient(owner *Server, conn net.Conn) *Client {
Connection: conn,
AddressPort: netip.MustParseAddrPort(conn.RemoteAddr().String()),
LastActive: time.Now(),
ExpectedMessage: MessageHandshakeChallenge,
expectedMessage: MessageHandshakeChallenge,
blockRequestThrottler: time.Tick(time.Second / 50), //maximum 50 per second
messageChannel: make(chan *ClientMessage, 10),
messageChannel: make(chan *ClientMessage, 10),
closeChannel: make(chan struct{}),
}
return c
}
func (c *Client) Ban(duration time.Duration) {
c.Owner.Ban(c.AddressPort.Addr(), duration)
func (c *Client) Ban(duration time.Duration, err error) {
c.Owner.Ban(c.AddressPort.Addr(), duration, err)
c.Close()
}
@ -162,40 +168,46 @@ func (c *Client) OnConnection() {
var missingBlocks []types.Hash
go func() {
for message := range c.messageChannel {
//log.Printf("Sending message %d len %d", message.MessageId, len(message.Buffer))
_, _ = c.Write([]byte{byte(message.MessageId)})
_, _ = c.Write(message.Buffer)
defer close(c.messageChannel)
for {
select {
case <-c.closeChannel:
return
case message := <- c.messageChannel:
//log.Printf("Sending message %d len %d", message.MessageId, len(message.Buffer))
_, _ = c.Write([]byte{byte(message.MessageId)})
_, _ = c.Write(message.Buffer)
}
}
}()
for !c.Closed.Load() {
var messageId MessageId
if err := binary.Read(c, binary.LittleEndian, &messageId); err != nil {
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
}
if !c.HandshakeComplete && messageId != c.ExpectedMessage {
c.Ban(DefaultBanTime)
if !c.HandshakeComplete && messageId != c.expectedMessage {
c.Ban(DefaultBanTime, fmt.Errorf("unexpected pre-handshake message: got %d, expected %d", messageId, c.expectedMessage))
return
}
switch messageId {
case MessageHandshakeChallenge:
if c.HandshakeComplete {
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, errors.New("got HANDSHAKE_CHALLENGE but handshake is complete"))
return
}
var challenge HandshakeChallenge
var peerId uint64
if err := binary.Read(c, binary.LittleEndian, &challenge); err != nil {
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
}
if err := binary.Read(c, binary.LittleEndian, &peerId); err != nil {
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
}
@ -224,41 +236,41 @@ func (c *Client) OnConnection() {
c.sendHandshakeSolution(challenge)
c.ExpectedMessage = MessageHandshakeSolution
c.expectedMessage = MessageHandshakeSolution
c.OnAfterHandshake()
case MessageHandshakeSolution:
if c.HandshakeComplete {
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, errors.New("got HANDSHAKE_SOLUTION but handshake is complete"))
return
}
var challengeHash types.Hash
var solution uint64
if err := binary.Read(c, binary.LittleEndian, &challengeHash); err != nil {
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
}
if err := binary.Read(c, binary.LittleEndian, &solution); err != nil {
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
}
if c.IsIncomingConnection {
if hash, ok := CalculateChallengeHash(c.HandshakeChallenge, c.Owner.Consensus().Id(), solution); !ok {
if hash, ok := CalculateChallengeHash(c.handshakeChallenge, c.Owner.Consensus().Id(), solution); !ok {
//not enough PoW
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, fmt.Errorf("not enough PoW on HANDSHAKE_SOLUTION, challenge = %s, solution = %d, calculated hash = %s, expected hash = %s", hex.EncodeToString(c.handshakeChallenge[:]), solution, hash.String(), challengeHash.String()))
return
} else if hash != challengeHash {
//wrong hash
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, fmt.Errorf("wrong hash HANDSHAKE_SOLUTION, challenge = %s, solution = %d, calculated hash = %s, expected hash = %s", hex.EncodeToString(c.handshakeChallenge[:]), solution, hash.String(), challengeHash.String()))
return
}
} else {
if hash, _ := CalculateChallengeHash(c.HandshakeChallenge, c.Owner.Consensus().Id(), solution); hash != challengeHash {
if hash, _ := CalculateChallengeHash(c.handshakeChallenge, c.Owner.Consensus().Id(), solution); hash != challengeHash {
//wrong hash
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, fmt.Errorf("wrong hash HANDSHAKE_SOLUTION, challenge = %s, solution = %d, calculated hash = %s, expected hash = %s", hex.EncodeToString(c.handshakeChallenge[:]), solution, hash.String(), challengeHash.String()))
return
}
}
@ -266,17 +278,17 @@ func (c *Client) OnConnection() {
case MessageListenPort:
if c.ListenPort != 0 {
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, errors.New("got LISTEN_PORT but we already received it"))
return
}
if err := binary.Read(c, binary.LittleEndian, &c.ListenPort); err != nil {
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
}
if c.ListenPort == 0 || c.ListenPort >= 65536 {
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, fmt.Errorf("listen port out of range: %d", c.ListenPort))
return
}
case MessageBlockRequest:
@ -284,7 +296,7 @@ func (c *Client) OnConnection() {
var templateId types.Hash
if err := binary.Read(c, binary.LittleEndian, &templateId); err != nil {
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
}
@ -306,7 +318,7 @@ func (c *Client) OnConnection() {
var blockSize uint32
if err := binary.Read(c, binary.LittleEndian, &blockSize); err != nil {
//TODO warn
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
} else if blockSize == 0 {
//NOT found
@ -314,7 +326,7 @@ func (c *Client) OnConnection() {
} else {
if err = block.FromReader(c); err != nil {
//TODO warn
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
} else {
if c.ChainTipBlockRequest {
@ -330,7 +342,7 @@ func (c *Client) OnConnection() {
go func() {
if missingBlocks, err = c.Owner.SideChain().AddPoolBlockExternal(block); err != nil {
//TODO warn
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
} else {
for _, id := range missingBlocks {
@ -347,21 +359,21 @@ func (c *Client) OnConnection() {
var blockSize uint32
if err := binary.Read(c, binary.LittleEndian, &blockSize); err != nil {
//TODO warn
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
} else if blockSize == 0 {
//NOT found
//TODO log
} else if err = block.FromReader(c); err != nil {
//TODO warn
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
}
c.LastBroadcast = time.Now()
go func() {
if err := c.Owner.SideChain().PreprocessBlock(block); err != nil {
//TODO warn
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
} else {
//TODO: investigate different monero block mining
@ -369,7 +381,7 @@ func (c *Client) OnConnection() {
block.WantBroadcast.Store(true)
if missingBlocks, err = c.Owner.SideChain().AddPoolBlockExternal(block); err != nil {
//TODO warn
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
} else {
for _, id := range missingBlocks {
@ -383,10 +395,10 @@ func (c *Client) OnConnection() {
c.SendPeerListResponse(nil)
case MessagePeerListResponse:
if numPeers, err := c.ReadByte(); err != nil {
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
} else if numPeers > PeerListResponseMaxPeers {
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, fmt.Errorf("too many peers on PEER_LIST_RESPONSE num_peers = %d", numPeers))
return
} else {
c.PingTime = utils.Max(time.Now().Sub(c.LastPeerListRequestTime), 0)
@ -395,15 +407,15 @@ func (c *Client) OnConnection() {
for i := uint8(0); i < numPeers; i++ {
if isV6, err := c.ReadByte(); err != nil {
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
} else {
if _, err = c.Read(rawIp[:]); err != nil {
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
} else if err = binary.Read(c, binary.LittleEndian, &port); err != nil {
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, err)
return
}
if isV6 != 0 {
@ -416,21 +428,21 @@ func (c *Client) OnConnection() {
}
//TODO
default:
c.Ban(DefaultBanTime)
c.Ban(DefaultBanTime, fmt.Errorf("unknown MessageId %d", messageId))
return
}
}
}
func (c *Client) sendHandshakeChallenge() {
if _, err := rand.Read(c.HandshakeChallenge[:]); err != nil {
if _, err := rand.Read(c.handshakeChallenge[:]); err != nil {
log.Printf("[P2PServer] Unable to generate handshake challenge for %s", c.AddressPort.String())
c.Close()
return
}
var buf [HandshakeChallengeSize + int(unsafe.Sizeof(uint64(0)))]byte
copy(buf[:], c.HandshakeChallenge[:])
copy(buf[:], c.handshakeChallenge[:])
binary.LittleEndian.PutUint64(buf[HandshakeChallengeSize:], c.Owner.PeerId())
c.SendMessage(&ClientMessage{
@ -481,7 +493,9 @@ type ClientMessage struct {
}
func (c *Client) SendMessage(message *ClientMessage) {
c.messageChannel <- message
if !c.Closed.Load() {
c.messageChannel <- message
}
}
// ReadByte reads from underlying connection, on error it will Close
@ -515,8 +529,13 @@ func (c *Client) Close() {
c.Owner.clientsLock.Lock()
defer c.Owner.clientsLock.Unlock()
c.Owner.clients = slices.Delete(c.Owner.clients, i, i)
if c.IsIncomingConnection {
c.Owner.NumIncomingConnections.Add(-1)
} else {
c.Owner.NumOutgoingConnections.Add(-1)
}
}
_ = c.Connection.Close()
close(c.messageChannel)
close(c.closeChannel)
}

View file

@ -43,8 +43,8 @@ type Server struct {
MaxOutgoingPeers uint32
MaxIncomingPeers uint32
NumOutgoingConnections atomic.Uint32
NumIncomingConnections atomic.Uint32
NumOutgoingConnections atomic.Int32
NumIncomingConnections atomic.Int32
clientsLock sync.RWMutex
clients []*Client
@ -75,6 +75,14 @@ func NewServer(sidechain *sidechain.SideChain, listenAddress string, maxOutgoing
func (s *Server) AddToPeerList(addressPort netip.AddrPort) {
log.Printf("TODO AddToPeerList %s", addressPort.String())
if uint32(s.NumOutgoingConnections.Load()) < s.MaxOutgoingPeers {
go func() {
if err := s.Connect(addressPort); err != nil {
log.Printf("error connecting to %s: %s", addressPort.String(), err.Error())
}
}()
}
}
func (s *Server) Listen() (err error) {
@ -87,18 +95,14 @@ func (s *Server) Listen() (err error) {
return err
} else {
if err = func() error {
if s.NumIncomingConnections.Load() > s.MaxIncomingPeers {
if uint32(s.NumIncomingConnections.Load()) > s.MaxIncomingPeers {
return errors.New("incoming connections limit was reached")
}
if addrPort, err := netip.ParseAddrPort(conn.RemoteAddr().String()); err != nil {
return err
} else if !addrPort.Addr().IsLoopback() {
s.clientsLock.RLock()
defer s.clientsLock.RUnlock()
for _, c := range s.clients {
if c.AddressPort.Addr().Compare(addrPort.Addr()) == 0 {
return errors.New("peer is already connected as " + c.AddressPort.String())
}
if clients := s.GetAddressConnected(addrPort.Addr()); !addrPort.Addr().IsLoopback() && len(clients) != 0 {
return errors.New("peer is already connected as " + clients[0].AddressPort.String())
}
}
@ -117,6 +121,7 @@ func (s *Server) Listen() (err error) {
client := NewClient(s, conn)
client.IsIncomingConnection = true
s.clients = append(s.clients, client)
s.NumIncomingConnections.Add(1)
go client.OnConnection()
}()
}
@ -127,14 +132,30 @@ func (s *Server) Listen() (err error) {
return nil
}
func (s *Server) Connect(addr netip.AddrPort) error {
if conn, err := net.Dial("tcp", addr.String()); err != nil {
func (s *Server) GetAddressConnected(addr netip.Addr) (result []*Client) {
s.clientsLock.RLock()
defer s.clientsLock.RUnlock()
for _, c := range s.clients {
if c.AddressPort.Addr().Compare(addr) == 0 {
result = append(result, c)
}
}
return result
}
func (s *Server) Connect(addrPort netip.AddrPort) error {
if clients := s.GetAddressConnected(addrPort.Addr()); !addrPort.Addr().IsLoopback() && len(clients) != 0 {
return errors.New("peer is already connected as " + clients[0].AddressPort.String())
}
if conn, err := net.Dial("tcp", addrPort.String()); err != nil {
return err
} else {
s.clientsLock.Lock()
defer s.clientsLock.Unlock()
client := NewClient(s, conn)
s.clients = append(s.clients, client)
s.NumOutgoingConnections.Add(1)
go client.OnConnection()
return nil
}
@ -146,8 +167,16 @@ func (s *Server) Clients() []*Client {
return slices.Clone(s.clients)
}
func (s *Server) Ban(ip netip.Addr, duration time.Duration) {
//TODO
func (s *Server) Ban(ip netip.Addr, duration time.Duration, err error) {
//TODO: banlist
go func() {
log.Printf("[P2PServer] Banned %s for %s: %s", ip.String(), duration.String(), err.Error())
if !ip.IsLoopback() {
for _, c := range s.GetAddressConnected(ip) {
c.Close()
}
}
}()
}
func (s *Server) Close() {
@ -164,6 +193,10 @@ func (s *Server) SideChain() *sidechain.SideChain {
return s.sidechain
}
func (s *Server) updateClients() {
}
func (s *Server) Broadcast(block *sidechain.PoolBlock) {
var message *ClientMessage
if block != nil {

View file

@ -120,7 +120,7 @@ func (c *SideChain) saveBlock(block *PoolBlock) {
if parent == nil || templateBlock == nil || (block.Side.Height % BlockSaveEpochSize) == 0 { //store full blocks every once in a while, or when there is no template block
if parent == nil || templateBlock == nil || block.Side.Height == fullBlockTemplateHeight { //store full blocks every once in a while, or when there is no template block
blockFlags |= BlockSaveOptionTemplate
} else {
if minerAddressOffset > 0 {

View file

@ -8,19 +8,17 @@ import (
"github.com/floatdrop/lru"
)
type derivationCacheKey [crypto.PublicKeySize * 2]byte
type sharedDataCacheKey [crypto.PrivateKeySize + 8]byte
type deterministicTransactionCacheKey [crypto.PublicKeySize + types.HashSize]byte
type ephemeralPublicKeyCacheKey [crypto.PrivateKeySize + crypto.PublicKeySize * 2 + 8]byte
type sharedDataWithTag struct {
SharedData *crypto.PrivateKeyScalar
type ephemeralPublicKeyWithViewTag struct {
PublicKey crypto.PublicKeyBytes
ViewTag uint8
}
type DerivationCache struct {
deterministicKeyCache *lru.LRU[derivationCacheKey, *crypto.KeyPair]
derivationCache *lru.LRU[derivationCacheKey, *crypto.PublicKeyPoint]
sharedDataCache *lru.LRU[sharedDataCacheKey, sharedDataWithTag]
ephemeralPublicKeyCache *lru.LRU[derivationCacheKey, crypto.PublicKeyBytes]
deterministicKeyCache *lru.LRU[deterministicTransactionCacheKey, *crypto.KeyPair]
ephemeralPublicKeyCache *lru.LRU[ephemeralPublicKeyCacheKey, ephemeralPublicKeyWithViewTag]
}
func NewDerivationCache() *DerivationCache {
@ -41,47 +39,27 @@ func (d *DerivationCache) Clear() {
const knownMinersPerPplns = pplnsSize / 4
const outputIdsPerMiner = 2
d.deterministicKeyCache = lru.New[derivationCacheKey, *crypto.KeyPair](cacheForNMinutesOfShares)
d.derivationCache = lru.New[derivationCacheKey, *crypto.PublicKeyPoint](pplnsSize * knownMinersPerPplns)
d.sharedDataCache = lru.New[sharedDataCacheKey, sharedDataWithTag](pplnsSize * knownMinersPerPplns * outputIdsPerMiner)
d.ephemeralPublicKeyCache = lru.New[derivationCacheKey, crypto.PublicKeyBytes](pplnsSize * knownMinersPerPplns * outputIdsPerMiner)
d.deterministicKeyCache = lru.New[deterministicTransactionCacheKey, *crypto.KeyPair](cacheForNMinutesOfShares)
d.ephemeralPublicKeyCache = lru.New[ephemeralPublicKeyCacheKey, ephemeralPublicKeyWithViewTag](pplnsSize * knownMinersPerPplns * outputIdsPerMiner)
}
func (d *DerivationCache) GetEphemeralPublicKey(a address.Interface, txKey crypto.PrivateKey, outputIndex uint64) (crypto.PublicKeyBytes, uint8) {
sharedData, viewTag := d.GetSharedData(a, txKey, outputIndex)
var key derivationCacheKey
copy(key[:], a.SpendPublicKey().AsSlice())
copy(key[types.HashSize:], sharedData.AsSlice())
var key ephemeralPublicKeyCacheKey
copy(key[:], txKey.AsSlice())
copy(key[crypto.PrivateKeySize:], a.ToPackedAddress().Bytes())
binary.LittleEndian.PutUint64(key[crypto.PrivateKeySize + crypto.PublicKeySize*2:], outputIndex)
if ephemeralPubKey := d.ephemeralPublicKeyCache.Get(key); ephemeralPubKey == nil {
ephemeralPubKey := address.GetPublicKeyForSharedData(a, sharedData).AsBytes()
d.ephemeralPublicKeyCache.Set(key, ephemeralPubKey)
return ephemeralPubKey, viewTag
ephemeralPubKey, viewTag := address.GetEphemeralPublicKeyAndViewTag(a, txKey, outputIndex)
pKB := ephemeralPubKey.AsBytes()
d.ephemeralPublicKeyCache.Set(key, ephemeralPublicKeyWithViewTag{PublicKey: pKB, ViewTag: viewTag})
return pKB, viewTag
} else {
return *ephemeralPubKey, viewTag
}
}
func (d *DerivationCache) GetSharedData(a address.Interface, txKey crypto.PrivateKey, outputIndex uint64) (*crypto.PrivateKeyScalar, uint8) {
derivation := d.GetDerivation(a, txKey)
var key sharedDataCacheKey
copy(key[:], derivation.AsSlice())
binary.LittleEndian.PutUint64(key[types.HashSize:], outputIndex)
if sharedData := d.sharedDataCache.Get(key); sharedData == nil {
var data sharedDataWithTag
data.SharedData = crypto.GetDerivationSharedDataForOutputIndex(derivation, outputIndex).AsScalar()
data.ViewTag = crypto.GetDerivationViewTagForOutputIndex(derivation, outputIndex)
d.sharedDataCache.Set(key, data)
return data.SharedData, data.ViewTag
} else {
return sharedData.SharedData, sharedData.ViewTag
return ephemeralPubKey.PublicKey, ephemeralPubKey.ViewTag
}
}
func (d *DerivationCache) GetDeterministicTransactionKey(a address.Interface, prevId types.Hash) *crypto.KeyPair {
var key derivationCacheKey
var key deterministicTransactionCacheKey
copy(key[:], a.SpendPublicKey().AsSlice())
copy(key[types.HashSize:], prevId[:])
@ -93,17 +71,3 @@ func (d *DerivationCache) GetDeterministicTransactionKey(a address.Interface, pr
return *kp
}
}
func (d *DerivationCache) GetDerivation(a address.Interface, txKey crypto.PrivateKey) *crypto.PublicKeyPoint {
var key derivationCacheKey
copy(key[:], a.ViewPublicKey().AsSlice())
copy(key[types.HashSize:], txKey.AsSlice())
if derivation := d.derivationCache.Get(key); derivation == nil {
data := txKey.GetDerivationCofactor(a.ViewPublicKey())
d.derivationCache.Set(key, data.AsPoint())
return data.AsPoint()
} else {
return *derivation
}
}

View file

@ -190,6 +190,7 @@ func (c *SideChain) AddPoolBlockExternal(block *PoolBlock) (missingBlocks []type
}
func (c *SideChain) AddPoolBlock(block *PoolBlock) (err error) {
c.sidechainLock.Lock()
defer c.sidechainLock.Unlock()
if _, ok := c.blocksByTemplateId[block.SideTemplateId(c.Consensus())]; ok {
@ -199,6 +200,8 @@ func (c *SideChain) AddPoolBlock(block *PoolBlock) (err error) {
}
c.blocksByTemplateId[block.SideTemplateId(c.Consensus())] = block
log.Printf("[SideChain] add_block: height = %d, id = %s, mainchain height = %d, verified = %t", block.Side.Height, block.SideTemplateId(c.Consensus()), block.Main.Coinbase.GenHeight, block.Verified.Load())
if l, ok := c.blocksByHeight[block.Side.Height]; ok {
c.blocksByHeight[block.Side.Height] = append(l, block)
} else {
@ -241,7 +244,7 @@ func (c *SideChain) verifyLoop(blockToVerify *PoolBlock) (err error) {
err = invalid
}
} else if verification != nil {
log.Printf("[SideChain] can't verify block at height = %d, id = %s, mainchain height = %d, mined by %s: %s", block.Side.Height, block.SideTemplateId(c.Consensus()), block.Main.Coinbase.GenHeight, block.GetAddress().ToBase58(), verification.Error())
//log.Printf("[SideChain] can't verify block at height = %d, id = %s, mainchain height = %d, mined by %s: %s", block.Side.Height, block.SideTemplateId(c.Consensus()), block.Main.Coinbase.GenHeight, block.GetAddress().ToBase58(), verification.Error())
block.Verified.Store(false)
block.Invalid.Store(false)
} else {