From 4ee46b4a26f7c64e50b6c93986b5aa0cc3e63509 Mon Sep 17 00:00:00 2001 From: Jimmy Song Date: Sat, 13 May 2017 17:36:32 -0700 Subject: [PATCH] Refactor some key stuff --- edwards25519.go | 131 +++++++++++++++++------------------ edwards25519_test.go | 48 ++++++------- internal_test.go | 8 ++- keccak_test.go | 5 +- key.go | 51 +++++--------- ringct.go | 5 -- ringsignature.go | 100 +++++++++++++------------- ringsignature_test.go | 16 ++--- ringsignature_verify_test.go | 8 +-- transaction.go | 16 ++--- transaction_test.go | 6 +- 11 files changed, 190 insertions(+), 204 deletions(-) diff --git a/edwards25519.go b/edwards25519.go index e717e38..43db761 100644 --- a/edwards25519.go +++ b/edwards25519.go @@ -25,14 +25,18 @@ var FeFffb4 = FieldElement{-21786234, -12173074, 21573800, 4524538, -4645904, 16 var FeSqrtM1 = FieldElement{-32595792, -7943725, 9377950, 3500415, 12389472, -272473, -25146209, -2005654, 326686, 11406482} /* sqrt(-1) */ var zero FieldElement +var one FieldElement -func FeZero(fe *FieldElement) { - copy(fe[:], zero[:]) +func init() { + one[0] = 1 } -func FeOne(fe *FieldElement) { - FeZero(fe) - fe[0] = 1 +func (f *FieldElement) Zero() { + copy(f[:], zero[:]) +} + +func (f *FieldElement) One() { + copy(f[:], one[:]) } func FeAdd(dst, a, b *FieldElement) { @@ -83,24 +87,17 @@ func FeCMove(f, g *FieldElement, b int32) { f[9] ^= b & (f[9] ^ g[9]) } -func load3(in []byte) int64 { - var r int64 - r = int64(in[0]) - r |= int64(in[1]) << 8 - r |= int64(in[2]) << 16 - return r +func load3(in []byte) (result int64) { + result = int64(in[0]) | (int64(in[1]) << 8) | (int64(in[2]) << 16) + return } -func load4(in []byte) int64 { - var r int64 - r = int64(in[0]) - r |= int64(in[1]) << 8 - r |= int64(in[2]) << 16 - r |= int64(in[3]) << 24 - return r +func load4(in []byte) (result int64) { + result = int64(in[0]) | (int64(in[1]) << 8) | (int64(in[2]) << 16) | (int64(in[3]) << 24) + return } -func FeFromBytes(dst *FieldElement, src *[32]byte) { +func FeFromBytes(dst *FieldElement, src *Key) { h0 := load4(src[:]) h1 := load3(src[4:]) << 6 h2 := load3(src[7:]) << 5 @@ -138,7 +135,7 @@ func FeFromBytes(dst *FieldElement, src *[32]byte) { // // Have q+2^(-255)x = 2^(-255)(h + 19 2^(-25) h9 + 2^(-1)) // so floor(2^(-255)(h + 19 2^(-25) h9 + 2^(-1))) = q. -func FeToBytes(s *[32]byte, h *FieldElement) { +func FeToBytes(s *Key, h *FieldElement) { var carry [10]int32 q := (19*h[9] + (1 << 24)) >> 25 @@ -227,14 +224,14 @@ func FeToBytes(s *[32]byte, h *FieldElement) { s[31] = byte(h[9] >> 18) } -func FeIsNegative(f *FieldElement) byte { - var s [32]byte +func (f *FieldElement) IsNegative() byte { + var s Key FeToBytes(&s, f) return s[0] & 1 } -func FeIsNonZero(f *FieldElement) int32 { - var s [32]byte +func (f *FieldElement) IsNonZero() int32 { + var s Key FeToBytes(&s, f) var x uint8 for _, b := range s { @@ -674,9 +671,9 @@ type CachedGroupElement struct { } func (p *ProjectiveGroupElement) Zero() { - FeZero(&p.X) - FeOne(&p.Y) - FeOne(&p.Z) + p.X.Zero() + p.Y.One() + p.Z.One() } func (p *ProjectiveGroupElement) Double(r *CompletedGroupElement) { @@ -693,17 +690,17 @@ func (p *ProjectiveGroupElement) Double(r *CompletedGroupElement) { FeSub(&r.T, &r.T, &r.Z) } -func (p *ProjectiveGroupElement) ToBytes(s *[32]byte) { +func (p *ProjectiveGroupElement) ToBytes(s *Key) { var recip, x, y FieldElement FeInvert(&recip, &p.Z) FeMul(&x, &p.X, &recip) FeMul(&y, &p.Y, &recip) FeToBytes(s, &y) - s[31] ^= FeIsNegative(&x) << 7 + s[31] ^= x.IsNegative() << 7 } -func (p *ProjectiveGroupElement) FromBytes(s *[32]byte) { +func (p *ProjectiveGroupElement) FromBytes(s *Key) { h0 := load4(s[:]) h1 := load3(s[4:]) << 6 h2 := load3(s[7:]) << 5 @@ -759,7 +756,7 @@ func (p *ProjectiveGroupElement) FromBytes(s *[32]byte) { u[8] = int32(h8) u[9] = int32(h9) FeSquare2(&v, &u) /* 2 * u^2 */ - FeOne(&w) + w.One() FeAdd(&w, &v, &w) /* w = 2 * u^2 + 1 */ FeSquare(&x, &w) /* w^2 */ FeMul(&y, &FeMa2, &v) /* -2 * A^2 * u^2 */ @@ -771,9 +768,9 @@ func (p *ProjectiveGroupElement) FromBytes(s *[32]byte) { FeCopy(&z, &FeMa) isNegative := false var sign byte - if FeIsNonZero(&y) != 0 { + if y.IsNonZero() != 0 { FeAdd(&y, &w, &x) - if FeIsNonZero(&y) != 0 { + if y.IsNonZero() != 0 { isNegative = true } else { FeMul(&p.X, &p.X, &FeFffb1) @@ -784,7 +781,7 @@ func (p *ProjectiveGroupElement) FromBytes(s *[32]byte) { if isNegative { FeMul(&x, &x, &FeSqrtM1) FeSub(&y, &w, &x) - if FeIsNonZero(&y) != 0 { + if y.IsNonZero() != 0 { FeAdd(&y, &w, &x) FeMul(&p.X, &p.X, &FeFffb3) } else { @@ -798,7 +795,7 @@ func (p *ProjectiveGroupElement) FromBytes(s *[32]byte) { FeMul(&z, &z, &v) /* -2 * A * u^2 */ sign = 0 } - if FeIsNegative(&p.X) != sign { + if p.X.IsNegative() != sign { FeNeg(&p.X, &p.X) } FeAdd(&p.Z, &z, &w) @@ -807,10 +804,10 @@ func (p *ProjectiveGroupElement) FromBytes(s *[32]byte) { } func (p *ExtendedGroupElement) Zero() { - FeZero(&p.X) - FeOne(&p.Y) - FeOne(&p.Z) - FeZero(&p.T) + p.X.Zero() + p.Y.One() + p.Z.One() + p.T.Zero() } func (p *ExtendedGroupElement) Double(r *CompletedGroupElement) { @@ -832,21 +829,21 @@ func (p *ExtendedGroupElement) ToProjective(r *ProjectiveGroupElement) { FeCopy(&r.Z, &p.Z) } -func (p *ExtendedGroupElement) ToBytes(s *[32]byte) { +func (p *ExtendedGroupElement) ToBytes(s *Key) { var recip, x, y FieldElement FeInvert(&recip, &p.Z) FeMul(&x, &p.X, &recip) FeMul(&y, &p.Y, &recip) FeToBytes(s, &y) - s[31] ^= FeIsNegative(&x) << 7 + s[31] ^= x.IsNegative() << 7 } -func (p *ExtendedGroupElement) FromBytes(s *[32]byte) bool { +func (p *ExtendedGroupElement) FromBytes(s *Key) bool { var u, v, v3, vxx, check FieldElement FeFromBytes(&p.Y, s) - FeOne(&p.Z) + p.Z.One() FeSquare(&u, &p.Y) FeMul(&v, &u, &d) FeSub(&u, &u, &p.Z) // y = y^2-1 @@ -862,14 +859,14 @@ func (p *ExtendedGroupElement) FromBytes(s *[32]byte) bool { FeMul(&p.X, &p.X, &v3) FeMul(&p.X, &p.X, &u) // x = uv^3(uv^7)^((q-5)/8) - var tmpX, tmp2 [32]byte + var tmpX, tmp2 Key FeSquare(&vxx, &p.X) FeMul(&vxx, &vxx, &v) FeSub(&check, &vxx, &u) // vx^2-u - if FeIsNonZero(&check) == 1 { + if check.IsNonZero() == 1 { FeAdd(&check, &vxx, &u) // vx^2+u - if FeIsNonZero(&check) == 1 { + if check.IsNonZero() == 1 { return false } FeMul(&p.X, &p.X, &SqrtM1) @@ -880,7 +877,7 @@ func (p *ExtendedGroupElement) FromBytes(s *[32]byte) bool { } } - if FeIsNegative(&p.X) != (s[31] >> 7) { + if p.X.IsNegative() != (s[31] >> 7) { FeNeg(&p.X, &p.X) } @@ -902,16 +899,16 @@ func (p *CompletedGroupElement) ToExtended(r *ExtendedGroupElement) { } func (p *PreComputedGroupElement) Zero() { - FeOne(&p.yPlusX) - FeOne(&p.yMinusX) - FeZero(&p.xy2d) + p.yPlusX.One() + p.yMinusX.One() + p.xy2d.Zero() } func (c *CachedGroupElement) Zero() { - FeOne(&c.yPlusX) - FeOne(&c.yMinusX) - FeOne(&c.Z) - FeZero(&c.T2d) + c.yPlusX.One() + c.yMinusX.One() + c.Z.One() + c.T2d.Zero() } func geAdd(r *CompletedGroupElement, p *ExtendedGroupElement, q *CachedGroupElement) { @@ -1000,7 +997,7 @@ func GePrecompute(r *[8]CachedGroupElement, s *ExtendedGroupElement) { } } -func slide(r *[256]int8, a *[32]byte) { +func slide(r *[256]int8, a *Key) { for i := range r { r[i] = int8(1 & (a[i>>3] >> uint(i&7))) } @@ -1034,7 +1031,7 @@ func slide(r *[256]int8, a *[32]byte) { // where a = a[0]+256*a[1]+...+256^31 a[31]. // and b = b[0]+256*b[1]+...+256^31 b[31]. // B is the Ed25519 base point (x,4/5) with x positive. -func GeDoubleScalarMultVartime(r *ProjectiveGroupElement, a *[32]byte, A *ExtendedGroupElement, b *[32]byte) { +func GeDoubleScalarMultVartime(r *ProjectiveGroupElement, a *Key, A *ExtendedGroupElement, b *Key) { var aSlide, bSlide [256]int8 var Ai [8]CachedGroupElement // A,3A,5A,7A,9A,11A,13A,15A var t CompletedGroupElement @@ -1079,7 +1076,7 @@ func GeDoubleScalarMultVartime(r *ProjectiveGroupElement, a *[32]byte, A *Extend // sets r = a*A + b*B // where Bi is the [8]CachedGroupElement consisting of // B,3B,5B,7B,9B,11B,13B,15B -func GeDoubleScalarMultPrecompVartime(r *ProjectiveGroupElement, a *[32]byte, A *ExtendedGroupElement, b *[32]byte, Bi *[8]CachedGroupElement) { +func GeDoubleScalarMultPrecompVartime(r *ProjectiveGroupElement, a *Key, A *ExtendedGroupElement, b *Key, Bi *[8]CachedGroupElement) { var aSlide, bSlide [256]int8 var Ai [8]CachedGroupElement // A,3A,5A,7A,9A,11A,13A,15A var t CompletedGroupElement @@ -1161,7 +1158,7 @@ func selectPoint(t *PreComputedGroupElement, pos int32, b int32) { // // Preconditions: // a[31] <= 127 -func GeScalarMult(r *ProjectiveGroupElement, a *[32]byte, A *ExtendedGroupElement) { +func GeScalarMult(r *ProjectiveGroupElement, a *Key, A *ExtendedGroupElement) { var e [64]int32 var carry, carry2 int32 for i := 0; i < 31; i++ { @@ -1220,7 +1217,7 @@ func GeScalarMult(r *ProjectiveGroupElement, a *[32]byte, A *ExtendedGroupElemen // // Preconditions: // a[31] <= 127 -func GeScalarMultBase(h *ExtendedGroupElement, a *[32]byte) { +func GeScalarMultBase(h *ExtendedGroupElement, a *Key) { var e [64]int8 for i, v := range a { @@ -1266,7 +1263,7 @@ func GeScalarMultBase(h *ExtendedGroupElement, a *[32]byte) { } } -func ScAdd(s, a, b *[32]byte) { +func ScAdd(s, a, b *Key) { a0 := 2097151 & load3(a[:]) a1 := 2097151 & (load4(a[2:]) >> 5) a2 := 2097151 & (load3(a[5:]) >> 2) @@ -1464,7 +1461,7 @@ func ScAdd(s, a, b *[32]byte) { s[31] = byte(s11 >> 17) } -func ScSub(s, a, b *[32]byte) { +func ScSub(s, a, b *Key) { a0 := 2097151 & load3(a[:]) a1 := 2097151 & (load4(a[2:]) >> 5) a2 := 2097151 & (load3(a[5:]) >> 2) @@ -1666,7 +1663,7 @@ func signum(a int64) int64 { return a>>63 - ((-a) >> 63) } -func ScValid(s *[32]byte) bool { +func ScValid(s *Key) bool { s0 := load4(s[:]) s1 := load4(s[4:]) s2 := load4(s[8:]) @@ -1679,7 +1676,7 @@ func ScValid(s *[32]byte) bool { } -func ScIsZero(s *[32]byte) bool { +func ScIsZero(s *Key) bool { return ((int(s[0]|s[1]|s[2]|s[3]|s[4]|s[5]|s[6]|s[7]|s[8]| s[9]|s[10]|s[11]|s[12]|s[13]|s[14]|s[15]|s[16]|s[17]| s[18]|s[19]|s[20]|s[21]|s[22]|s[23]|s[24]|s[25]|s[26]| @@ -1696,7 +1693,7 @@ func ScIsZero(s *[32]byte) bool { // Output: // s[0]+256*s[1]+...+256^31*s[31] = (ab+c) mod l // where l = 2^252 + 27742317777372353535851937790883648493. -func ScMulAdd(s, a, b, c *[32]byte) { +func ScMulAdd(s, a, b, c *Key) { a0 := 2097151 & load3(a[:]) a1 := 2097151 & (load4(a[2:]) >> 5) a2 := 2097151 & (load3(a[5:]) >> 2) @@ -2129,7 +2126,7 @@ func ScMulAdd(s, a, b, c *[32]byte) { // Output: // s[0]+256*s[1]+...+256^31*s[31] = (c-ab) mod l // where l = 2^252 + 27742317777372353535851937790883648493. -func ScMulSub(s, a, b, c *[32]byte) { +func ScMulSub(s, a, b, c *Key) { a0 := 2097151 & load3(a[:]) a1 := 2097151 & (load4(a[2:]) >> 5) a2 := 2097151 & (load3(a[5:]) >> 2) @@ -2560,7 +2557,7 @@ func ScMulSub(s, a, b, c *[32]byte) { // Output: // s[0]+256*s[1]+...+256^31*s[31] = s mod l // where l = 2^252 + 27742317777372353535851937790883648493. -func ScReduce(out *[32]byte, s *[64]byte) { +func ScReduce(out *Key, s *[64]byte) { s0 := 2097151 & load3(s[:]) s1 := 2097151 & (load4(s[2:]) >> 5) s2 := 2097151 & (load3(s[5:]) >> 2) @@ -2878,7 +2875,7 @@ func ScReduce(out *[32]byte, s *[64]byte) { out[31] = byte(s11 >> 17) } -func ScReduce32(s *[32]byte) { +func ScReduce32(s *Key) { s0 := 2097151 & load3(s[:]) s1 := 2097151 & (load4(s[2:]) >> 5) s2 := 2097151 & (load3(s[5:]) >> 2) diff --git a/edwards25519_test.go b/edwards25519_test.go index 94afc1d..1de33dc 100644 --- a/edwards25519_test.go +++ b/edwards25519_test.go @@ -35,11 +35,11 @@ func TestScMulSub(t *testing.T) { }, } for _, test := range tests { - a := HexToBytes(test.aHex) - b := HexToBytes(test.bHex) - c := HexToBytes(test.cHex) - want := HexToBytes(test.wantHex) - var got [32]byte + a := HexToKey(test.aHex) + b := HexToKey(test.bHex) + c := HexToKey(test.cHex) + want := HexToKey(test.wantHex) + var got Key ScMulSub(&got, &a, &b, &c) if want != got { t.Errorf("%s: want %x, got %x", test.name, want, got) @@ -80,14 +80,14 @@ func TestScalarMult(t *testing.T) { }, } for _, test := range tests { - scalarBytes := HexToBytes(test.scalarHex) - pointBytes := HexToBytes(test.pointHex) - want := HexToBytes(test.wantHex) + scalarBytes := HexToKey(test.scalarHex) + pointBytes := HexToKey(test.pointHex) + want := HexToKey(test.wantHex) point := new(ExtendedGroupElement) point.FromBytes(&pointBytes) result := new(ProjectiveGroupElement) GeScalarMult(result, &scalarBytes, point) - var got [32]byte + var got Key result.ToBytes(&got) if want != got { t.Errorf("%s: want %x, got %x", test.name, want, got) @@ -113,15 +113,15 @@ func TestGeMul8(t *testing.T) { }, } for _, test := range tests { - pointBytes := HexToBytes(test.pointHex) - want := HexToBytes(test.wantHex) + pointBytes := HexToKey(test.pointHex) + want := HexToKey(test.wantHex) tmp := new(ExtendedGroupElement) tmp.FromBytes(&pointBytes) point := new(ProjectiveGroupElement) tmp.ToProjective(point) tmp2 := new(CompletedGroupElement) result := new(ExtendedGroupElement) - var got [32]byte + var got Key GeMul8(tmp2, point) tmp2.ToExtended(result) result.ToBytes(&got) @@ -169,15 +169,15 @@ func TestGeDoubleScalarMultVartime(t *testing.T) { }, } for _, test := range tests { - pointBytes := HexToBytes(test.pointHex) - a := HexToBytes(test.scalar1Hex) - b := HexToBytes(test.scalar2Hex) - want := HexToBytes(test.wantHex) + pointBytes := HexToKey(test.pointHex) + a := HexToKey(test.scalar1Hex) + b := HexToKey(test.scalar2Hex) + want := HexToKey(test.wantHex) point := new(ExtendedGroupElement) point.FromBytes(&pointBytes) result := new(ProjectiveGroupElement) GeDoubleScalarMultVartime(result, &a, point, &b) - var got [32]byte + var got Key result.ToBytes(&got) if want != got { t.Errorf("%s: want %x, got %x", test.name, want, got) @@ -228,11 +228,11 @@ func TestGeDoubleScalarMultPrecompVartime(t *testing.T) { }, } for _, test := range tests { - point1Bytes := HexToBytes(test.point1Hex) - point2Bytes := HexToBytes(test.point2Hex) - a := HexToBytes(test.scalar1Hex) - b := HexToBytes(test.scalar2Hex) - want := HexToBytes(test.wantHex) + point1Bytes := HexToKey(test.point1Hex) + point2Bytes := HexToKey(test.point2Hex) + a := HexToKey(test.scalar1Hex) + b := HexToKey(test.scalar2Hex) + want := HexToKey(test.wantHex) point1 := new(ExtendedGroupElement) point1.FromBytes(&point1Bytes) point2 := new(ExtendedGroupElement) @@ -241,7 +241,7 @@ func TestGeDoubleScalarMultPrecompVartime(t *testing.T) { GePrecompute(&point2Precomp, point2) result := new(ProjectiveGroupElement) GeDoubleScalarMultPrecompVartime(result, &a, point1, &b, &point2Precomp) - var got [32]byte + var got Key result.ToBytes(&got) if want != got { t.Errorf("%s: want %x, got %x", test.name, want, got) @@ -1605,7 +1605,7 @@ func TestScValid(t *testing.T) { }, } for _, test := range tests { - scalar := HexToBytes(test.scalarHex) + scalar := HexToKey(test.scalarHex) got := ScValid(&scalar) if test.valid != got { t.Errorf("%x: want %t, got %t", scalar, test.valid, got) diff --git a/internal_test.go b/internal_test.go index 23c87c3..d33debc 100644 --- a/internal_test.go +++ b/internal_test.go @@ -4,7 +4,13 @@ import ( "encoding/hex" ) -func HexToBytes(h string) (result [32]byte) { +func HexToKey(h string) (result Key) { + byteSlice, _ := hex.DecodeString(h) + copy(result[:], byteSlice) + return +} + +func HexToHash(h string) (result Hash) { byteSlice, _ := hex.DecodeString(h) copy(result[:], byteSlice) return diff --git a/keccak_test.go b/keccak_test.go index 30f6e98..f2acf02 100644 --- a/keccak_test.go +++ b/keccak_test.go @@ -1,7 +1,6 @@ package moneroutil import ( - "bytes" "encoding/hex" "testing" ) @@ -36,8 +35,8 @@ func TestKeccak256(t *testing.T) { for _, test := range tests { message, _ := hex.DecodeString(test.messageHex) got := Keccak256(message) - want := HexToBytes(test.wantHex) - if bytes.Compare(want[:], got[:]) != 0 { + want := HexToHash(test.wantHex) + if want != got { t.Errorf("want %x, got %x", want, got) } } diff --git a/key.go b/key.go index f0c7041..08a8a6d 100644 --- a/key.go +++ b/key.go @@ -5,66 +5,53 @@ import ( ) const ( - PointLength = 32 - ScalarLength = 32 + KeyLength = 32 ) -type PubKey [PointLength]byte -type PrivKey [ScalarLength]byte +// Key can be a Scalar or a Point +type Key [KeyLength]byte -func (p *PrivKey) FromBytes(b [ScalarLength]byte) { +func (p *Key) FromBytes(b [KeyLength]byte) { *p = b } -func (p *PrivKey) ToBytes() (result [ScalarLength]byte) { - result = [32]byte(*p) +func (p *Key) ToBytes() (result [KeyLength]byte) { + result = [KeyLength]byte(*p) return } -func (p *PrivKey) PubKey() (pubKey *PubKey) { - secret := p.ToBytes() +func (p *Key) PubKey() (pubKey *Key) { point := new(ExtendedGroupElement) - GeScalarMultBase(point, &secret) - pubKeyBytes := new([PointLength]byte) - point.ToBytes(pubKeyBytes) - pubKey = (*PubKey)(pubKeyBytes) - return -} - -func (p *PubKey) FromBytes(b [PointLength]byte) { - *p = b -} - -func (p *PubKey) ToBytes() (result [PointLength]byte) { - result = [PointLength]byte(*p) + GeScalarMultBase(point, p) + pubKey = new(Key) + point.ToBytes(pubKey) return } // Creates a point on the Edwards Curve by hashing the key -func (p *PubKey) HashToEC() (result *ExtendedGroupElement) { +func (p *Key) HashToEC() (result *ExtendedGroupElement) { result = new(ExtendedGroupElement) var p1 ProjectiveGroupElement var p2 CompletedGroupElement - h := [PointLength]byte(Keccak256(p[:])) + h := Key(Keccak256(p[:])) p1.FromBytes(&h) GeMul8(&p2, &p1) p2.ToExtended(result) return } -func RandomScalar() (result [ScalarLength]byte) { - var reduceFrom [ScalarLength * 2]byte - tmp := make([]byte, ScalarLength*2) +func RandomScalar() (result *Key) { + result = new(Key) + var reduceFrom [KeyLength * 2]byte + tmp := make([]byte, KeyLength*2) rand.Read(tmp) copy(reduceFrom[:], tmp) - ScReduce(&result, &reduceFrom) + ScReduce(result, &reduceFrom) return } -func NewKeyPair() (privKey *PrivKey, pubKey *PubKey) { - privKey = new(PrivKey) - pubKey = new(PubKey) - privKey.FromBytes(RandomScalar()) +func NewKeyPair() (privKey *Key, pubKey *Key) { + privKey = RandomScalar() pubKey = privKey.PubKey() return } diff --git a/ringct.go b/ringct.go index a9e91e1..5a5f0e8 100644 --- a/ringct.go +++ b/ringct.go @@ -9,13 +9,8 @@ const ( RCTTypeNull = iota RCTTypeFull RCTTypeSimple - - KeyLength = 32 ) -// Key for Confidential Transactions, can be private or public -type Key [KeyLength]byte - // V = Vector, M = Matrix type KeyV []Key type KeyM []KeyV diff --git a/ringsignature.go b/ringsignature.go index f1d1178..f5bdafd 100644 --- a/ringsignature.go +++ b/ringsignature.go @@ -7,45 +7,53 @@ import ( ) type RingSignatureElement struct { - c [ScalarLength]byte - r [ScalarLength]byte + c *Key + r *Key } type RingSignature []*RingSignatureElement func (r *RingSignatureElement) Serialize() (result []byte) { - result = make([]byte, 2*ScalarLength) - copy(result[:ScalarLength], r.c[:]) - copy(result[ScalarLength:2*ScalarLength], r.r[:]) + result = make([]byte, 2*KeyLength) + copy(result[:KeyLength], r.c[:]) + copy(result[KeyLength:2*KeyLength], r.r[:]) return } func (r *RingSignature) Serialize() (result []byte) { - result = make([]byte, len(*r)*ScalarLength*2) + result = make([]byte, len(*r)*KeyLength*2) for i := 0; i < len(*r); i++ { - copy(result[i*ScalarLength*2:(i+1)*ScalarLength*2], (*r)[i].Serialize()) + copy(result[i*KeyLength*2:(i+1)*KeyLength*2], (*r)[i].Serialize()) + } + return +} + +func NewRingSignatureElement() (r *RingSignatureElement) { + r = &RingSignatureElement{ + c: new(Key), + r: new(Key), } return } func ParseSignature(buf io.Reader) (result *RingSignatureElement, err error) { - rse := new(RingSignatureElement) - c := make([]byte, ScalarLength) + rse := NewRingSignatureElement() + c := make([]byte, KeyLength) n, err := buf.Read(c) if err != nil { return } - if n != ScalarLength { + if n != KeyLength { err = fmt.Errorf("Not enough bytes for signature c") return } copy(rse.c[:], c) - r := make([]byte, ScalarLength) + r := make([]byte, KeyLength) n, err = buf.Read(r) if err != nil { return } - if n != ScalarLength { + if n != KeyLength { err = fmt.Errorf("Not enough bytes for signature r") return } @@ -70,43 +78,41 @@ func ParseSignatures(mixinLengths []int, buf io.Reader) (signatures []RingSignat return } -func HashToScalar(data ...[]byte) (result [ScalarLength]byte) { - result = Keccak256(data...) - ScReduce32(&result) +func HashToScalar(data ...[]byte) (result *Key) { + result = new(Key) + *result = Key(Keccak256(data...)) + ScReduce32(result) return } -func CreateSignature(prefixHash *Hash, mixins []PubKey, privKey *PrivKey) (keyImage PubKey, pubKeys []PubKey, sig RingSignature) { +func CreateSignature(prefixHash *Hash, mixins []Key, privKey *Key) (keyImage Key, pubKeys []Key, sig RingSignature) { point := privKey.PubKey().HashToEC() - privKeyBytes := privKey.ToBytes() keyImagePoint := new(ProjectiveGroupElement) - GeScalarMult(keyImagePoint, &privKeyBytes, point) - var keyImageBytes [PointLength]byte + GeScalarMult(keyImagePoint, privKey, point) // convert key Image point from Projective to Extended // in order to precompute - keyImagePoint.ToBytes(&keyImageBytes) + keyImagePoint.ToBytes(&keyImage) keyImageGe := new(ExtendedGroupElement) - keyImageGe.FromBytes(&keyImageBytes) - keyImage = PubKey(keyImageBytes) + keyImageGe.FromBytes(&keyImage) var keyImagePre [8]CachedGroupElement GePrecompute(&keyImagePre, keyImageGe) k := RandomScalar() - pubKeys = make([]PubKey, len(mixins)+1) + pubKeys = make([]Key, len(mixins)+1) privIndex := rand.Intn(len(pubKeys)) pubKeys[privIndex] = *privKey.PubKey() r := make([]*RingSignatureElement, len(pubKeys)) - var sum [ScalarLength]byte + sum := new(Key) toHash := prefixHash[:] for i := 0; i < len(pubKeys); i++ { tmpE := new(ExtendedGroupElement) tmpP := new(ProjectiveGroupElement) - var tmpEBytes, tmpPBytes [PointLength]byte + var tmpEBytes, tmpPBytes Key if i == privIndex { - GeScalarMultBase(tmpE, &k) + GeScalarMultBase(tmpE, k) tmpE.ToBytes(&tmpEBytes) toHash = append(toHash, tmpEBytes[:]...) tmpE = privKey.PubKey().HashToEC() - GeScalarMult(tmpP, &k, tmpE) + GeScalarMult(tmpP, k, tmpE) tmpP.ToBytes(&tmpPBytes) toHash = append(toHash, tmpPBytes[:]...) } else { @@ -119,64 +125,60 @@ func CreateSignature(prefixHash *Hash, mixins []PubKey, privKey *PrivKey) (keyIm c: RandomScalar(), r: RandomScalar(), } - pubKeyBytes := pubKeys[i].ToBytes() - tmpE.FromBytes(&pubKeyBytes) - GeDoubleScalarMultVartime(tmpP, &r[i].c, tmpE, &r[i].r) + tmpE.FromBytes(&pubKeys[i]) + GeDoubleScalarMultVartime(tmpP, r[i].c, tmpE, r[i].r) tmpP.ToBytes(&tmpPBytes) toHash = append(toHash, tmpPBytes[:]...) tmpE = pubKeys[i].HashToEC() - GeDoubleScalarMultPrecompVartime(tmpP, &r[i].r, tmpE, &r[i].c, &keyImagePre) + GeDoubleScalarMultPrecompVartime(tmpP, r[i].r, tmpE, r[i].c, &keyImagePre) tmpP.ToBytes(&tmpPBytes) toHash = append(toHash, tmpPBytes[:]...) - ScAdd(&sum, &sum, &r[i].c) + ScAdd(sum, sum, r[i].c) } } h := HashToScalar(toHash) - r[privIndex] = new(RingSignatureElement) - ScSub(&r[privIndex].c, &h, &sum) - scalar := privKey.ToBytes() - ScMulSub(&r[privIndex].r, &r[privIndex].c, &scalar, &k) + r[privIndex] = NewRingSignatureElement() + ScSub(r[privIndex].c, h, sum) + ScMulSub(r[privIndex].r, r[privIndex].c, privKey, k) sig = r return } -func VerifySignature(prefixHash *Hash, keyImage *PubKey, pubKeys []PubKey, ringSignature RingSignature) (result bool) { +func VerifySignature(prefixHash *Hash, keyImage *Key, pubKeys []Key, ringSignature RingSignature) (result bool) { keyImageGe := new(ExtendedGroupElement) - keyImageBytes := [PointLength]byte(*keyImage) - if !keyImageGe.FromBytes(&keyImageBytes) { + if !keyImageGe.FromBytes(keyImage) { result = false return } var keyImagePre [8]CachedGroupElement GePrecompute(&keyImagePre, keyImageGe) toHash := prefixHash[:] - var tmpS, sum [ScalarLength]byte + tmpS, sum := new(Key), new(Key) for i, pubKey := range pubKeys { rse := ringSignature[i] - if !ScValid(&rse.c) || !ScValid(&rse.r) { + if !ScValid(rse.c) || !ScValid(rse.r) { result = false return } tmpE := new(ExtendedGroupElement) tmpP := new(ProjectiveGroupElement) - pubKeyBytes := [PointLength]byte(pubKey) - if !tmpE.FromBytes(&pubKeyBytes) { + if !tmpE.FromBytes(&pubKey) { result = false return } - var tmpPBytes, tmpEBytes [PointLength]byte - GeDoubleScalarMultVartime(tmpP, &rse.c, tmpE, &rse.r) + var tmpPBytes, tmpEBytes Key + GeDoubleScalarMultVartime(tmpP, rse.c, tmpE, rse.r) tmpP.ToBytes(&tmpPBytes) toHash = append(toHash, tmpPBytes[:]...) tmpE = pubKey.HashToEC() tmpE.ToBytes(&tmpEBytes) - GeDoubleScalarMultPrecompVartime(tmpP, &rse.r, tmpE, &rse.c, &keyImagePre) + GeDoubleScalarMultPrecompVartime(tmpP, rse.r, tmpE, rse.c, &keyImagePre) tmpP.ToBytes(&tmpPBytes) toHash = append(toHash, tmpPBytes[:]...) - ScAdd(&sum, &sum, &rse.c) + ScAdd(sum, sum, rse.c) } tmpS = HashToScalar(toHash) - ScSub(&sum, &tmpS, &sum) - result = ScIsZero(&sum) + ScSub(sum, tmpS, sum) + result = ScIsZero(sum) return } diff --git a/ringsignature_test.go b/ringsignature_test.go index 6b79cbc..ead55f0 100644 --- a/ringsignature_test.go +++ b/ringsignature_test.go @@ -759,9 +759,9 @@ func TestHashToScalar(t *testing.T) { } for _, test := range tests { toHash, _ := hex.DecodeString(test.hashHex) - want := HexToBytes(test.scalarHex) + want := HexToKey(test.scalarHex) got := HashToScalar(toHash) - if want != got { + if want != *got { t.Errorf("%x, want %x, got %x", toHash, want, got) } } @@ -1798,12 +1798,12 @@ func TestHashToEC(t *testing.T) { }, } for _, test := range tests { - pubkeyBytes := HexToBytes(test.pubkeyHex) - pubKey := new(PubKey) + pubkeyBytes := HexToKey(test.pubkeyHex) + pubKey := new(Key) pubKey.FromBytes(pubkeyBytes) - want := HexToBytes(test.extendedHex) + want := HexToKey(test.extendedHex) ecPoint := pubKey.HashToEC() - var got [32]byte + var got Key ecPoint.ToBytes(&got) if want != got { t.Errorf("%x: want %x, got %x", pubkeyBytes, want, got) @@ -1815,9 +1815,9 @@ func TestCreateSignature(t *testing.T) { numTries := 50 numMixins := 10 for i := 0; i < numTries; i++ { - hash := Hash(RandomScalar()) + hash := Hash(*RandomScalar()) privKey, _ := NewKeyPair() - mixins := make([]PubKey, numMixins) + mixins := make([]Key, numMixins) for j := 0; j < numMixins; j++ { _, pk := NewKeyPair() mixins[j] = *pk diff --git a/ringsignature_verify_test.go b/ringsignature_verify_test.go index 1e8182d..233a573 100644 --- a/ringsignature_verify_test.go +++ b/ringsignature_verify_test.go @@ -40763,11 +40763,11 @@ func TestVerifySignature(t *testing.T) { }, } for j, test := range tests { - prefixHash := Hash(HexToBytes(test.prefixHashHex)) - keyImage := PubKey(HexToBytes(test.keyImageHex)) - pubKeys := make([]PubKey, len(test.pubKeys)) + prefixHash := HexToHash(test.prefixHashHex) + keyImage := HexToKey(test.keyImageHex) + pubKeys := make([]Key, len(test.pubKeys)) for i, pubKeyHex := range test.pubKeys { - pubKeys[i] = PubKey(HexToBytes(pubKeyHex)) + pubKeys[i] = HexToKey(pubKeyHex) } ringSignature := make([]*RingSignatureElement, len(test.pubKeys)) ringSignatureBytes, _ := hex.DecodeString(test.ringSignature) diff --git a/transaction.go b/transaction.go index 434c7fe..2df1c4a 100644 --- a/transaction.go +++ b/transaction.go @@ -18,7 +18,7 @@ const ( var UnimplementedError = fmt.Errorf("Unimplemented") type txOutToScript struct { - pubKeys []PubKey + pubKeys []Key script []byte } @@ -27,7 +27,7 @@ type txOutToScriptHash struct { } type txOutToKey struct { - key PubKey + key Key } type TxOutTargetSerializer interface { @@ -55,7 +55,7 @@ type txInToScriptHash struct { type txInToKey struct { amount uint64 keyOffsets []uint64 - keyImage PubKey + keyImage Key } type TxInSerializer interface { @@ -87,7 +87,7 @@ func (h *Hash) Serialize() (result []byte) { return } -func (p *PubKey) Serialize() (result []byte) { +func (p *Key) Serialize() (result []byte) { result = p[:] return } @@ -281,12 +281,12 @@ func ParseTxInToKey(buf io.Reader) (txIn *txInToKey, err error) { return } } - pubKey := make([]byte, PointLength) + pubKey := make([]byte, KeyLength) n, err := buf.Read(pubKey) if err != nil { return } - if n != PointLength { + if n != KeyLength { err = fmt.Errorf("Buffer not long enough for public key") return } @@ -330,12 +330,12 @@ func ParseTxOutToScriptHash(buf io.Reader) (txOutTarget *txOutToScriptHash, err func ParseTxOutToKey(buf io.Reader) (txOutTarget *txOutToKey, err error) { t := new(txOutToKey) - pubKey := make([]byte, PointLength) + pubKey := make([]byte, KeyLength) n, err := buf.Read(pubKey) if err != nil { return } - if n != PointLength { + if n != KeyLength { err = fmt.Errorf("Buffer not long enough for public key") return } diff --git a/transaction_test.go b/transaction_test.go index 1fa2533..8ef11c5 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -113,7 +113,7 @@ func TestCoinbaseTransaction(t *testing.T) { if err != nil { t.Errorf("%s: error parsing tx: %s", test.name, err) } - wantHash := HexToBytes(test.hashHex) + wantHash := HexToHash(test.hashHex) gotHash := transaction.GetHash() if wantHash != gotHash { t.Errorf("%s: want %x, got %x", test.name, wantHash, gotHash) @@ -331,13 +331,13 @@ func TestTransaction(t *testing.T) { if err != nil { t.Errorf("%s: error parsing tx: %s", test.name, err) } - wantHash := HexToBytes(test.hashHex) + wantHash := HexToHash(test.hashHex) gotHash := transaction.GetHash() if wantHash != gotHash { t.Errorf("%s: want %x, got %x", test.name, wantHash, gotHash) } gotHash = transaction.PrefixHash() - wantHash = HexToBytes(test.prefixHashHex) + wantHash = HexToHash(test.prefixHashHex) if wantHash != gotHash { t.Errorf("%s: want %x, got %x", test.name, wantHash, gotHash) }