From bd6487d1f64aa34cccd65dce35ae103a60f7f3a5 Mon Sep 17 00:00:00 2001 From: Jimmy Song Date: Thu, 11 May 2017 12:15:06 -0700 Subject: [PATCH] refactor and use interfaces! --- key.go | 70 ++++++++++++++++++++++++++ ringsignature.go | 112 +++++++++++------------------------------- ringsignature_test.go | 6 +-- transaction.go | 97 +++++++++++++++++++++--------------- varint.go | 20 +++++--- 5 files changed, 174 insertions(+), 131 deletions(-) create mode 100644 key.go diff --git a/key.go b/key.go new file mode 100644 index 0000000..f0c7041 --- /dev/null +++ b/key.go @@ -0,0 +1,70 @@ +package moneroutil + +import ( + "crypto/rand" +) + +const ( + PointLength = 32 + ScalarLength = 32 +) + +type PubKey [PointLength]byte +type PrivKey [ScalarLength]byte + +func (p *PrivKey) FromBytes(b [ScalarLength]byte) { + *p = b +} + +func (p *PrivKey) ToBytes() (result [ScalarLength]byte) { + result = [32]byte(*p) + return +} + +func (p *PrivKey) PubKey() (pubKey *PubKey) { + secret := p.ToBytes() + point := new(ExtendedGroupElement) + GeScalarMultBase(point, &secret) + pubKeyBytes := new([PointLength]byte) + point.ToBytes(pubKeyBytes) + pubKey = (*PubKey)(pubKeyBytes) + return +} + +func (p *PubKey) FromBytes(b [PointLength]byte) { + *p = b +} + +func (p *PubKey) ToBytes() (result [PointLength]byte) { + result = [PointLength]byte(*p) + return +} + +// Creates a point on the Edwards Curve by hashing the key +func (p *PubKey) HashToEC() (result *ExtendedGroupElement) { + result = new(ExtendedGroupElement) + var p1 ProjectiveGroupElement + var p2 CompletedGroupElement + h := [PointLength]byte(Keccak256(p[:])) + p1.FromBytes(&h) + GeMul8(&p2, &p1) + p2.ToExtended(result) + return +} + +func RandomScalar() (result [ScalarLength]byte) { + var reduceFrom [ScalarLength * 2]byte + tmp := make([]byte, ScalarLength*2) + rand.Read(tmp) + copy(reduceFrom[:], tmp) + ScReduce(&result, &reduceFrom) + return +} + +func NewKeyPair() (privKey *PrivKey, pubKey *PubKey) { + privKey = new(PrivKey) + pubKey = new(PubKey) + privKey.FromBytes(RandomScalar()) + pubKey = privKey.PubKey() + return +} diff --git a/ringsignature.go b/ringsignature.go index 94d9a89..e2664cf 100644 --- a/ringsignature.go +++ b/ringsignature.go @@ -1,20 +1,11 @@ package moneroutil import ( - "bytes" - "crypto/rand" "fmt" - mathrand "math/rand" + "io" + "math/rand" ) -const ( - PointLength = 32 - ScalarLength = 32 -) - -type PubKey [PointLength]byte -type PrivKey [ScalarLength]byte - type RingSignatureElement struct { c [ScalarLength]byte r [ScalarLength]byte @@ -22,80 +13,48 @@ type RingSignatureElement struct { type RingSignature []*RingSignatureElement -func RandomScalar() (result [ScalarLength]byte) { - var reduceFrom [ScalarLength * 2]byte - tmp := make([]byte, ScalarLength*2) - rand.Read(tmp) - copy(reduceFrom[:], tmp) - ScReduce(&result, &reduceFrom) - return -} - -func (p *PrivKey) FromBytes(b [ScalarLength]byte) { - *p = b -} - -func (p *PrivKey) ToBytes() (result [ScalarLength]byte) { - result = [32]byte(*p) - return -} - -func (p *PrivKey) PubKey() (pubKey *PubKey) { - secret := p.ToBytes() - point := new(ExtendedGroupElement) - GeScalarMultBase(point, &secret) - pubKeyBytes := new([PointLength]byte) - point.ToBytes(pubKeyBytes) - pubKey = (*PubKey)(pubKeyBytes) - return -} - -func (p *PubKey) ToBytes() (result [PointLength]byte) { - result = [PointLength]byte(*p) - return -} - -func NewKeyPair() (privKey *PrivKey, pubKey *PubKey) { - privKey = new(PrivKey) - pubKey = new(PubKey) - privKey.FromBytes(RandomScalar()) - pubKey = privKey.PubKey() - return -} - -func (s *RingSignatureElement) Serialize() (result []byte) { +func (r *RingSignatureElement) Serialize() (result []byte) { result = make([]byte, 2*ScalarLength) - copy(result, s.c[:]) - copy(result[ScalarLength:2*ScalarLength], s.r[:]) + copy(result[:ScalarLength], r.c[:]) + copy(result[ScalarLength:2*ScalarLength], r.r[:]) return } func (r *RingSignature) Serialize() (result []byte) { + result = make([]byte, len(*r)*ScalarLength*2) for i := 0; i < len(*r); i++ { - result = append(result, (*r)[i].Serialize()...) + copy(result[i*ScalarLength*2:(i+1)*ScalarLength*2], (*r)[i].Serialize()) } return } -func ParseSignature(buf *bytes.Buffer) (result *RingSignatureElement, err error) { - s := new(RingSignatureElement) - c := buf.Next(ScalarLength) - if len(c) != ScalarLength { +func ParseSignature(buf io.Reader) (result *RingSignatureElement, err error) { + rse := new(RingSignatureElement) + c := make([]byte, ScalarLength) + n, err := buf.Read(c) + if err != nil { + return + } + if n != ScalarLength { err = fmt.Errorf("Not enough bytes for signature c") return } - copy(s.c[:], c) - r := buf.Next(ScalarLength) - if len(r) != ScalarLength { + copy(rse.c[:], c) + r := make([]byte, ScalarLength) + n, err = buf.Read(r) + if err != nil { + return + } + if n != ScalarLength { err = fmt.Errorf("Not enough bytes for signature r") return } - copy(s.r[:], r) - result = s + copy(rse.r[:], r) + result = rse return } -func ParseSignatures(mixinLengths []int, buf *bytes.Buffer) (signatures []RingSignature, err error) { +func ParseSignatures(mixinLengths []int, buf io.Reader) (signatures []RingSignature, err error) { // mixinLengths is the number of mixins at each input position sigs := make([]RingSignature, len(mixinLengths), len(mixinLengths)) for i, nMixin := range mixinLengths { @@ -111,16 +70,6 @@ func ParseSignatures(mixinLengths []int, buf *bytes.Buffer) (signatures []RingSi return } -// hashes a pubkey into an Edwards Curve element -func HashToEC(pk *PubKey, r *ExtendedGroupElement) { - var p1 ProjectiveGroupElement - var p2 CompletedGroupElement - h := [PointLength]byte(Keccak256(pk[:])) - p1.FromBytes(&h) - GeMul8(&p2, &p1) - p2.ToExtended(r) -} - func HashToScalar(data ...[]byte) (result [ScalarLength]byte) { result = Keccak256(data...) ScReduce32(&result) @@ -128,8 +77,7 @@ func HashToScalar(data ...[]byte) (result [ScalarLength]byte) { } func CreateSignature(prefixHash *Hash, mixins []PubKey, privKey *PrivKey) (keyImage PubKey, pubKeys []PubKey, sig RingSignature) { - point := new(ExtendedGroupElement) - HashToEC(privKey.PubKey(), point) + point := privKey.PubKey().HashToEC() privKeyBytes := privKey.ToBytes() keyImagePoint := new(ProjectiveGroupElement) GeScalarMult(keyImagePoint, &privKeyBytes, point) @@ -144,7 +92,7 @@ func CreateSignature(prefixHash *Hash, mixins []PubKey, privKey *PrivKey) (keyIm GePrecompute(&keyImagePre, keyImageGe) k := RandomScalar() pubKeys = make([]PubKey, len(mixins)+1) - privIndex := mathrand.Intn(len(pubKeys)) + privIndex := rand.Intn(len(pubKeys)) pubKeys[privIndex] = *privKey.PubKey() r := make([]*RingSignatureElement, len(pubKeys)) var sum [ScalarLength]byte @@ -157,7 +105,7 @@ func CreateSignature(prefixHash *Hash, mixins []PubKey, privKey *PrivKey) (keyIm GeScalarMultBase(tmpE, &k) tmpE.ToBytes(&tmpEBytes) toHash = append(toHash, tmpEBytes[:]...) - HashToEC(privKey.PubKey(), tmpE) + tmpE = privKey.PubKey().HashToEC() GeScalarMult(tmpP, &k, tmpE) tmpP.ToBytes(&tmpPBytes) toHash = append(toHash, tmpPBytes[:]...) @@ -176,7 +124,7 @@ func CreateSignature(prefixHash *Hash, mixins []PubKey, privKey *PrivKey) (keyIm GeDoubleScalarMultVartime(tmpP, &r[i].c, tmpE, &r[i].r) tmpP.ToBytes(&tmpPBytes) toHash = append(toHash, tmpPBytes[:]...) - HashToEC(&pubKeys[i], tmpE) + tmpE = pubKeys[i].HashToEC() GeDoubleScalarMultPrecompVartime(tmpP, &r[i].r, tmpE, &r[i].c, &keyImagePre) tmpP.ToBytes(&tmpPBytes) toHash = append(toHash, tmpPBytes[:]...) @@ -220,7 +168,7 @@ func VerifySignature(prefixHash *Hash, keyImage *PubKey, pubKeys []PubKey, ringS GeDoubleScalarMultVartime(tmpP, &rse.c, tmpE, &rse.r) tmpP.ToBytes(&tmpPBytes) toHash = append(toHash, tmpPBytes[:]...) - HashToEC(&pubKey, tmpE) + tmpE = pubKey.HashToEC() tmpE.ToBytes(&tmpEBytes) GeDoubleScalarMultPrecompVartime(tmpP, &rse.r, tmpE, &rse.c, &keyImagePre) tmpP.ToBytes(&tmpPBytes) diff --git a/ringsignature_test.go b/ringsignature_test.go index 6171349..6b79cbc 100644 --- a/ringsignature_test.go +++ b/ringsignature_test.go @@ -1799,10 +1799,10 @@ func TestHashToEC(t *testing.T) { } for _, test := range tests { pubkeyBytes := HexToBytes(test.pubkeyHex) - pubKey := PubKey(pubkeyBytes) + pubKey := new(PubKey) + pubKey.FromBytes(pubkeyBytes) want := HexToBytes(test.extendedHex) - ecPoint := new(ExtendedGroupElement) - HashToEC(&pubKey, ecPoint) + ecPoint := pubKey.HashToEC() var got [32]byte ecPoint.ToBytes(&got) if want != got { diff --git a/transaction.go b/transaction.go index a9570be..36bf482 100644 --- a/transaction.go +++ b/transaction.go @@ -1,9 +1,8 @@ package moneroutil import ( - "bytes" - "errors" "fmt" + "io" ) const ( @@ -213,7 +212,7 @@ func (t *Transaction) Serialize() (result []byte) { return } -func ParseTxInGen(buf *bytes.Buffer) (txIn *txInGen, err error) { +func ParseTxInGen(buf io.Reader) (txIn *txInGen, err error) { t := new(txInGen) t.height, err = ReadVarInt(buf) if err != nil { @@ -223,17 +222,17 @@ func ParseTxInGen(buf *bytes.Buffer) (txIn *txInGen, err error) { return } -func ParseTxInToScript(buf *bytes.Buffer) (txIn *txInToScript, err error) { - err = errors.New("Unimplemented") +func ParseTxInToScript(buf io.Reader) (txIn *txInToScript, err error) { + err = fmt.Errorf("Unimplemented") return } -func ParseTxInToScriptHash(buf *bytes.Buffer) (txIn *txInToScriptHash, err error) { - err = errors.New("Unimplemented") +func ParseTxInToScriptHash(buf io.Reader) (txIn *txInToScriptHash, err error) { + err = fmt.Errorf("Unimplemented") return } -func ParseTxInToKey(buf *bytes.Buffer) (txIn *txInToKey, err error) { +func ParseTxInToKey(buf io.Reader) (txIn *txInToKey, err error) { t := new(txInToKey) t.amount, err = ReadVarInt(buf) if err != nil { @@ -250,9 +249,13 @@ func ParseTxInToKey(buf *bytes.Buffer) (txIn *txInToKey, err error) { return } } - pubKey := buf.Next(PointLength) - if len(pubKey) != PointLength { - err = errors.New("Buffer not long enough for public key") + pubKey := make([]byte, PointLength) + n, err := buf.Read(pubKey) + if err != nil { + return + } + if n != PointLength { + err = fmt.Errorf("Buffer not long enough for public key") return } copy(t.keyImage[:], pubKey) @@ -260,39 +263,48 @@ func ParseTxInToKey(buf *bytes.Buffer) (txIn *txInToKey, err error) { return } -func ParseTxIn(buf *bytes.Buffer) (txIn TxInSerializer, err error) { - marker, err := buf.ReadByte() +func ParseTxIn(buf io.Reader) (txIn TxInSerializer, err error) { + marker := make([]byte, 1) + n, err := buf.Read(marker) + if n != 1 { + err = fmt.Errorf("Buffer not enough for TxIn") + return + } if err != nil { return } switch { - case marker == txInGenMarker: + case marker[0] == txInGenMarker: txIn, err = ParseTxInGen(buf) - case marker == txInToScriptMarker: + case marker[0] == txInToScriptMarker: txIn, err = ParseTxInToScript(buf) - case marker == txInToScriptHashMarker: + case marker[0] == txInToScriptHashMarker: txIn, err = ParseTxInToScriptHash(buf) - case marker == txInToKeyMarker: + case marker[0] == txInToKeyMarker: txIn, err = ParseTxInToKey(buf) } return } -func ParseTxOutToScript(buf *bytes.Buffer) (txOutTarget *txOutToScript, err error) { - err = errors.New("Unimplemented") +func ParseTxOutToScript(buf io.Reader) (txOutTarget *txOutToScript, err error) { + err = fmt.Errorf("Unimplemented") return } -func ParseTxOutToScriptHash(buf *bytes.Buffer) (txOutTarget *txOutToScriptHash, err error) { - err = errors.New("Unimplemented") +func ParseTxOutToScriptHash(buf io.Reader) (txOutTarget *txOutToScriptHash, err error) { + err = fmt.Errorf("Unimplemented") return } -func ParseTxOutToKey(buf *bytes.Buffer) (txOutTarget *txOutToKey, err error) { +func ParseTxOutToKey(buf io.Reader) (txOutTarget *txOutToKey, err error) { t := new(txOutToKey) - pubKey := buf.Next(PointLength) - if len(pubKey) != PointLength { - err = errors.New("Buffer not long enough for public key") + pubKey := make([]byte, PointLength) + n, err := buf.Read(pubKey) + if err != nil { + return + } + if n != PointLength { + err = fmt.Errorf("Buffer not long enough for public key") return } copy(t.key[:], pubKey) @@ -300,25 +312,30 @@ func ParseTxOutToKey(buf *bytes.Buffer) (txOutTarget *txOutToKey, err error) { return } -func ParseTxOut(buf *bytes.Buffer) (txOut *TxOut, err error) { +func ParseTxOut(buf io.Reader) (txOut *TxOut, err error) { t := new(TxOut) t.amount, err = ReadVarInt(buf) if err != nil { return } - marker, err := buf.ReadByte() + marker := make([]byte, 1) + n, err := buf.Read(marker) if err != nil { return } + if n != 1 { + err = fmt.Errorf("Buffer not long enough for TxOut") + return + } switch { - case marker == txOutToScriptMarker: + case marker[0] == txOutToScriptMarker: t.target, err = ParseTxOutToScript(buf) - case marker == txOutToScriptHashMarker: + case marker[0] == txOutToScriptHashMarker: t.target, err = ParseTxOutToScriptHash(buf) - case marker == txOutToKeyMarker: + case marker[0] == txOutToKeyMarker: t.target, err = ParseTxOutToKey(buf) default: - err = errors.New("Bad Marker") + err = fmt.Errorf("Bad Marker") return } if err != nil { @@ -328,21 +345,25 @@ func ParseTxOut(buf *bytes.Buffer) (txOut *TxOut, err error) { return } -func ParseExtra(buf *bytes.Buffer) (extra []byte, err error) { +func ParseExtra(buf io.Reader) (extra []byte, err error) { length, err := ReadVarInt(buf) if err != nil { return } - e := buf.Next(int(length)) - if len(e) != int(length) { - err = errors.New("Not enough bytes for extra") + e := make([]byte, int(length)) + n, err := buf.Read(e) + if err != nil { + return + } + if n != int(length) { + err = fmt.Errorf("Not enough bytes for extra") return } extra = e return } -func ParseTransaction(buf *bytes.Buffer) (transaction *Transaction, err error) { +func ParseTransaction(buf io.Reader) (transaction *Transaction, err error) { t := new(Transaction) version, err := ReadVarInt(buf) if err != nil { @@ -388,10 +409,6 @@ func ParseTransaction(buf *bytes.Buffer) (transaction *Transaction, err error) { if err != nil { return } - if buf.Len() != 0 { - err = errors.New("Buffer has extra data") - return - } transaction = t return } diff --git a/varint.go b/varint.go index 2386b18..6745733 100644 --- a/varint.go +++ b/varint.go @@ -1,21 +1,29 @@ package moneroutil import ( - "bytes" + "fmt" + "io" ) -func ReadVarInt(buf *bytes.Buffer) (result uint64, err error) { - var b byte +func ReadVarInt(buf io.Reader) (result uint64, err error) { + b := make([]byte, 1) + var r uint64 + var n int for i := 0; ; i++ { - b, err = buf.ReadByte() + n, err = buf.Read(b) if err != nil { return } - result += (uint64(b) & 0x7f) << uint(i*7) - if uint64(b)&0x80 == 0 { + if n != 1 { + err = fmt.Errorf("Buffer ended prematurely for varint") + return + } + r += (uint64(b[0]) & 0x7f) << uint(i*7) + if uint64(b[0])&0x80 == 0 { break } } + result = r return }