refactor and use interfaces!

This commit is contained in:
Jimmy Song 2017-05-11 12:15:06 -07:00
parent 025ef20056
commit bd6487d1f6
5 changed files with 174 additions and 131 deletions

70
key.go Normal file
View file

@ -0,0 +1,70 @@
package moneroutil
import (
"crypto/rand"
)
const (
PointLength = 32
ScalarLength = 32
)
type PubKey [PointLength]byte
type PrivKey [ScalarLength]byte
func (p *PrivKey) FromBytes(b [ScalarLength]byte) {
*p = b
}
func (p *PrivKey) ToBytes() (result [ScalarLength]byte) {
result = [32]byte(*p)
return
}
func (p *PrivKey) PubKey() (pubKey *PubKey) {
secret := p.ToBytes()
point := new(ExtendedGroupElement)
GeScalarMultBase(point, &secret)
pubKeyBytes := new([PointLength]byte)
point.ToBytes(pubKeyBytes)
pubKey = (*PubKey)(pubKeyBytes)
return
}
func (p *PubKey) FromBytes(b [PointLength]byte) {
*p = b
}
func (p *PubKey) ToBytes() (result [PointLength]byte) {
result = [PointLength]byte(*p)
return
}
// Creates a point on the Edwards Curve by hashing the key
func (p *PubKey) HashToEC() (result *ExtendedGroupElement) {
result = new(ExtendedGroupElement)
var p1 ProjectiveGroupElement
var p2 CompletedGroupElement
h := [PointLength]byte(Keccak256(p[:]))
p1.FromBytes(&h)
GeMul8(&p2, &p1)
p2.ToExtended(result)
return
}
func RandomScalar() (result [ScalarLength]byte) {
var reduceFrom [ScalarLength * 2]byte
tmp := make([]byte, ScalarLength*2)
rand.Read(tmp)
copy(reduceFrom[:], tmp)
ScReduce(&result, &reduceFrom)
return
}
func NewKeyPair() (privKey *PrivKey, pubKey *PubKey) {
privKey = new(PrivKey)
pubKey = new(PubKey)
privKey.FromBytes(RandomScalar())
pubKey = privKey.PubKey()
return
}

View file

@ -1,20 +1,11 @@
package moneroutil
import (
"bytes"
"crypto/rand"
"fmt"
mathrand "math/rand"
"io"
"math/rand"
)
const (
PointLength = 32
ScalarLength = 32
)
type PubKey [PointLength]byte
type PrivKey [ScalarLength]byte
type RingSignatureElement struct {
c [ScalarLength]byte
r [ScalarLength]byte
@ -22,80 +13,48 @@ type RingSignatureElement struct {
type RingSignature []*RingSignatureElement
func RandomScalar() (result [ScalarLength]byte) {
var reduceFrom [ScalarLength * 2]byte
tmp := make([]byte, ScalarLength*2)
rand.Read(tmp)
copy(reduceFrom[:], tmp)
ScReduce(&result, &reduceFrom)
return
}
func (p *PrivKey) FromBytes(b [ScalarLength]byte) {
*p = b
}
func (p *PrivKey) ToBytes() (result [ScalarLength]byte) {
result = [32]byte(*p)
return
}
func (p *PrivKey) PubKey() (pubKey *PubKey) {
secret := p.ToBytes()
point := new(ExtendedGroupElement)
GeScalarMultBase(point, &secret)
pubKeyBytes := new([PointLength]byte)
point.ToBytes(pubKeyBytes)
pubKey = (*PubKey)(pubKeyBytes)
return
}
func (p *PubKey) ToBytes() (result [PointLength]byte) {
result = [PointLength]byte(*p)
return
}
func NewKeyPair() (privKey *PrivKey, pubKey *PubKey) {
privKey = new(PrivKey)
pubKey = new(PubKey)
privKey.FromBytes(RandomScalar())
pubKey = privKey.PubKey()
return
}
func (s *RingSignatureElement) Serialize() (result []byte) {
func (r *RingSignatureElement) Serialize() (result []byte) {
result = make([]byte, 2*ScalarLength)
copy(result, s.c[:])
copy(result[ScalarLength:2*ScalarLength], s.r[:])
copy(result[:ScalarLength], r.c[:])
copy(result[ScalarLength:2*ScalarLength], r.r[:])
return
}
func (r *RingSignature) Serialize() (result []byte) {
result = make([]byte, len(*r)*ScalarLength*2)
for i := 0; i < len(*r); i++ {
result = append(result, (*r)[i].Serialize()...)
copy(result[i*ScalarLength*2:(i+1)*ScalarLength*2], (*r)[i].Serialize())
}
return
}
func ParseSignature(buf *bytes.Buffer) (result *RingSignatureElement, err error) {
s := new(RingSignatureElement)
c := buf.Next(ScalarLength)
if len(c) != ScalarLength {
func ParseSignature(buf io.Reader) (result *RingSignatureElement, err error) {
rse := new(RingSignatureElement)
c := make([]byte, ScalarLength)
n, err := buf.Read(c)
if err != nil {
return
}
if n != ScalarLength {
err = fmt.Errorf("Not enough bytes for signature c")
return
}
copy(s.c[:], c)
r := buf.Next(ScalarLength)
if len(r) != ScalarLength {
copy(rse.c[:], c)
r := make([]byte, ScalarLength)
n, err = buf.Read(r)
if err != nil {
return
}
if n != ScalarLength {
err = fmt.Errorf("Not enough bytes for signature r")
return
}
copy(s.r[:], r)
result = s
copy(rse.r[:], r)
result = rse
return
}
func ParseSignatures(mixinLengths []int, buf *bytes.Buffer) (signatures []RingSignature, err error) {
func ParseSignatures(mixinLengths []int, buf io.Reader) (signatures []RingSignature, err error) {
// mixinLengths is the number of mixins at each input position
sigs := make([]RingSignature, len(mixinLengths), len(mixinLengths))
for i, nMixin := range mixinLengths {
@ -111,16 +70,6 @@ func ParseSignatures(mixinLengths []int, buf *bytes.Buffer) (signatures []RingSi
return
}
// hashes a pubkey into an Edwards Curve element
func HashToEC(pk *PubKey, r *ExtendedGroupElement) {
var p1 ProjectiveGroupElement
var p2 CompletedGroupElement
h := [PointLength]byte(Keccak256(pk[:]))
p1.FromBytes(&h)
GeMul8(&p2, &p1)
p2.ToExtended(r)
}
func HashToScalar(data ...[]byte) (result [ScalarLength]byte) {
result = Keccak256(data...)
ScReduce32(&result)
@ -128,8 +77,7 @@ func HashToScalar(data ...[]byte) (result [ScalarLength]byte) {
}
func CreateSignature(prefixHash *Hash, mixins []PubKey, privKey *PrivKey) (keyImage PubKey, pubKeys []PubKey, sig RingSignature) {
point := new(ExtendedGroupElement)
HashToEC(privKey.PubKey(), point)
point := privKey.PubKey().HashToEC()
privKeyBytes := privKey.ToBytes()
keyImagePoint := new(ProjectiveGroupElement)
GeScalarMult(keyImagePoint, &privKeyBytes, point)
@ -144,7 +92,7 @@ func CreateSignature(prefixHash *Hash, mixins []PubKey, privKey *PrivKey) (keyIm
GePrecompute(&keyImagePre, keyImageGe)
k := RandomScalar()
pubKeys = make([]PubKey, len(mixins)+1)
privIndex := mathrand.Intn(len(pubKeys))
privIndex := rand.Intn(len(pubKeys))
pubKeys[privIndex] = *privKey.PubKey()
r := make([]*RingSignatureElement, len(pubKeys))
var sum [ScalarLength]byte
@ -157,7 +105,7 @@ func CreateSignature(prefixHash *Hash, mixins []PubKey, privKey *PrivKey) (keyIm
GeScalarMultBase(tmpE, &k)
tmpE.ToBytes(&tmpEBytes)
toHash = append(toHash, tmpEBytes[:]...)
HashToEC(privKey.PubKey(), tmpE)
tmpE = privKey.PubKey().HashToEC()
GeScalarMult(tmpP, &k, tmpE)
tmpP.ToBytes(&tmpPBytes)
toHash = append(toHash, tmpPBytes[:]...)
@ -176,7 +124,7 @@ func CreateSignature(prefixHash *Hash, mixins []PubKey, privKey *PrivKey) (keyIm
GeDoubleScalarMultVartime(tmpP, &r[i].c, tmpE, &r[i].r)
tmpP.ToBytes(&tmpPBytes)
toHash = append(toHash, tmpPBytes[:]...)
HashToEC(&pubKeys[i], tmpE)
tmpE = pubKeys[i].HashToEC()
GeDoubleScalarMultPrecompVartime(tmpP, &r[i].r, tmpE, &r[i].c, &keyImagePre)
tmpP.ToBytes(&tmpPBytes)
toHash = append(toHash, tmpPBytes[:]...)
@ -220,7 +168,7 @@ func VerifySignature(prefixHash *Hash, keyImage *PubKey, pubKeys []PubKey, ringS
GeDoubleScalarMultVartime(tmpP, &rse.c, tmpE, &rse.r)
tmpP.ToBytes(&tmpPBytes)
toHash = append(toHash, tmpPBytes[:]...)
HashToEC(&pubKey, tmpE)
tmpE = pubKey.HashToEC()
tmpE.ToBytes(&tmpEBytes)
GeDoubleScalarMultPrecompVartime(tmpP, &rse.r, tmpE, &rse.c, &keyImagePre)
tmpP.ToBytes(&tmpPBytes)

View file

@ -1799,10 +1799,10 @@ func TestHashToEC(t *testing.T) {
}
for _, test := range tests {
pubkeyBytes := HexToBytes(test.pubkeyHex)
pubKey := PubKey(pubkeyBytes)
pubKey := new(PubKey)
pubKey.FromBytes(pubkeyBytes)
want := HexToBytes(test.extendedHex)
ecPoint := new(ExtendedGroupElement)
HashToEC(&pubKey, ecPoint)
ecPoint := pubKey.HashToEC()
var got [32]byte
ecPoint.ToBytes(&got)
if want != got {

View file

@ -1,9 +1,8 @@
package moneroutil
import (
"bytes"
"errors"
"fmt"
"io"
)
const (
@ -213,7 +212,7 @@ func (t *Transaction) Serialize() (result []byte) {
return
}
func ParseTxInGen(buf *bytes.Buffer) (txIn *txInGen, err error) {
func ParseTxInGen(buf io.Reader) (txIn *txInGen, err error) {
t := new(txInGen)
t.height, err = ReadVarInt(buf)
if err != nil {
@ -223,17 +222,17 @@ func ParseTxInGen(buf *bytes.Buffer) (txIn *txInGen, err error) {
return
}
func ParseTxInToScript(buf *bytes.Buffer) (txIn *txInToScript, err error) {
err = errors.New("Unimplemented")
func ParseTxInToScript(buf io.Reader) (txIn *txInToScript, err error) {
err = fmt.Errorf("Unimplemented")
return
}
func ParseTxInToScriptHash(buf *bytes.Buffer) (txIn *txInToScriptHash, err error) {
err = errors.New("Unimplemented")
func ParseTxInToScriptHash(buf io.Reader) (txIn *txInToScriptHash, err error) {
err = fmt.Errorf("Unimplemented")
return
}
func ParseTxInToKey(buf *bytes.Buffer) (txIn *txInToKey, err error) {
func ParseTxInToKey(buf io.Reader) (txIn *txInToKey, err error) {
t := new(txInToKey)
t.amount, err = ReadVarInt(buf)
if err != nil {
@ -250,9 +249,13 @@ func ParseTxInToKey(buf *bytes.Buffer) (txIn *txInToKey, err error) {
return
}
}
pubKey := buf.Next(PointLength)
if len(pubKey) != PointLength {
err = errors.New("Buffer not long enough for public key")
pubKey := make([]byte, PointLength)
n, err := buf.Read(pubKey)
if err != nil {
return
}
if n != PointLength {
err = fmt.Errorf("Buffer not long enough for public key")
return
}
copy(t.keyImage[:], pubKey)
@ -260,39 +263,48 @@ func ParseTxInToKey(buf *bytes.Buffer) (txIn *txInToKey, err error) {
return
}
func ParseTxIn(buf *bytes.Buffer) (txIn TxInSerializer, err error) {
marker, err := buf.ReadByte()
func ParseTxIn(buf io.Reader) (txIn TxInSerializer, err error) {
marker := make([]byte, 1)
n, err := buf.Read(marker)
if n != 1 {
err = fmt.Errorf("Buffer not enough for TxIn")
return
}
if err != nil {
return
}
switch {
case marker == txInGenMarker:
case marker[0] == txInGenMarker:
txIn, err = ParseTxInGen(buf)
case marker == txInToScriptMarker:
case marker[0] == txInToScriptMarker:
txIn, err = ParseTxInToScript(buf)
case marker == txInToScriptHashMarker:
case marker[0] == txInToScriptHashMarker:
txIn, err = ParseTxInToScriptHash(buf)
case marker == txInToKeyMarker:
case marker[0] == txInToKeyMarker:
txIn, err = ParseTxInToKey(buf)
}
return
}
func ParseTxOutToScript(buf *bytes.Buffer) (txOutTarget *txOutToScript, err error) {
err = errors.New("Unimplemented")
func ParseTxOutToScript(buf io.Reader) (txOutTarget *txOutToScript, err error) {
err = fmt.Errorf("Unimplemented")
return
}
func ParseTxOutToScriptHash(buf *bytes.Buffer) (txOutTarget *txOutToScriptHash, err error) {
err = errors.New("Unimplemented")
func ParseTxOutToScriptHash(buf io.Reader) (txOutTarget *txOutToScriptHash, err error) {
err = fmt.Errorf("Unimplemented")
return
}
func ParseTxOutToKey(buf *bytes.Buffer) (txOutTarget *txOutToKey, err error) {
func ParseTxOutToKey(buf io.Reader) (txOutTarget *txOutToKey, err error) {
t := new(txOutToKey)
pubKey := buf.Next(PointLength)
if len(pubKey) != PointLength {
err = errors.New("Buffer not long enough for public key")
pubKey := make([]byte, PointLength)
n, err := buf.Read(pubKey)
if err != nil {
return
}
if n != PointLength {
err = fmt.Errorf("Buffer not long enough for public key")
return
}
copy(t.key[:], pubKey)
@ -300,25 +312,30 @@ func ParseTxOutToKey(buf *bytes.Buffer) (txOutTarget *txOutToKey, err error) {
return
}
func ParseTxOut(buf *bytes.Buffer) (txOut *TxOut, err error) {
func ParseTxOut(buf io.Reader) (txOut *TxOut, err error) {
t := new(TxOut)
t.amount, err = ReadVarInt(buf)
if err != nil {
return
}
marker, err := buf.ReadByte()
marker := make([]byte, 1)
n, err := buf.Read(marker)
if err != nil {
return
}
if n != 1 {
err = fmt.Errorf("Buffer not long enough for TxOut")
return
}
switch {
case marker == txOutToScriptMarker:
case marker[0] == txOutToScriptMarker:
t.target, err = ParseTxOutToScript(buf)
case marker == txOutToScriptHashMarker:
case marker[0] == txOutToScriptHashMarker:
t.target, err = ParseTxOutToScriptHash(buf)
case marker == txOutToKeyMarker:
case marker[0] == txOutToKeyMarker:
t.target, err = ParseTxOutToKey(buf)
default:
err = errors.New("Bad Marker")
err = fmt.Errorf("Bad Marker")
return
}
if err != nil {
@ -328,21 +345,25 @@ func ParseTxOut(buf *bytes.Buffer) (txOut *TxOut, err error) {
return
}
func ParseExtra(buf *bytes.Buffer) (extra []byte, err error) {
func ParseExtra(buf io.Reader) (extra []byte, err error) {
length, err := ReadVarInt(buf)
if err != nil {
return
}
e := buf.Next(int(length))
if len(e) != int(length) {
err = errors.New("Not enough bytes for extra")
e := make([]byte, int(length))
n, err := buf.Read(e)
if err != nil {
return
}
if n != int(length) {
err = fmt.Errorf("Not enough bytes for extra")
return
}
extra = e
return
}
func ParseTransaction(buf *bytes.Buffer) (transaction *Transaction, err error) {
func ParseTransaction(buf io.Reader) (transaction *Transaction, err error) {
t := new(Transaction)
version, err := ReadVarInt(buf)
if err != nil {
@ -388,10 +409,6 @@ func ParseTransaction(buf *bytes.Buffer) (transaction *Transaction, err error) {
if err != nil {
return
}
if buf.Len() != 0 {
err = errors.New("Buffer has extra data")
return
}
transaction = t
return
}

View file

@ -1,21 +1,29 @@
package moneroutil
import (
"bytes"
"fmt"
"io"
)
func ReadVarInt(buf *bytes.Buffer) (result uint64, err error) {
var b byte
func ReadVarInt(buf io.Reader) (result uint64, err error) {
b := make([]byte, 1)
var r uint64
var n int
for i := 0; ; i++ {
b, err = buf.ReadByte()
n, err = buf.Read(b)
if err != nil {
return
}
result += (uint64(b) & 0x7f) << uint(i*7)
if uint64(b)&0x80 == 0 {
if n != 1 {
err = fmt.Errorf("Buffer ended prematurely for varint")
return
}
r += (uint64(b[0]) & 0x7f) << uint(i*7)
if uint64(b[0])&0x80 == 0 {
break
}
}
result = r
return
}