all: sync with landed standard library upstream

This commit is contained in:
Filippo Valsorda 2021-05-26 18:06:35 +02:00
parent 8e7780424d
commit dd0c73fa20
28 changed files with 840 additions and 802 deletions

20
doc.go Normal file
View file

@ -0,0 +1,20 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package edwards25519 implements group logic for the twisted Edwards curve
//
// -x^2 + y^2 = 1 + -(121665/121666)*x^2*y^2
//
// This is better known as the Edwards curve equivalent to Curve25519, and is
// the curve used by the Ed25519 signature scheme.
//
// Most users don't need this package, and should instead use crypto/ed25519 for
// signatures, golang.org/x/crypto/curve25519 for Diffie-Hellman, or
// github.com/gtank/ristretto255 for prime order group logic.
//
// However, developers who do need to interact with low-level edwards25519
// operations can use this package, which is an extended version of
// crypto/ed25519/internal/edwards25519 from the standard library repackaged as
// an importable module.
package edwards25519

View file

@ -2,31 +2,22 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package edwards25519 implements group logic for the twisted Edwards curve
//
// -x^2 + y^2 = 1 + -(121665/121666)*x^2*y^2
//
// This is better known as the Edwards curve equivalent to Curve25519, and is
// the curve used by the Ed25519 signature scheme.
//
// Most users don't need this package, and should instead use crypto/ed25519 for
// signatures, golang.org/x/crypto/curve25519 for Diffie-Hellman, or
// github.com/gtank/ristretto255 for prime order group logic. However, for
// anyone currently using a fork of crypto/ed25519/internal/edwards25519 or
// github.com/agl/edwards25519, this package should be a safer, faster, and more
// powerful alternative.
package edwards25519
import "errors"
import (
"errors"
"filippo.io/edwards25519/field"
)
// Point types.
type projP1xP1 struct {
X, Y, Z, T fieldElement
X, Y, Z, T field.Element
}
type projP2 struct {
X, Y, Z fieldElement
X, Y, Z field.Element
}
// Point represents a point on the edwards25519 curve.
@ -38,27 +29,29 @@ type projP2 struct {
type Point struct {
// The point is internally represented in extended coordinates (X, Y, Z, T)
// where x = X/Z, y = Y/Z, and xy = T/Z per https://eprint.iacr.org/2008/522.
x, y, z, t fieldElement
x, y, z, t field.Element
// Make the type not comparable with bradfitz's device, since equal points
// can be represented by different Go values.
_ [0]func()
// Make the type not comparable (i.e. used with == or as a map key), as
// equivalent points can be represented by different Go values.
_ incomparable
}
type incomparable [0]func()
func checkInitialized(points ...*Point) {
for _, p := range points {
if p.x == (fieldElement{}) && p.y == (fieldElement{}) {
if p.x == (field.Element{}) && p.y == (field.Element{}) {
panic("edwards25519: use of uninitialized Point")
}
}
}
type projCached struct {
YplusX, YminusX, Z, T2d fieldElement
YplusX, YminusX, Z, T2d field.Element
}
type affineCached struct {
YplusX, YminusX, T2d fieldElement
YplusX, YminusX, T2d field.Element
}
// Constructors.
@ -70,27 +63,27 @@ func (v *projP2) Zero() *projP2 {
return v
}
// identity is the point at infinity.
var identity, _ = new(Point).SetBytes([]byte{
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
// NewIdentityPoint returns a new Point set to the identity.
func NewIdentityPoint() *Point {
return &Point{
x: fieldElement{0, 0, 0, 0, 0},
y: fieldElement{1, 0, 0, 0, 0},
z: fieldElement{1, 0, 0, 0, 0},
t: fieldElement{0, 0, 0, 0, 0},
}
return new(Point).Set(identity)
}
// generator is the canonical curve basepoint. See TestGenerator for the
// correspondence of this encoding with the values in RFC 8032.
var generator, _ = new(Point).SetBytes([]byte{
0x58, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66})
// NewGeneratorPoint returns a new Point set to the canonical generator.
func NewGeneratorPoint() *Point {
return &Point{
x: fieldElement{1738742601995546, 1146398526822698,
2070867633025821, 562264141797630, 587772402128613},
y: fieldElement{1801439850948184, 1351079888211148,
450359962737049, 900719925474099, 1801439850948198},
z: fieldElement{1, 0, 0, 0, 0},
t: fieldElement{1841354044333475, 16398895984059,
755974180946558, 900171276175154, 1821297809914039},
}
return new(Point).Set(generator)
}
func (v *projCached) Zero() *projCached {
@ -118,7 +111,7 @@ func (v *Point) Set(u *Point) *Point {
// Encoding.
// Bytes returns the canonical 32 bytes encoding of v, according to RFC 8032,
// Bytes returns the canonical 32-byte encoding of v, according to RFC 8032,
// Section 5.1.2.
func (v *Point) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
@ -130,17 +123,19 @@ func (v *Point) Bytes() []byte {
func (v *Point) bytes(buf *[32]byte) []byte {
checkInitialized(v)
var recip, x, y fieldElement
recip.Invert(&v.z)
x.Multiply(&v.x, &recip) // x = X / Z
y.Multiply(&v.y, &recip) // y = Y / Z
var zInv, x, y field.Element
zInv.Invert(&v.z) // zInv = 1 / Z
x.Multiply(&v.x, &zInv) // x = X / Z
y.Multiply(&v.y, &zInv) // y = Y / Z
out := copyFieldElement(buf, &y)
out[31] |= byte(x.IsNegative() << 7)
return out
}
// SetBytes sets v = x, where x is a 32 bytes encoding of v. If x does not
var feOne = new(field.Element).One()
// SetBytes sets v = x, where x is a 32-byte encoding of v. If x does not
// represent a valid point on the curve, SetBytes returns nil and an error and
// the receiver is unchanged. Otherwise, SetBytes returns v.
//
@ -150,7 +145,7 @@ func (v *Point) bytes(buf *[32]byte) []byte {
func (v *Point) SetBytes(x []byte) (*Point, error) {
// Specifically, the non-canonical encodings that are accepted are
// 1) the ones where the field element is not reduced (see the
// (*fieldElement).SetBytes docs) and
// (*field.Element).SetBytes docs) and
// 2) the ones where the x-coordinate is zero and the sign bit is set.
//
// This is consistent with crypto/ed25519/internal/edwards25519. Read more
@ -160,28 +155,29 @@ func (v *Point) SetBytes(x []byte) (*Point, error) {
if len(x) != 32 {
return nil, errors.New("edwards25519: invalid point encoding length")
}
y := (&fieldElement{}).SetBytes(x)
y := new(field.Element).SetBytes(x)
// -x² + y² = 1 + dx²y²
// x² + dx²y² = x²(dy² + 1) = y² - 1
// x² = (y² - 1) / (dy² + 1)
// u = y² - 1
y2 := (&fieldElement{}).Square(y)
u := (&fieldElement{}).Subtract(y2, feOne)
y2 := new(field.Element).Square(y)
u := new(field.Element).Subtract(y2, feOne)
// v = dy² + 1
vv := (&fieldElement{}).Multiply(y2, d)
vv := new(field.Element).Multiply(y2, d)
vv = vv.Add(vv, feOne)
// x = +√(u/v)
xx, wasSquare := (&fieldElement{}).SqrtRatio(u, vv)
xx, wasSquare := new(field.Element).SqrtRatio(u, vv)
if wasSquare == 0 {
return nil, errors.New("edwards25519: invalid point encoding")
}
// Select the negative square root if the sign bit is set.
xx = xx.condNeg(xx, int(x[31]>>7))
xxNeg := new(field.Element).Negate(xx)
xx = xx.Select(xxNeg, xx, int(x[31]>>7))
v.x.Set(xx)
v.y.Set(y)
@ -191,42 +187,8 @@ func (v *Point) SetBytes(x []byte) (*Point, error) {
return v, nil
}
// BytesMontgomery converts v to a point on the birationally-equivalent
// Curve25519 Montgomery curve, and returns its canonical 32 bytes encoding
// according to RFC 7748.
//
// Note that BytesMontgomery only encodes the u-coordinate, so v and -v encode
// to the same value. If v is the identity point, BytesMontgomery returns 32
// zero bytes, analogously to the X25519 function.
func (v *Point) BytesMontgomery() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var buf [32]byte
return v.bytesMontgomery(&buf)
}
func (v *Point) bytesMontgomery(buf *[32]byte) []byte {
checkInitialized(v)
// RFC 7748, Section 4.1 provides the bilinear map to calculate the
// Montgomery u-coordinate
//
// u = (1 + y) / (1 - y)
//
// where y = Y / Z.
var y, recip, u fieldElement
y.Multiply(&v.y, y.Invert(&v.z)) // y = Y / Z
recip.Invert(recip.Subtract(feOne, &y)) // r = 1/(1 - y)
u.Multiply(u.Add(feOne, &y), &recip) // u = (1 + y)*r
return copyFieldElement(buf, &u)
}
func copyFieldElement(buf *[32]byte, v *fieldElement) []byte {
out := v.Bytes()
copy(buf[:], out)
func copyFieldElement(buf *[32]byte, v *field.Element) []byte {
copy(buf[:], v.Bytes())
return buf[:]
}
@ -263,9 +225,12 @@ func (v *Point) fromP2(p *projP2) *Point {
}
// d is a constant in the curve equation.
var d = &fieldElement{929955233495203, 466365720129213,
1662059464998953, 2033849074728123, 1442794654840575}
var d2 = new(fieldElement).Add(d, d)
var d = new(field.Element).SetBytes([]byte{
0xa3, 0x78, 0x59, 0x13, 0xca, 0x4d, 0xeb, 0x75,
0xab, 0xd8, 0x41, 0x41, 0x4d, 0x0a, 0x70, 0x00,
0x98, 0xe8, 0x79, 0x77, 0x79, 0x40, 0xc7, 0x8c,
0x73, 0xfe, 0x6f, 0x2b, 0xee, 0x6c, 0x03, 0x52})
var d2 = new(field.Element).Add(d, d)
func (v *projCached) FromP3(p *Point) *projCached {
v.YplusX.Add(&p.y, &p.x)
@ -280,7 +245,7 @@ func (v *affineCached) FromP3(p *Point) *affineCached {
v.YminusX.Subtract(&p.y, &p.x)
v.T2d.Multiply(&p.t, d2)
var invZ fieldElement
var invZ field.Element
invZ.Invert(&p.z)
v.YplusX.Multiply(&v.YplusX, &invZ)
v.YminusX.Multiply(&v.YminusX, &invZ)
@ -293,21 +258,21 @@ func (v *affineCached) FromP3(p *Point) *affineCached {
// Add sets v = p + q, and returns v.
func (v *Point) Add(p, q *Point) *Point {
checkInitialized(p, q)
qCached := (&projCached{}).FromP3(q)
result := (&projP1xP1{}).Add(p, qCached)
qCached := new(projCached).FromP3(q)
result := new(projP1xP1).Add(p, qCached)
return v.fromP1xP1(result)
}
// Subtract sets v = p - q, and returns v.
func (v *Point) Subtract(p, q *Point) *Point {
checkInitialized(p, q)
qCached := (&projCached{}).FromP3(q)
result := (&projP1xP1{}).Sub(p, qCached)
qCached := new(projCached).FromP3(q)
result := new(projP1xP1).Sub(p, qCached)
return v.fromP1xP1(result)
}
func (v *projP1xP1) Add(p *Point, q *projCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, ZZ2 fieldElement
var YplusX, YminusX, PP, MM, TT2d, ZZ2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
@ -327,7 +292,7 @@ func (v *projP1xP1) Add(p *Point, q *projCached) *projP1xP1 {
}
func (v *projP1xP1) Sub(p *Point, q *projCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, ZZ2 fieldElement
var YplusX, YminusX, PP, MM, TT2d, ZZ2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
@ -347,7 +312,7 @@ func (v *projP1xP1) Sub(p *Point, q *projCached) *projP1xP1 {
}
func (v *projP1xP1) AddAffine(p *Point, q *affineCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, Z2 fieldElement
var YplusX, YminusX, PP, MM, TT2d, Z2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
@ -366,7 +331,7 @@ func (v *projP1xP1) AddAffine(p *Point, q *affineCached) *projP1xP1 {
}
func (v *projP1xP1) SubAffine(p *Point, q *affineCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, Z2 fieldElement
var YplusX, YminusX, PP, MM, TT2d, Z2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
@ -387,7 +352,7 @@ func (v *projP1xP1) SubAffine(p *Point, q *affineCached) *projP1xP1 {
// Doubling.
func (v *projP1xP1) Double(p *projP2) *projP1xP1 {
var XX, YY, ZZ2, XplusYsq fieldElement
var XX, YY, ZZ2, XplusYsq field.Element
XX.Square(&p.X)
YY.Square(&p.Y)
@ -404,19 +369,6 @@ func (v *projP1xP1) Double(p *projP2) *projP1xP1 {
return v
}
// MultByCofactor sets v = 8 * p, and returns v.
func (v *Point) MultByCofactor(p *Point) *Point {
checkInitialized(p)
result := projP1xP1{}
pp := (&projP2{}).FromP3(p)
result.Double(pp)
pp.FromP1xP1(&result)
result.Double(pp)
pp.FromP1xP1(&result)
result.Double(pp)
return v.fromP1xP1(&result)
}
// Negation.
// Negate sets v = -p, and returns v.
@ -433,7 +385,7 @@ func (v *Point) Negate(p *Point) *Point {
func (v *Point) Equal(u *Point) int {
checkInitialized(v, u)
var t1, t2, t3, t4 fieldElement
var t1, t2, t3, t4 field.Element
t1.Multiply(&v.x, &u.z)
t2.Multiply(&u.x, &v.z)
t3.Multiply(&v.y, &u.z)
@ -464,13 +416,13 @@ func (v *affineCached) Select(a, b *affineCached, cond int) *affineCached {
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
func (v *projCached) CondNeg(cond int) *projCached {
v.YplusX.Swap(&v.YminusX, cond)
v.T2d.condNeg(&v.T2d, cond)
v.T2d.Select(new(field.Element).Negate(&v.T2d), &v.T2d, cond)
return v
}
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
func (v *affineCached) CondNeg(cond int) *affineCached {
v.YplusX.Swap(&v.YminusX, cond)
v.T2d.condNeg(&v.T2d, cond)
v.T2d.Select(new(field.Element).Negate(&v.T2d), &v.T2d, cond)
return v
}

View file

@ -6,9 +6,12 @@ package edwards25519
import (
"encoding/hex"
"os"
"reflect"
"strings"
"testing"
"testing/quick"
"filippo.io/edwards25519/field"
)
var B = NewGeneratorPoint()
@ -17,7 +20,7 @@ var I = NewIdentityPoint()
func checkOnCurve(t *testing.T, points ...*Point) {
t.Helper()
for i, p := range points {
var XX, YY, ZZ, ZZZZ fieldElement
var XX, YY, ZZ, ZZZZ field.Element
XX.Square(&p.x)
YY.Square(&p.y)
ZZ.Square(&p.z)
@ -26,7 +29,7 @@ func checkOnCurve(t *testing.T, points ...*Point) {
// -(X/Z)² + (Y/Z)² = 1 + d(X/Z)²(Y/Z)²
// (-X² + Y²)/Z² = 1 + (dX²Y²)/Z⁴
// (-X² + Y²)*Z² = Z⁴ + dX²Y²
var lhs, rhs fieldElement
var lhs, rhs field.Element
lhs.Subtract(&YY, &XX).Multiply(&lhs, &ZZ)
rhs.Multiply(d, &XX).Multiply(&rhs, &YY).Add(&rhs, &ZZZZ)
if lhs.Equal(&rhs) != 1 {
@ -41,12 +44,30 @@ func checkOnCurve(t *testing.T, points ...*Point) {
}
}
func TestGenerator(t *testing.T) {
// These are the coordinates of B from RFC 8032, Section 5.1, converted to
// little endian hex.
x := "1ad5258f602d56c9b2a7259560c72c695cdcd6fd31e2a4c0fe536ecdd3366921"
y := "5866666666666666666666666666666666666666666666666666666666666666"
if got := hex.EncodeToString(B.x.Bytes()); got != x {
t.Errorf("wrong B.x: got %s, expected %s", got, x)
}
if got := hex.EncodeToString(B.y.Bytes()); got != y {
t.Errorf("wrong B.y: got %s, expected %s", got, y)
}
if B.z.Equal(feOne) != 1 {
t.Errorf("wrong B.z: got %v, expected 1", B.z)
}
// Check that t is correct.
checkOnCurve(t, B)
}
func TestAddSubNegOnBasePoint(t *testing.T) {
checkLhs, checkRhs := &Point{}, &Point{}
checkLhs.Add(B, B)
tmpP2 := (&projP2{}).FromP3(B)
tmpP1xP1 := (&projP1xP1{}).Double(tmpP2)
tmpP2 := new(projP2).FromP3(B)
tmpP1xP1 := new(projP1xP1).Double(tmpP2)
checkRhs.fromP1xP1(tmpP1xP1)
if checkLhs.Equal(checkRhs) != 1 {
t.Error("B + B != [2]B")
@ -54,7 +75,7 @@ func TestAddSubNegOnBasePoint(t *testing.T) {
checkOnCurve(t, checkLhs, checkRhs)
checkLhs.Subtract(B, B)
Bneg := (&Point{}).Negate(B)
Bneg := new(Point).Negate(B)
checkRhs.Add(B, Bneg)
if checkLhs.Equal(checkRhs) != 1 {
t.Error("B - B != B + (-B)")
@ -239,11 +260,11 @@ func TestNonCanonicalPoints(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p1, err := (&Point{}).SetBytes(decodeHex(tt.encoding))
p1, err := new(Point).SetBytes(decodeHex(tt.encoding))
if err != nil {
t.Fatalf("error decoding non-canonical point: %v", err)
}
p2, err := (&Point{}).SetBytes(decodeHex(tt.canonical))
p2, err := new(Point).SetBytes(decodeHex(tt.canonical))
if err != nil {
t.Fatalf("error decoding canonical point: %v", err)
}
@ -258,107 +279,20 @@ func TestNonCanonicalPoints(t *testing.T) {
}
}
// TestBytesMontgomery tests the SetBytesWithClamping+BytesMontgomery path
// equivalence to curve25519.X25519 for basepoint scalar multiplications.
//
// Note that you can't actually implement X25519 with this package because
// there is no SetBytesMontgomery, and it would not be possible to implement
// it properly: points on the twist would get rejected, and the Scalar returned
// by SetBytesWithClamping does not preserve its cofactor-clearing properties.
//
// Disabled to avoid the golang.org/x/crypto module dependency.
/* func TestBytesMontgomery(t *testing.T) {
f := func(scalar [32]byte) bool {
s := NewScalar().SetBytesWithClamping(scalar[:])
p := (&Point{}).ScalarBaseMult(s)
got := p.BytesMontgomery()
want, _ := curve25519.X25519(scalar[:], curve25519.Basepoint)
return bytes.Equal(got, want)
}
if err := quick.Check(f, nil); err != nil {
t.Error(err)
}
} */
func TestBytesMontgomerySodium(t *testing.T) {
// Generated with libsodium.js 1.0.18
// crypto_sign_keypair().publicKey
publicKey := "3bf918ffc2c955dc895bf145f566fb96623c1cadbe040091175764b5fde322c0"
p, err := (&Point{}).SetBytes(decodeHex(publicKey))
if err != nil {
t.Fatal(err)
}
// crypto_sign_ed25519_pk_to_curve25519(publicKey)
want := "efc6c9d0738e9ea18d738ad4a2653631558931b0f1fde4dd58c436d19686dc28"
if got := hex.EncodeToString(p.BytesMontgomery()); got != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestBytesMontgomeryInfinity(t *testing.T) {
p := NewIdentityPoint()
want := "0000000000000000000000000000000000000000000000000000000000000000"
if got := hex.EncodeToString(p.BytesMontgomery()); got != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestMultByCofactor(t *testing.T) {
lowOrderBytes := "26e8958fc2b227b045c3f489f2ef98f0d5dfac05d3c63339b13802886d53fc85"
lowOrder, err := (&Point{}).SetBytes(decodeHex(lowOrderBytes))
if err != nil {
t.Fatal(err)
}
if p := (&Point{}).MultByCofactor(lowOrder); p.Equal(NewIdentityPoint()) != 1 {
t.Errorf("expected low order point * cofactor to be the identity")
}
f := func(scalar [64]byte) bool {
s := NewScalar().SetUniformBytes(scalar[:])
p := (&Point{}).ScalarBaseMult(s)
p8 := (&Point{}).MultByCofactor(p)
checkOnCurve(t, p8)
// 8 * p == (8 * s) * B
s.Multiply(s, &Scalar{[32]byte{8}})
pp := (&Point{}).ScalarBaseMult(s)
if p8.Equal(pp) != 1 {
return false
}
// 8 * p == 8 * (lowOrder + p)
pp.Add(p, lowOrder)
pp.MultByCofactor(pp)
if p8.Equal(pp) != 1 {
return false
}
// 8 * p == p + p + p + p + p + p + p + p
pp.Set(NewIdentityPoint())
for i := 0; i < 8; i++ {
pp.Add(pp, p)
}
return p8.Equal(pp) == 1
}
if err := quick.Check(f, nil); err != nil {
t.Error(err)
}
}
var testAllocationsSink byte
func TestAllocations(t *testing.T) {
allocs := testing.AllocsPerRun(100, func() {
if strings.HasSuffix(os.Getenv("GO_BUILDER_NAME"), "-noopt") {
t.Skip("skipping allocations test without relevant optimizations")
}
if allocs := testing.AllocsPerRun(100, func() {
p := NewIdentityPoint()
p.Add(p, NewGeneratorPoint())
s := NewScalar()
testAllocationsSink ^= s.Bytes()[0]
testAllocationsSink ^= p.Bytes()[0]
testAllocationsSink ^= p.BytesMontgomery()[0]
})
if allocs := int(allocs); allocs != 0 {
t.Errorf("expected zero allocations, got %d", allocs)
}); allocs > 0 {
t.Errorf("expected zero allocations, got %0.1v", allocs)
}
}

284
extra.go Normal file
View file

@ -0,0 +1,284 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
// This file contains additional functionality that is not included in the
// upstream crypto/ed25519/internal/edwards25519 package.
import "filippo.io/edwards25519/field"
// BytesMontgomery converts v to a point on the birationally-equivalent
// Curve25519 Montgomery curve, and returns its canonical 32 bytes encoding
// according to RFC 7748.
//
// Note that BytesMontgomery only encodes the u-coordinate, so v and -v encode
// to the same value. If v is the identity point, BytesMontgomery returns 32
// zero bytes, analogously to the X25519 function.
func (v *Point) BytesMontgomery() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var buf [32]byte
return v.bytesMontgomery(&buf)
}
func (v *Point) bytesMontgomery(buf *[32]byte) []byte {
checkInitialized(v)
// RFC 7748, Section 4.1 provides the bilinear map to calculate the
// Montgomery u-coordinate
//
// u = (1 + y) / (1 - y)
//
// where y = Y / Z.
var y, recip, u field.Element
y.Multiply(&v.y, y.Invert(&v.z)) // y = Y / Z
recip.Invert(recip.Subtract(feOne, &y)) // r = 1/(1 - y)
u.Multiply(u.Add(feOne, &y), &recip) // u = (1 + y)*r
return copyFieldElement(buf, &u)
}
// MultByCofactor sets v = 8 * p, and returns v.
func (v *Point) MultByCofactor(p *Point) *Point {
checkInitialized(p)
result := projP1xP1{}
pp := (&projP2{}).FromP3(p)
result.Double(pp)
pp.FromP1xP1(&result)
result.Double(pp)
pp.FromP1xP1(&result)
result.Double(pp)
return v.fromP1xP1(&result)
}
// Given k > 0, set s = s**(2*i).
func (s *Scalar) pow2k(k int) {
for i := 0; i < k; i++ {
s.Multiply(s, s)
}
}
// Invert sets s to the inverse of a nonzero scalar v, and returns s.
//
// If t is zero, Invert will panic.
func (s *Scalar) Invert(t *Scalar) *Scalar {
if t.s == [32]byte{} {
panic("edwards25519: zero Scalar passed to Invert")
}
// Uses a hardcoded sliding window of width 4.
var table [8]Scalar
var tt Scalar
tt.Multiply(t, t)
table[0] = *t
for i := 0; i < 7; i++ {
table[i+1].Multiply(&table[i], &tt)
}
// Now table = [t**1, t**3, t**7, t**11, t**13, t**15]
// so t**k = t[k/2] for odd k
// To compute the sliding window digits, use the following Sage script:
// sage: import itertools
// sage: def sliding_window(w,k):
// ....: digits = []
// ....: while k > 0:
// ....: if k % 2 == 1:
// ....: kmod = k % (2**w)
// ....: digits.append(kmod)
// ....: k = k - kmod
// ....: else:
// ....: digits.append(0)
// ....: k = k // 2
// ....: return digits
// Now we can compute s roughly as follows:
// sage: s = 1
// sage: for coeff in reversed(sliding_window(4,l-2)):
// ....: s = s*s
// ....: if coeff > 0 :
// ....: s = s*t**coeff
// This works on one bit at a time, with many runs of zeros.
// The digits can be collapsed into [(count, coeff)] as follows:
// sage: [(len(list(group)),d) for d,group in itertools.groupby(sliding_window(4,l-2))]
// Entries of the form (k, 0) turn into pow2k(k)
// Entries of the form (1, coeff) turn into a squaring and then a table lookup.
// We can fold the squaring into the previous pow2k(k) as pow2k(k+1).
*s = table[1/2]
s.pow2k(127 + 1)
s.Multiply(s, &table[1/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[13/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[5/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[1/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(5 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(9 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[13/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[11/2])
return s
}
// MultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
//
// Execution time depends only on the lengths of the two slices, which must match.
func (v *Point) MultiScalarMult(scalars []*Scalar, points []*Point) *Point {
if len(scalars) != len(points) {
panic("edwards25519: called MultiScalarMult with different size inputs")
}
checkInitialized(points...)
// Proceed as in the single-base case, but share doublings
// between each point in the multiscalar equation.
// Build lookup tables for each point
tables := make([]projLookupTable, len(points))
for i := range tables {
tables[i].FromP3(points[i])
}
// Compute signed radix-16 digits for each scalar
digits := make([][64]int8, len(scalars))
for i := range digits {
digits[i] = scalars[i].signedRadix16()
}
// Unwrap first loop iteration to save computing 16*identity
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
// Lookup-and-add the appropriate multiple of each input point
for j := range tables {
tables[j].SelectInto(multiple, digits[j][63])
tmp1.Add(v, multiple) // tmp1 = v + x_(j,63)*Q in P1xP1 coords
v.fromP1xP1(tmp1) // update v
}
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
for i := 62; i >= 0; i-- {
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
v.fromP1xP1(tmp1) // v = 16*(prev) in P3 coords
// Lookup-and-add the appropriate multiple of each input point
for j := range tables {
tables[j].SelectInto(multiple, digits[j][i])
tmp1.Add(v, multiple) // tmp1 = v + x_(j,i)*Q in P1xP1 coords
v.fromP1xP1(tmp1) // update v
}
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
}
return v
}
// VarTimeMultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
//
// Execution time depends on the inputs.
func (v *Point) VarTimeMultiScalarMult(scalars []*Scalar, points []*Point) *Point {
if len(scalars) != len(points) {
panic("edwards25519: called VarTimeMultiScalarMult with different size inputs")
}
checkInitialized(points...)
// Generalize double-base NAF computation to arbitrary sizes.
// Here all the points are dynamic, so we only use the smaller
// tables.
// Build lookup tables for each point
tables := make([]nafLookupTable5, len(points))
for i := range tables {
tables[i].FromP3(points[i])
}
// Compute a NAF for each scalar
nafs := make([][256]int8, len(scalars))
for i := range nafs {
nafs[i] = scalars[i].nonAdjacentForm(5)
}
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
tmp2.Zero()
// Move from high to low bits, doubling the accumulator
// at each iteration and checking whether there is a nonzero
// coefficient to look up a multiple of.
//
// Skip trying to find the first nonzero coefficent, because
// searching might be more work than a few extra doublings.
for i := 255; i >= 0; i-- {
tmp1.Double(tmp2)
for j := range nafs {
if nafs[j][i] > 0 {
v.fromP1xP1(tmp1)
tables[j].SelectInto(multiple, nafs[j][i])
tmp1.Add(v, multiple)
} else if nafs[j][i] < 0 {
v.fromP1xP1(tmp1)
tables[j].SelectInto(multiple, -nafs[j][i])
tmp1.Sub(v, multiple)
}
}
tmp2.FromP1xP1(tmp1)
}
v.fromP2(tmp2)
return v
}

162
extra_test.go Normal file
View file

@ -0,0 +1,162 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"encoding/hex"
"testing"
"testing/quick"
)
// TestBytesMontgomery tests the SetBytesWithClamping+BytesMontgomery path
// equivalence to curve25519.X25519 for basepoint scalar multiplications.
//
// Note that you can't actually implement X25519 with this package because
// there is no SetBytesMontgomery, and it would not be possible to implement
// it properly: points on the twist would get rejected, and the Scalar returned
// by SetBytesWithClamping does not preserve its cofactor-clearing properties.
//
// Disabled to avoid the golang.org/x/crypto module dependency.
/* func TestBytesMontgomery(t *testing.T) {
f := func(scalar [32]byte) bool {
s := NewScalar().SetBytesWithClamping(scalar[:])
p := (&Point{}).ScalarBaseMult(s)
got := p.BytesMontgomery()
want, _ := curve25519.X25519(scalar[:], curve25519.Basepoint)
return bytes.Equal(got, want)
}
if err := quick.Check(f, nil); err != nil {
t.Error(err)
}
} */
func TestBytesMontgomerySodium(t *testing.T) {
// Generated with libsodium.js 1.0.18
// crypto_sign_keypair().publicKey
publicKey := "3bf918ffc2c955dc895bf145f566fb96623c1cadbe040091175764b5fde322c0"
p, err := (&Point{}).SetBytes(decodeHex(publicKey))
if err != nil {
t.Fatal(err)
}
// crypto_sign_ed25519_pk_to_curve25519(publicKey)
want := "efc6c9d0738e9ea18d738ad4a2653631558931b0f1fde4dd58c436d19686dc28"
if got := hex.EncodeToString(p.BytesMontgomery()); got != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestBytesMontgomeryInfinity(t *testing.T) {
p := NewIdentityPoint()
want := "0000000000000000000000000000000000000000000000000000000000000000"
if got := hex.EncodeToString(p.BytesMontgomery()); got != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestMultByCofactor(t *testing.T) {
lowOrderBytes := "26e8958fc2b227b045c3f489f2ef98f0d5dfac05d3c63339b13802886d53fc85"
lowOrder, err := (&Point{}).SetBytes(decodeHex(lowOrderBytes))
if err != nil {
t.Fatal(err)
}
if p := (&Point{}).MultByCofactor(lowOrder); p.Equal(NewIdentityPoint()) != 1 {
t.Errorf("expected low order point * cofactor to be the identity")
}
f := func(scalar [64]byte) bool {
s := NewScalar().SetUniformBytes(scalar[:])
p := (&Point{}).ScalarBaseMult(s)
p8 := (&Point{}).MultByCofactor(p)
checkOnCurve(t, p8)
// 8 * p == (8 * s) * B
s.Multiply(s, &Scalar{[32]byte{8}})
pp := (&Point{}).ScalarBaseMult(s)
if p8.Equal(pp) != 1 {
return false
}
// 8 * p == 8 * (lowOrder + p)
pp.Add(p, lowOrder)
pp.MultByCofactor(pp)
if p8.Equal(pp) != 1 {
return false
}
// 8 * p == p + p + p + p + p + p + p + p
pp.Set(NewIdentityPoint())
for i := 0; i < 8; i++ {
pp.Add(pp, p)
}
return p8.Equal(pp) == 1
}
if err := quick.Check(f, nil); err != nil {
t.Error(err)
}
}
func TestScalarInvert(t *testing.T) {
invertWorks := func(xInv Scalar, x notZeroScalar) bool {
xInv.Invert((*Scalar)(&x))
var check Scalar
check.Multiply((*Scalar)(&x), &xInv)
return check == scOne && isReduced(&xInv)
}
if err := quick.Check(invertWorks, quickCheckConfig32); err != nil {
t.Error(err)
}
}
func TestMultiScalarMultMatchesBaseMult(t *testing.T) {
multiScalarMultMatchesBaseMult := func(x, y, z Scalar) bool {
var p, q1, q2, q3, check Point
p.MultiScalarMult([]*Scalar{&x, &y, &z}, []*Point{B, B, B})
q1.ScalarBaseMult(&x)
q2.ScalarBaseMult(&y)
q3.ScalarBaseMult(&z)
check.Add(&q1, &q2).Add(&check, &q3)
checkOnCurve(t, &p, &check, &q1, &q2, &q3)
return p.Equal(&check) == 1
}
if err := quick.Check(multiScalarMultMatchesBaseMult, quickCheckConfig32); err != nil {
t.Error(err)
}
}
func TestVarTimeMultiScalarMultMatchesBaseMult(t *testing.T) {
varTimeMultiScalarMultMatchesBaseMult := func(x, y, z Scalar) bool {
var p, q1, q2, q3, check Point
p.VarTimeMultiScalarMult([]*Scalar{&x, &y, &z}, []*Point{B, B, B})
q1.ScalarBaseMult(&x)
q2.ScalarBaseMult(&y)
q3.ScalarBaseMult(&z)
check.Add(&q1, &q2).Add(&check, &q3)
checkOnCurve(t, &p, &check, &q1, &q2, &q3)
return p.Equal(&check) == 1
}
if err := quick.Check(varTimeMultiScalarMultMatchesBaseMult, quickCheckConfig32); err != nil {
t.Error(err)
}
}
func BenchmarkMultiScalarMultSize8(t *testing.B) {
var p Point
x := dalekScalar
for i := 0; i < t.N; i++ {
p.MultiScalarMult([]*Scalar{&x, &x, &x, &x, &x, &x, &x, &x},
[]*Point{B, B, B, B, B, B, B, B})
}
}

View file

@ -12,13 +12,13 @@ import (
. "github.com/mmcloughlin/avo/operand"
. "github.com/mmcloughlin/avo/reg"
_ "filippo.io/edwards25519"
_ "filippo.io/edwards25519/field"
)
//go:generate go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg edwards25519
//go:generate go run . -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field
func main() {
Package("filippo.io/edwards25519")
Package("filippo.io/edwards25519/field")
ConstraintExpr("amd64,gc,!purego")
feMul()
feSquare()
@ -40,7 +40,7 @@ type uint128 struct {
func (c uint128) String() string { return c.name }
func feSquare() {
TEXT("feSquare", NOSPLIT, "func(out, a *fieldElement)")
TEXT("feSquare", NOSPLIT, "func(out, a *Element)")
Doc("feSquare sets out = a * a. It works like feSquareGeneric.")
Pragma("noescape")
@ -129,7 +129,7 @@ func feSquare() {
}
func feMul() {
TEXT("feMul", NOSPLIT, "func(out, a, b *fieldElement)")
TEXT("feMul", NOSPLIT, "func(out, a, b *Element)")
Doc("feMul sets out = a * b. It works like feMulGeneric.")
Pragma("noescape")

View file

@ -1,4 +1,4 @@
module filippo.io/edwards25519/asm
module asm
go 1.16
@ -7,4 +7,4 @@ require (
github.com/mmcloughlin/avo v0.2.0
)
replace filippo.io/edwards25519 => ../
replace filippo.io/edwards25519 v0.0.0 => ../..

View file

@ -2,7 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
// Package field implements fast arithmetic modulo 2^255-19.
package field
import (
"crypto/subtle"
@ -10,15 +11,15 @@ import (
"math/bits"
)
// fieldElement represents an element of the field GF(2^255-19). Note that this
// Element represents an element of the field GF(2^255-19). Note that this
// is not a cryptographically secure group, and should only be used to interact
// with Point coordinates.
// with edwards25519.Point coordinates.
//
// This type works similarly to math/big.Int, and all arguments and receivers
// are allowed to alias.
//
// The zero value is a valid zero element.
type fieldElement struct {
type Element struct {
// An element t represents the integer
// t.l0 + t.l1*2^51 + t.l2*2^102 + t.l3*2^153 + t.l4*2^204
//
@ -32,27 +33,24 @@ type fieldElement struct {
const maskLow51Bits uint64 = (1 << 51) - 1
var (
feZero = &fieldElement{0, 0, 0, 0, 0}
feOne = &fieldElement{1, 0, 0, 0, 0}
feTwo = &fieldElement{2, 0, 0, 0, 0}
feMinusOne = new(fieldElement).Negate(feOne)
)
var feZero = &Element{0, 0, 0, 0, 0}
// Zero sets v = 0, and returns v.
func (v *fieldElement) Zero() *fieldElement {
func (v *Element) Zero() *Element {
*v = *feZero
return v
}
var feOne = &Element{1, 0, 0, 0, 0}
// One sets v = 1, and returns v.
func (v *fieldElement) One() *fieldElement {
func (v *Element) One() *Element {
*v = *feOne
return v
}
// reduce reduces v modulo 2^255 - 19 and returns it.
func (v *fieldElement) reduce() *fieldElement {
func (v *Element) reduce() *Element {
v.carryPropagate()
// After the light reduction we now have a field element representation
@ -85,7 +83,7 @@ func (v *fieldElement) reduce() *fieldElement {
}
// Add sets v = a + b, and returns v.
func (v *fieldElement) Add(a, b *fieldElement) *fieldElement {
func (v *Element) Add(a, b *Element) *Element {
v.l0 = a.l0 + b.l0
v.l1 = a.l1 + b.l1
v.l2 = a.l2 + b.l2
@ -99,7 +97,7 @@ func (v *fieldElement) Add(a, b *fieldElement) *fieldElement {
}
// Subtract sets v = a - b, and returns v.
func (v *fieldElement) Subtract(a, b *fieldElement) *fieldElement {
func (v *Element) Subtract(a, b *Element) *Element {
// We first add 2 * p, to guarantee the subtraction won't underflow, and
// then subtract b (which can be up to 2^255 + 2^13 * 19).
v.l0 = (a.l0 + 0xFFFFFFFFFFFDA) - b.l0
@ -111,17 +109,17 @@ func (v *fieldElement) Subtract(a, b *fieldElement) *fieldElement {
}
// Negate sets v = -a, and returns v.
func (v *fieldElement) Negate(a *fieldElement) *fieldElement {
func (v *Element) Negate(a *Element) *Element {
return v.Subtract(feZero, a)
}
// Invert sets v = 1/z mod p, and returns v.
//
// If z == 0, Invert returns v = 0.
func (v *fieldElement) Invert(z *fieldElement) *fieldElement {
func (v *Element) Invert(z *Element) *Element {
// Inversion is implemented as exponentiation with exponent p 2. It uses the
// same sequence of 255 squarings and 11 multiplications as [Curve25519].
var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t fieldElement
var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t Element
z2.Square(z) // 2
t.Square(&z2) // 4
@ -129,7 +127,7 @@ func (v *fieldElement) Invert(z *fieldElement) *fieldElement {
z9.Multiply(&t, z) // 9
z11.Multiply(&z9, &z2) // 11
t.Square(&z11) // 22
z2_5_0.Multiply(&t, &z9) // 2^5 - 2^0 = 31
z2_5_0.Multiply(&t, &z9) // 31 = 2^5 - 2^0
t.Square(&z2_5_0) // 2^6 - 2^1
for i := 0; i < 4; i++ {
@ -183,17 +181,17 @@ func (v *fieldElement) Invert(z *fieldElement) *fieldElement {
}
// Set sets v = a, and returns v.
func (v *fieldElement) Set(a *fieldElement) *fieldElement {
func (v *Element) Set(a *Element) *Element {
*v = *a
return v
}
// SetBytes sets v to x, which must be a 32 bytes little-endian encoding.
// SetBytes sets v to x, which must be a 32-byte little-endian encoding.
//
// Consistently with RFC 7748, the most significant bit (the high bit of the
// Consistent with RFC 7748, the most significant bit (the high bit of the
// last byte) is ignored, and non-canonical values (2^255-19 through 2^255-1)
// are accepted. Note that this is laxer than specified by RFC 8032.
func (v *fieldElement) SetBytes(x []byte) *fieldElement {
func (v *Element) SetBytes(x []byte) *Element {
if len(x) != 32 {
panic("edwards25519: invalid field element input size")
}
@ -218,15 +216,15 @@ func (v *fieldElement) SetBytes(x []byte) *fieldElement {
return v
}
// Bytes returns the canonical 32 bytes little-endian encoding of v.
func (v *fieldElement) Bytes() []byte {
// Bytes returns the canonical 32-byte little-endian encoding of v.
func (v *Element) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [32]byte
return v.bytes(&out)
}
func (v *fieldElement) bytes(out *[32]byte) []byte {
func (v *Element) bytes(out *[32]byte) []byte {
t := *v
t.reduce()
@ -247,16 +245,17 @@ func (v *fieldElement) bytes(out *[32]byte) []byte {
}
// Equal returns 1 if v and u are equal, and 0 otherwise.
func (v *fieldElement) Equal(u *fieldElement) int {
func (v *Element) Equal(u *Element) int {
sa, sv := u.Bytes(), v.Bytes()
return subtle.ConstantTimeCompare(sa, sv)
}
const mask64Bits uint64 = (1 << 64) - 1
// mask64Bits returns 0xffffffff if cond is 1, and 0 otherwise.
func mask64Bits(cond int) uint64 { return ^(uint64(cond) - 1) }
// Select sets v to a if cond == 1, and to b if cond == 0.
func (v *fieldElement) Select(a, b *fieldElement, cond int) *fieldElement {
m := uint64(cond) * mask64Bits
func (v *Element) Select(a, b *Element, cond int) *Element {
m := mask64Bits(cond)
v.l0 = (m & a.l0) | (^m & b.l0)
v.l1 = (m & a.l1) | (^m & b.l1)
v.l2 = (m & a.l2) | (^m & b.l2)
@ -266,8 +265,8 @@ func (v *fieldElement) Select(a, b *fieldElement, cond int) *fieldElement {
}
// Swap swaps v and u if cond == 1 or leaves them unchanged if cond == 0, and returns v.
func (v *fieldElement) Swap(u *fieldElement, cond int) {
m := uint64(cond) * mask64Bits
func (v *Element) Swap(u *Element, cond int) {
m := mask64Bits(cond)
t := m & (v.l0 ^ u.l0)
v.l0 ^= t
u.l0 ^= t
@ -285,37 +284,30 @@ func (v *fieldElement) Swap(u *fieldElement, cond int) {
u.l4 ^= t
}
// condNeg sets v to -u if cond == 1, and to u if cond == 0.
func (v *fieldElement) condNeg(u *fieldElement, cond int) *fieldElement {
tmp := new(fieldElement).Negate(u)
return v.Select(tmp, u, cond)
}
// IsNegative returns 1 if v is negative, and 0 otherwise.
func (v *fieldElement) IsNegative() int {
b := v.Bytes()
return int(b[0] & 1)
func (v *Element) IsNegative() int {
return int(v.Bytes()[0] & 1)
}
// Absolute sets v to |u|, and returns v.
func (v *fieldElement) Absolute(u *fieldElement) *fieldElement {
return v.condNeg(u, u.IsNegative())
func (v *Element) Absolute(u *Element) *Element {
return v.Select(new(Element).Negate(u), u, u.IsNegative())
}
// Multiply sets v = x * y, and returns v.
func (v *fieldElement) Multiply(x, y *fieldElement) *fieldElement {
func (v *Element) Multiply(x, y *Element) *Element {
feMul(v, x, y)
return v
}
// Square sets v = x * x, and returns v.
func (v *fieldElement) Square(x *fieldElement) *fieldElement {
func (v *Element) Square(x *Element) *Element {
feSquare(v, x)
return v
}
// Mult32 sets v = x * y, and returns v.
func (v *fieldElement) Mult32(x *fieldElement, y uint32) *fieldElement {
func (v *Element) Mult32(x *Element, y uint32) *Element {
x0lo, x0hi := mul51(x.l0, y)
x1lo, x1hi := mul51(x.l1, y)
x2lo, x2hi := mul51(x.l2, y)
@ -340,8 +332,8 @@ func mul51(a uint64, b uint32) (lo uint64, hi uint64) {
}
// Pow22523 set v = x^((p-5)/8), and returns v. (p-5)/8 is 2^252-3.
func (v *fieldElement) Pow22523(x *fieldElement) *fieldElement {
var t0, t1, t2 fieldElement
func (v *Element) Pow22523(x *Element) *Element {
var t0, t1, t2 Element
t0.Square(x) // x^2
t1.Square(&t0) // x^4
@ -391,7 +383,7 @@ func (v *fieldElement) Pow22523(x *fieldElement) *fieldElement {
}
// sqrtM1 is 2^((p-1)/4), which squared is equal to -1 by Euler's Criterion.
var sqrtM1 = &fieldElement{1718705420411056, 234908883556509,
var sqrtM1 = &Element{1718705420411056, 234908883556509,
2233514472574048, 2117202627021982, 765476049583133}
// SqrtRatio sets r to the non-negative square root of the ratio of u and v.
@ -399,8 +391,8 @@ var sqrtM1 = &fieldElement{1718705420411056, 234908883556509,
// If u/v is square, SqrtRatio returns r and 1. If u/v is not square, SqrtRatio
// sets r according to Section 4.3 of draft-irtf-cfrg-ristretto255-decaf448-00,
// and returns r and 0.
func (r *fieldElement) SqrtRatio(u, v *fieldElement) (rr *fieldElement, wasSquare int) {
var a, b fieldElement
func (r *Element) SqrtRatio(u, v *Element) (rr *Element, wasSquare int) {
var a, b Element
// r = (u * v3) * (u * v7)^((p-5)/8)
v2 := a.Square(v)

View file

@ -2,15 +2,15 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
package field
import (
"testing"
"testing/quick"
)
func checkAliasingOneArg(f func(v, x *fieldElement) *fieldElement) func(v, x fieldElement) bool {
return func(v, x fieldElement) bool {
func checkAliasingOneArg(f func(v, x *Element) *Element) func(v, x Element) bool {
return func(v, x Element) bool {
x1, v1 := x, x
// Calculate a reference f(x) without aliasing.
@ -28,9 +28,9 @@ func checkAliasingOneArg(f func(v, x *fieldElement) *fieldElement) func(v, x fie
}
}
func checkAliasingTwoArgs(f func(v, x, y *fieldElement) *fieldElement) func(v, x, y fieldElement) bool {
return func(v, x, y fieldElement) bool {
x1, y1, v1 := x, y, fieldElement{}
func checkAliasingTwoArgs(f func(v, x, y *Element) *Element) func(v, x, y Element) bool {
return func(v, x, y Element) bool {
x1, y1, v1 := x, y, Element{}
// Calculate a reference f(x, y) without aliasing.
if out := f(&v, &x, &y); out != &v && isInBounds(out) {
@ -74,43 +74,41 @@ func checkAliasingTwoArgs(f func(v, x, y *fieldElement) *fieldElement) func(v, x
}
}
// TestAliasing checks that receivers and arguments can alias each other without
// leading to incorrect results. That is, it ensures that it's safe to write
//
// v.Invert(v)
//
// or
//
// v.Add(v, v)
//
// without any of the inputs getting clobbered by the output being written.
func TestAliasing(t *testing.T) {
type target struct {
name string
oneArgF func(v, x *fieldElement) *fieldElement
twoArgsF func(v, x, y *fieldElement) *fieldElement
oneArgF func(v, x *Element) *Element
twoArgsF func(v, x, y *Element) *Element
}
for _, tt := range []target{
{name: "Abs", oneArgF: (*fieldElement).Absolute},
{name: "Invert", oneArgF: (*fieldElement).Invert},
{name: "Neg", oneArgF: (*fieldElement).Negate},
{name: "Set", oneArgF: (*fieldElement).Set},
{name: "Square", oneArgF: (*fieldElement).Square},
{
name: "CondNeg0",
oneArgF: func(v, x *fieldElement) *fieldElement {
return (*fieldElement).condNeg(v, x, 0)
},
},
{
name: "CondNeg1",
oneArgF: func(v, x *fieldElement) *fieldElement {
return (*fieldElement).condNeg(v, x, 1)
},
},
{name: "Mul", twoArgsF: (*fieldElement).Multiply},
{name: "Add", twoArgsF: (*fieldElement).Add},
{name: "Sub", twoArgsF: (*fieldElement).Subtract},
{name: "Absolute", oneArgF: (*Element).Absolute},
{name: "Invert", oneArgF: (*Element).Invert},
{name: "Negate", oneArgF: (*Element).Negate},
{name: "Set", oneArgF: (*Element).Set},
{name: "Square", oneArgF: (*Element).Square},
{name: "Multiply", twoArgsF: (*Element).Multiply},
{name: "Add", twoArgsF: (*Element).Add},
{name: "Subtract", twoArgsF: (*Element).Subtract},
{
name: "Select0",
twoArgsF: func(v, x, y *fieldElement) *fieldElement {
return (*fieldElement).Select(v, x, y, 0)
twoArgsF: func(v, x, y *Element) *Element {
return (*Element).Select(v, x, y, 0)
},
},
{
name: "Select1",
twoArgsF: func(v, x, y *fieldElement) *fieldElement {
return (*fieldElement).Select(v, x, y, 1)
twoArgsF: func(v, x, y *Element) *Element {
return (*Element).Select(v, x, y, 1)
},
},
} {

View file

@ -1,13 +1,13 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg edwards25519. DO NOT EDIT.
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
// +build amd64,gc,!purego
package edwards25519
package field
// feMul sets out = a * b. It works like feMulGeneric.
//go:noescape
func feMul(out *fieldElement, a *fieldElement, b *fieldElement)
func feMul(out *Element, a *Element, b *Element)
// feSquare sets out = a * a. It works like feSquareGeneric.
//go:noescape
func feSquare(out *fieldElement, a *fieldElement)
func feSquare(out *Element, a *Element)

View file

@ -1,10 +1,10 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg edwards25519. DO NOT EDIT.
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
// +build amd64,gc,!purego
#include "textflag.h"
// func feMul(out *fieldElement, a *fieldElement, b *fieldElement)
// func feMul(out *Element, a *Element, b *Element)
TEXT ·feMul(SB), NOSPLIT, $0-24
MOVQ a+8(FP), CX
MOVQ b+16(FP), BX
@ -220,7 +220,7 @@ TEXT ·feMul(SB), NOSPLIT, $0-24
MOVQ R15, 32(AX)
RET
// func feSquare(out *fieldElement, a *fieldElement)
// func feSquare(out *Element, a *Element)
TEXT ·feSquare(SB), NOSPLIT, $0-16
MOVQ a+8(FP), CX

View file

@ -2,10 +2,11 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !amd64 || !gc || purego
// +build !amd64 !gc purego
package edwards25519
package field
func feMul(v, x, y *fieldElement) { feMulGeneric(v, x, y) }
func feMul(v, x, y *Element) { feMulGeneric(v, x, y) }
func feSquare(v, x *fieldElement) { feSquareGeneric(v, x) }
func feSquare(v, x *Element) { feSquareGeneric(v, x) }

View file

@ -2,14 +2,15 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
// +build arm64,gc,!purego
package edwards25519
package field
//go:noescape
func carryPropagate(v *fieldElement)
func carryPropagate(v *Element)
func (v *fieldElement) carryPropagate() *fieldElement {
func (v *Element) carryPropagate() *Element {
carryPropagate(v)
return v
}

View file

@ -12,7 +12,7 @@
//
// See https://golang.org/issues/43145 for the main compiler issue.
//
// func carryPropagate(v *fieldElement)
// func carryPropagate(v *Element)
TEXT ·carryPropagate(SB),NOFRAME|NOSPLIT,$0-8
MOVD v+0(FP), R20

View file

@ -2,10 +2,11 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !arm64 || !gc || purego
// +build !arm64 !gc purego
package edwards25519
package field
func (v *fieldElement) carryPropagate() *fieldElement {
func (v *Element) carryPropagate() *Element {
return v.carryPropagateGeneric()
}

View file

@ -2,12 +2,12 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
package field
import "testing"
func BenchmarkAdd(b *testing.B) {
var x, y fieldElement
var x, y Element
x.One()
y.Add(feOne, feOne)
b.ResetTimer()
@ -16,8 +16,8 @@ func BenchmarkAdd(b *testing.B) {
}
}
func BenchmarkMul(b *testing.B) {
var x, y fieldElement
func BenchmarkMultiply(b *testing.B) {
var x, y Element
x.One()
y.Add(feOne, feOne)
b.ResetTimer()
@ -26,8 +26,8 @@ func BenchmarkMul(b *testing.B) {
}
}
func BenchmarkMul32(b *testing.B) {
var x fieldElement
func BenchmarkMult32(b *testing.B) {
var x Element
x.One()
b.ResetTimer()
for i := 0; i < b.N; i++ {

View file

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
package field
import "math/bits"
@ -31,7 +31,7 @@ func shiftRightBy51(a uint128) uint64 {
return (a.hi << (64 - 51)) | (a.lo >> 51)
}
func feMulGeneric(v, a, b *fieldElement) {
func feMulGeneric(v, a, b *Element) {
a0 := a.l0
a1 := a.l1
a2 := a.l2
@ -118,7 +118,7 @@ func feMulGeneric(v, a, b *fieldElement) {
// After the multiplication, we need to reduce (carry) the five coefficients
// to obtain a result with limbs that are at most slightly larger than 2⁵¹,
// to respect the fieldElement invariant.
// to respect the Element invariant.
//
// Overall, the reduction works the same as carryPropagate, except with
// wider inputs: we take the carry for each coefficient by shifting it right
@ -156,13 +156,13 @@ func feMulGeneric(v, a, b *fieldElement) {
rr4 := r4.lo&maskLow51Bits + c3
// Now all coefficients fit into 64-bit registers but are still too large to
// be passed around as a fieldElement. We therefore do one last carry chain,
// be passed around as a Element. We therefore do one last carry chain,
// where the carries will be small enough to fit in the wiggle room above 2⁵¹.
*v = fieldElement{rr0, rr1, rr2, rr3, rr4}
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
func feSquareGeneric(v, a *fieldElement) {
func feSquareGeneric(v, a *Element) {
l0 := a.l0
l1 := a.l1
l2 := a.l2
@ -241,13 +241,13 @@ func feSquareGeneric(v, a *fieldElement) {
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
*v = fieldElement{rr0, rr1, rr2, rr3, rr4}
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
// carryPropagate brings the limbs below 52 bits by applying the reduction
// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry.
func (v *fieldElement) carryPropagateGeneric() *fieldElement {
func (v *Element) carryPropagateGeneric() *Element {
c0 := v.l0 >> 51
c1 := v.l1 >> 51
c2 := v.l2 >> 51

View file

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
package field
import (
"bytes"
@ -17,7 +17,7 @@ import (
"testing/quick"
)
func (v fieldElement) String() string {
func (v Element) String() string {
return hex.EncodeToString(v.Bytes())
}
@ -25,9 +25,9 @@ func (v fieldElement) String() string {
// times. The default value of -quickchecks is 100.
var quickCheckConfig1024 = &quick.Config{MaxCountScale: 1 << 10}
func generateFieldElement(rand *mathrand.Rand) fieldElement {
func generateFieldElement(rand *mathrand.Rand) Element {
const maskLow52Bits = (1 << 52) - 1
return fieldElement{
return Element{
rand.Uint64() & maskLow52Bits,
rand.Uint64() & maskLow52Bits,
rand.Uint64() & maskLow52Bits,
@ -70,8 +70,8 @@ var (
}
)
func generateWeirdFieldElement(rand *mathrand.Rand) fieldElement {
return fieldElement{
func generateWeirdFieldElement(rand *mathrand.Rand) Element {
return Element{
weirdLimbs52[rand.Intn(len(weirdLimbs52))],
weirdLimbs51[rand.Intn(len(weirdLimbs51))],
weirdLimbs51[rand.Intn(len(weirdLimbs51))],
@ -80,7 +80,7 @@ func generateWeirdFieldElement(rand *mathrand.Rand) fieldElement {
}
}
func (fieldElement) Generate(rand *mathrand.Rand, size int) reflect.Value {
func (Element) Generate(rand *mathrand.Rand, size int) reflect.Value {
if rand.Intn(2) == 0 {
return reflect.ValueOf(generateWeirdFieldElement(rand))
}
@ -89,7 +89,7 @@ func (fieldElement) Generate(rand *mathrand.Rand, size int) reflect.Value {
// isInBounds returns whether the element is within the expected bit size bounds
// after a light reduction.
func isInBounds(x *fieldElement) bool {
func isInBounds(x *Element) bool {
return bits.Len64(x.l0) <= 52 &&
bits.Len64(x.l1) <= 52 &&
bits.Len64(x.l2) <= 52 &&
@ -97,16 +97,16 @@ func isInBounds(x *fieldElement) bool {
bits.Len64(x.l4) <= 52
}
func TestMulDistributesOverAdd(t *testing.T) {
mulDistributesOverAdd := func(x, y, z fieldElement) bool {
func TestMultiplyDistributesOverAdd(t *testing.T) {
multiplyDistributesOverAdd := func(x, y, z Element) bool {
// Compute t1 = (x+y)*z
t1 := new(fieldElement)
t1 := new(Element)
t1.Add(&x, &y)
t1.Multiply(t1, &z)
// Compute t2 = x*z + y*z
t2 := new(fieldElement)
t3 := new(fieldElement)
t2 := new(Element)
t3 := new(Element)
t2.Multiply(&x, &z)
t3.Multiply(&y, &z)
t2.Add(t2, t3)
@ -114,7 +114,7 @@ func TestMulDistributesOverAdd(t *testing.T) {
return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
}
if err := quick.Check(mulDistributesOverAdd, quickCheckConfig1024); err != nil {
if err := quick.Check(multiplyDistributesOverAdd, quickCheckConfig1024); err != nil {
t.Error(err)
}
}
@ -147,26 +147,20 @@ func TestMul64to128(t *testing.T) {
}
func TestSetBytesRoundTrip(t *testing.T) {
f1 := func(in [32]byte, fe fieldElement) bool {
f1 := func(in [32]byte, fe Element) bool {
fe.SetBytes(in[:])
// Mask the most significant bit as it's ignored by SetBytes. (Now
// instead of earlier so we check the masking in SetBytes is working.)
in[len(in)-1] &= (1 << 7) - 1
// TODO: values in the range [2^255-19, 2^255-1] will still fail the
// comparison as they will have been reduced in the round-trip, but the
// current quickcheck generation strategy will never hit them, which is
// not good. We should have a weird generator that aims for edge cases,
// and we'll know it works when this test breaks.
return bytes.Equal(in[:], fe.Bytes()) && isInBounds(&fe)
}
if err := quick.Check(f1, nil); err != nil {
t.Errorf("failed bytes->FE->bytes round-trip: %v", err)
}
f2 := func(fe, r fieldElement) bool {
f2 := func(fe, r Element) bool {
r.SetBytes(fe.Bytes())
// Intentionally not using Equal not to go through Bytes again.
@ -182,23 +176,23 @@ func TestSetBytesRoundTrip(t *testing.T) {
// Check some fixed vectors from dalek
type feRTTest struct {
fe fieldElement
fe Element
b []byte
}
var tests = []feRTTest{
{
fe: fieldElement{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676},
fe: Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676},
b: []byte{74, 209, 69, 197, 70, 70, 161, 222, 56, 226, 229, 19, 112, 60, 25, 92, 187, 74, 222, 56, 50, 153, 51, 233, 40, 74, 57, 6, 160, 185, 213, 31},
},
{
fe: fieldElement{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972},
fe: Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972},
b: []byte{199, 23, 106, 112, 61, 77, 216, 79, 186, 60, 11, 118, 13, 16, 103, 15, 42, 32, 83, 250, 44, 57, 204, 198, 78, 199, 253, 119, 146, 172, 3, 122},
},
}
for _, tt := range tests {
b := tt.fe.Bytes()
if !bytes.Equal(b, tt.b) || new(fieldElement).SetBytes(tt.b).Equal(&tt.fe) != 1 {
if !bytes.Equal(b, tt.b) || new(Element).SetBytes(tt.b).Equal(&tt.fe) != 1 {
t.Errorf("Failed fixed roundtrip: %v", tt)
}
}
@ -212,7 +206,7 @@ func swapEndianness(buf []byte) []byte {
}
func TestBytesBigEquivalence(t *testing.T) {
f1 := func(in [32]byte, fe, fe1 fieldElement) bool {
f1 := func(in [32]byte, fe, fe1 Element) bool {
fe.SetBytes(in[:])
in[len(in)-1] &= (1 << 7) - 1 // mask the most significant bit
@ -234,7 +228,7 @@ func TestBytesBigEquivalence(t *testing.T) {
}
// fromBig sets v = n, and returns v. The bit length of n must not exceed 256.
func (v *fieldElement) fromBig(n *big.Int) *fieldElement {
func (v *Element) fromBig(n *big.Int) *Element {
if n.BitLen() > 32*8 {
panic("edwards25519: invalid field element input size")
}
@ -253,7 +247,7 @@ func (v *fieldElement) fromBig(n *big.Int) *fieldElement {
return v.SetBytes(buf[:32])
}
func (v *fieldElement) fromDecimal(s string) *fieldElement {
func (v *Element) fromDecimal(s string) *Element {
n, ok := new(big.Int).SetString(s, 10)
if !ok {
panic("not a valid decimal: " + s)
@ -262,7 +256,7 @@ func (v *fieldElement) fromDecimal(s string) *fieldElement {
}
// toBig returns v as a big.Int.
func (v *fieldElement) toBig() *big.Int {
func (v *Element) toBig() *big.Int {
buf := v.Bytes()
words := make([]big.Word, 32*8/bits.UintSize)
@ -281,13 +275,14 @@ func (v *fieldElement) toBig() *big.Int {
func TestDecimalConstants(t *testing.T) {
sqrtM1String := "19681161376707505956807079304988542015446066515923890162744021073123829784752"
if exp := (&fieldElement{}).fromDecimal(sqrtM1String); sqrtM1.Equal(exp) != 1 {
if exp := new(Element).fromDecimal(sqrtM1String); sqrtM1.Equal(exp) != 1 {
t.Errorf("sqrtM1 is %v, expected %v", sqrtM1, exp)
}
dString := "37095705934669439343138083508754565189542113879843219016388785533085940283555"
if exp := (&fieldElement{}).fromDecimal(dString); d.Equal(exp) != 1 {
t.Errorf("d is %v, expected %v", d, exp)
}
// d is in the parent package, and we don't want to expose d or fromDecimal.
// dString := "37095705934669439343138083508754565189542113879843219016388785533085940283555"
// if exp := new(Element).fromDecimal(dString); d.Equal(exp) != 1 {
// t.Errorf("d is %v, expected %v", d, exp)
// }
}
func TestSetBytesRoundTripEdgeCases(t *testing.T) {
@ -296,21 +291,14 @@ func TestSetBytesRoundTripEdgeCases(t *testing.T) {
// behavior, and that Bytes reduces them.
}
// Tests self-consistency between FeMul and FeSquare.
func TestSanity(t *testing.T) {
var x fieldElement
var x2, x2sq fieldElement
// var x2Go, x2sqGo fieldElement
// Tests self-consistency between Multiply and Square.
func TestConsistency(t *testing.T) {
var x Element
var x2, x2sq Element
x = fieldElement{1, 1, 1, 1, 1}
x = Element{1, 1, 1, 1, 1}
x2.Multiply(&x, &x)
// FeMulGo(&x2Go, &x, &x)
x2sq.Square(&x)
// FeSquareGo(&x2sqGo, &x)
// if !vartimeEqual(x2, x2Go) || !vartimeEqual(x2sq, x2sqGo) || !vartimeEqual(x2, x2sq) {
// t.Fatalf("all ones failed\nmul.s: %d\nmul.g: %d\nsqr.s: %d\nsqr.g: %d\n", x2, x2Go, x2sq, x2sqGo)
// }
if x2 != x2sq {
t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
@ -325,13 +313,7 @@ func TestSanity(t *testing.T) {
x.SetBytes(bytes[:])
x2.Multiply(&x, &x)
// FeMulGo(&x2Go, &x, &x)
x2sq.Square(&x)
// FeSquareGo(&x2sqGo, &x)
// if !vartimeEqual(x2, x2Go) || !vartimeEqual(x2sq, x2sqGo) || !vartimeEqual(x2, x2sq) {
// t.Fatalf("random field element failed\nfe: %x\n\nmul.s: %x\nmul.g: %x\nsqr.s: %x\nsqr.g: %x\n", x, x2, x2Go, x2sq, x2sqGo)
// }
if x2 != x2sq {
t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
@ -339,8 +321,8 @@ func TestSanity(t *testing.T) {
}
func TestEqual(t *testing.T) {
x := fieldElement{1, 1, 1, 1, 1}
y := fieldElement{5, 4, 3, 2, 1}
x := Element{1, 1, 1, 1, 1}
y := Element{5, 4, 3, 2, 1}
eq := x.Equal(&x)
if eq != 1 {
@ -354,9 +336,9 @@ func TestEqual(t *testing.T) {
}
func TestInvert(t *testing.T) {
x := fieldElement{1, 1, 1, 1, 1}
one := fieldElement{1, 0, 0, 0, 0}
var xinv, r fieldElement
x := Element{1, 1, 1, 1, 1}
one := Element{1, 0, 0, 0, 0}
var xinv, r Element
xinv.Invert(&x)
r.Multiply(&x, &xinv)
@ -382,7 +364,7 @@ func TestInvert(t *testing.T) {
t.Errorf("random inversion identity failed, got: %x for field element %x", r, x)
}
zero := fieldElement{}
zero := Element{}
x.Set(&zero)
if xx := xinv.Invert(&x); xx != &xinv {
t.Errorf("inverting zero did not return the receiver")
@ -392,10 +374,10 @@ func TestInvert(t *testing.T) {
}
func TestSelectSwap(t *testing.T) {
a := fieldElement{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676}
b := fieldElement{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972}
a := Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676}
b := Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972}
var c, d fieldElement
var c, d Element
c.Select(&a, &b, 1)
d.Select(&a, &b, 0)
@ -417,17 +399,17 @@ func TestSelectSwap(t *testing.T) {
}
}
func TestMul32(t *testing.T) {
mul32EquivalentToMul := func(x fieldElement, y uint32) bool {
t1 := new(fieldElement)
func TestMult32(t *testing.T) {
mult32EquivalentToMul := func(x Element, y uint32) bool {
t1 := new(Element)
for i := 0; i < 100; i++ {
t1.Mult32(&x, y)
}
ty := new(fieldElement)
ty := new(Element)
ty.l0 = uint64(y)
t2 := new(fieldElement)
t2 := new(Element)
for i := 0; i < 100; i++ {
t2.Multiply(&x, ty)
}
@ -435,7 +417,7 @@ func TestMul32(t *testing.T) {
return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
}
if err := quick.Check(mul32EquivalentToMul, quickCheckConfig1024); err != nil {
if err := quick.Check(mult32EquivalentToMul, quickCheckConfig1024); err != nil {
t.Error(err)
}
}
@ -489,10 +471,10 @@ func TestSqrtRatio(t *testing.T) {
}
for i, tt := range tests {
u := (&fieldElement{}).SetBytes(decodeHex(tt.u))
v := (&fieldElement{}).SetBytes(decodeHex(tt.v))
want := (&fieldElement{}).SetBytes(decodeHex(tt.r))
got, wasSquare := (&fieldElement{}).SqrtRatio(u, v)
u := new(Element).SetBytes(decodeHex(tt.u))
v := new(Element).SetBytes(decodeHex(tt.v))
want := new(Element).SetBytes(decodeHex(tt.r))
got, wasSquare := new(Element).SqrtRatio(u, v)
if got.Equal(want) == 0 || wasSquare != tt.wasSquare {
t.Errorf("%d: got (%v, %v), want (%v, %v)", i, got, wasSquare, want, tt.wasSquare)
}
@ -501,8 +483,8 @@ func TestSqrtRatio(t *testing.T) {
func TestCarryPropagate(t *testing.T) {
asmLikeGeneric := func(a [5]uint64) bool {
t1 := &fieldElement{a[0], a[1], a[2], a[3], a[4]}
t2 := &fieldElement{a[0], a[1], a[2], a[3], a[4]}
t1 := &Element{a[0], a[1], a[2], a[3], a[4]}
t2 := &Element{a[0], a[1], a[2], a[3], a[4]}
t1.carryPropagate()
t2.carryPropagateGeneric()
@ -524,7 +506,7 @@ func TestCarryPropagate(t *testing.T) {
}
func TestFeSquare(t *testing.T) {
asmLikeGeneric := func(a fieldElement) bool {
asmLikeGeneric := func(a Element) bool {
t1 := a
t2 := a
@ -544,7 +526,7 @@ func TestFeSquare(t *testing.T) {
}
func TestFeMul(t *testing.T) {
asmLikeGeneric := func(a, b fieldElement) bool {
asmLikeGeneric := func(a, b Element) bool {
a1 := a
a2 := a
b1 := b
@ -566,3 +548,11 @@ func TestFeMul(t *testing.T) {
t.Error(err)
}
}
func decodeHex(s string) []byte {
b, err := hex.DecodeString(s)
if err != nil {
panic(err)
}
return b
}

2
go.mod
View file

@ -1,3 +1,3 @@
module filippo.io/edwards25519
go 1.14
go 1.17

140
scalar.go
View file

@ -31,8 +31,6 @@ var (
scOne = Scalar{[32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}
// sage: l = GF(2**252 + 27742317777372353535851937790883648493)
// sage: l(-1).lift().digits(256)
scMinusOne = Scalar{[32]byte{236, 211, 245, 92, 26, 99, 18, 88, 214, 156, 247, 162, 222, 249, 222, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16}}
)
@ -50,25 +48,29 @@ func (s *Scalar) MultiplyAdd(x, y, z *Scalar) *Scalar {
// Add sets s = x + y mod l, and returns s.
func (s *Scalar) Add(x, y *Scalar) *Scalar {
// s = 1 * x + y mod l
return s.MultiplyAdd(&scOne, x, y)
scMulAdd(&s.s, &scOne.s, &x.s, &y.s)
return s
}
// Subtract sets s = x - y mod l, and returns s.
func (s *Scalar) Subtract(x, y *Scalar) *Scalar {
// s = -1 * y + x mod l
return s.MultiplyAdd(&scMinusOne, y, x)
scMulAdd(&s.s, &scMinusOne.s, &y.s, &x.s)
return s
}
// Negate sets s = -x mod l, and returns s.
func (s *Scalar) Negate(x *Scalar) *Scalar {
// s = -1 * x + 0 mod l
return s.MultiplyAdd(&scMinusOne, x, &scZero)
scMulAdd(&s.s, &scMinusOne.s, &x.s, &scZero.s)
return s
}
// Multiply sets s = x * y mod l, and returns s.
func (s *Scalar) Multiply(x, y *Scalar) *Scalar {
// s = x * y + 0 mod l
return s.MultiplyAdd(x, y, &scZero)
scMulAdd(&s.s, &x.s, &y.s, &scZero.s)
return s
}
// Set sets s = x, and returns s.
@ -89,9 +91,9 @@ func (s *Scalar) SetUniformBytes(x []byte) *Scalar {
return s
}
// SetCanonicalBytes sets s = x, where x is a 32 bytes little-endian encoding of
// SetCanonicalBytes sets s = x, where x is a 32-byte little-endian encoding of
// s, and returns s. If x is not a canonical encoding of s, SetCanonicalBytes
// returns nil and an error and the receiver is unchanged.
// returns nil and an error, and the receiver is unchanged.
func (s *Scalar) SetCanonicalBytes(x []byte) (*Scalar, error) {
if len(x) != 32 {
return nil, errors.New("invalid scalar length")
@ -145,7 +147,7 @@ func (s *Scalar) SetBytesWithClamping(x []byte) *Scalar {
return s
}
// Bytes returns the canonical 32 bytes little-endian encoding of s.
// Bytes returns the canonical 32-byte little-endian encoding of s.
func (s *Scalar) Bytes() []byte {
buf := make([]byte, 32)
copy(buf, s.s[:])
@ -157,6 +159,9 @@ func (s *Scalar) Equal(t *Scalar) int {
return subtle.ConstantTimeCompare(s.s[:], t.s[:])
}
// scMulAdd and scReduce are ported from the public domain, “ref10”
// implementation of ed25519 from SUPERCOP.
func load3(in []byte) int64 {
r := int64(in[0])
r |= int64(in[1]) << 8
@ -1018,120 +1023,3 @@ func (s *Scalar) signedRadix16() [64]int8 {
return digits
}
// Given k > 0, set s = s**(2*i).
func (s *Scalar) pow2k(k int) {
for i := 0; i < k; i++ {
s.Multiply(s, s)
}
}
// Invert sets s to the inverse of a nonzero scalar v, and returns s.
//
// If t is zero, Invert will panic.
func (s *Scalar) Invert(t *Scalar) *Scalar {
if t.s == [32]byte{} {
panic("edwards25519: zero Scalar passed to Invert")
}
// Uses a hardcoded sliding window of width 4.
var table [8]Scalar
var tt Scalar
tt.Multiply(t, t)
table[0] = *t
for i := 0; i < 7; i++ {
table[i+1].Multiply(&table[i], &tt)
}
// Now table = [t**1, t**3, t**7, t**11, t**13, t**15]
// so t**k = t[k/2] for odd k
// To compute the sliding window digits, use the following Sage script:
// sage: import itertools
// sage: def sliding_window(w,k):
// ....: digits = []
// ....: while k > 0:
// ....: if k % 2 == 1:
// ....: kmod = k % (2**w)
// ....: digits.append(kmod)
// ....: k = k - kmod
// ....: else:
// ....: digits.append(0)
// ....: k = k // 2
// ....: return digits
// Now we can compute s roughly as follows:
// sage: s = 1
// sage: for coeff in reversed(sliding_window(4,l-2)):
// ....: s = s*s
// ....: if coeff > 0 :
// ....: s = s*t**coeff
// This works on one bit at a time, with many runs of zeros.
// The digits can be collapsed into [(count, coeff)] as follows:
// sage: [(len(list(group)),d) for d,group in itertools.groupby(sliding_window(4,l-2))]
// Entries of the form (k, 0) turn into pow2k(k)
// Entries of the form (1, coeff) turn into a squaring and then a table lookup.
// We can fold the squaring into the previous pow2k(k) as pow2k(k+1).
*s = table[1/2]
s.pow2k(127 + 1)
s.Multiply(s, &table[1/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[13/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[5/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[1/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(5 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(9 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[13/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[11/2])
return s
}

View file

@ -72,9 +72,6 @@ func TestScalarAliasing(t *testing.T) {
}
for name, f := range map[string]interface{}{
"Invert": func(v Scalar, x notZeroScalar) bool {
return checkAliasingOneArg((*Scalar).Invert, v, Scalar(x))
},
"Negate": func(v, x Scalar) bool {
return checkAliasingOneArg((*Scalar).Negate, v, x)
},

View file

@ -44,6 +44,10 @@ func (Scalar) Generate(rand *mathrand.Rand, size int) reflect.Value {
return reflect.ValueOf(s)
}
// quickCheckConfig1024 will make each quickcheck test run (1024 * -quickchecks)
// times. The default value of -quickchecks is 100.
var quickCheckConfig1024 = &quick.Config{MaxCountScale: 1 << 10}
func TestScalarGenerate(t *testing.T) {
f := func(sc Scalar) bool {
return isReduced(&sc)
@ -109,24 +113,24 @@ func TestScalarSetBytesWithClamping(t *testing.T) {
// Generated with libsodium.js 1.0.18 crypto_scalarmult_ed25519_base.
random := "633d368491364dc9cd4c1bf891b1d59460face1644813240a313e61f2c88216e"
s := (&Scalar{}).SetBytesWithClamping(decodeHex(random))
p := (&Point{}).ScalarBaseMult(s)
s := new(Scalar).SetBytesWithClamping(decodeHex(random))
p := new(Point).ScalarBaseMult(s)
want := "1d87a9026fd0126a5736fe1628c95dd419172b5b618457e041c9c861b2494a94"
if got := hex.EncodeToString(p.Bytes()); got != want {
t.Errorf("random: got %q, want %q", got, want)
}
zero := "0000000000000000000000000000000000000000000000000000000000000000"
s = (&Scalar{}).SetBytesWithClamping(decodeHex(zero))
p = (&Point{}).ScalarBaseMult(s)
s = new(Scalar).SetBytesWithClamping(decodeHex(zero))
p = new(Point).ScalarBaseMult(s)
want = "693e47972caf527c7883ad1b39822f026f47db2ab0e1919955b8993aa04411d1"
if got := hex.EncodeToString(p.Bytes()); got != want {
t.Errorf("zero: got %q, want %q", got, want)
}
one := "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
s = (&Scalar{}).SetBytesWithClamping(decodeHex(one))
p = (&Point{}).ScalarBaseMult(s)
s = new(Scalar).SetBytesWithClamping(decodeHex(one))
p = new(Point).ScalarBaseMult(s)
want = "12e9a68b73fd5aacdbcaf3e88c46fea6ebedb1aa84eed1842f07f8edab65e3a7"
if got := hex.EncodeToString(p.Bytes()); got != want {
t.Errorf("one: got %q, want %q", got, want)
@ -141,8 +145,8 @@ func bigIntFromLittleEndianBytes(b []byte) *big.Int {
return new(big.Int).SetBytes(bb)
}
func TestScalarMulDistributesOverScalarAdd(t *testing.T) {
mulDistributesOverAdd := func(x, y, z Scalar) bool {
func TestScalarMultiplyDistributesOverAdd(t *testing.T) {
multiplyDistributesOverAdd := func(x, y, z Scalar) bool {
// Compute t1 = (x+y)*z
var t1 Scalar
t1.Add(&x, &y)
@ -158,7 +162,7 @@ func TestScalarMulDistributesOverScalarAdd(t *testing.T) {
return t1 == t2 && isReduced(&t1) && isReduced(&t3)
}
if err := quick.Check(mulDistributesOverAdd, quickCheckConfig1024); err != nil {
if err := quick.Check(multiplyDistributesOverAdd, quickCheckConfig1024); err != nil {
t.Error(err)
}
}
@ -219,19 +223,6 @@ func (notZeroScalar) Generate(rand *mathrand.Rand, size int) reflect.Value {
return reflect.ValueOf(notZeroScalar(s))
}
func TestScalarInvert(t *testing.T) {
invertWorks := func(xInv Scalar, x notZeroScalar) bool {
xInv.Invert((*Scalar)(&x))
var check Scalar
check.Multiply((*Scalar)(&x), &xInv)
return check == scOne && isReduced(&xInv)
}
if err := quick.Check(invertWorks, quickCheckConfig32); err != nil {
t.Error(err)
}
}
func TestScalarEqual(t *testing.T) {
if scOne.Equal(&scMinusOne) == 1 {
t.Errorf("scOne.Equal(&scMinusOne) is true")

View file

@ -4,11 +4,35 @@
package edwards25519
import "sync"
// basepointTable is a set of 32 affineLookupTables, where table i is generated
// from 256i * basepoint. It is precomputed the first time it's used.
func basepointTable() *[32]affineLookupTable {
basepointTablePrecomp.initOnce.Do(func() {
p := NewGeneratorPoint()
for i := 0; i < 32; i++ {
basepointTablePrecomp.table[i].FromP3(p)
for j := 0; j < 8; j++ {
p.Add(p, p)
}
}
})
return &basepointTablePrecomp.table
}
var basepointTablePrecomp struct {
table [32]affineLookupTable
initOnce sync.Once
}
// ScalarBaseMult sets v = x * B, where B is the canonical generator, and
// returns v.
//
// The scalar multiplication is done in constant time.
func (v *Point) ScalarBaseMult(x *Scalar) *Point {
basepointTable := basepointTable()
// Write x = sum(x_i * 16^i) so x*B = sum( B*x_i*16^i )
// as described in the Ed25519 paper
//
@ -98,58 +122,18 @@ func (v *Point) ScalarMult(x *Scalar, q *Point) *Point {
return v
}
// MultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
//
// Execution time depends only on the lengths of the two slices, which must match.
func (v *Point) MultiScalarMult(scalars []*Scalar, points []*Point) *Point {
if len(scalars) != len(points) {
panic("edwards25519: called MultiScalarMult with different size inputs")
}
checkInitialized(points...)
// basepointNafTable is the nafLookupTable8 for the basepoint.
// It is precomputed the first time it's used.
func basepointNafTable() *nafLookupTable8 {
basepointNafTablePrecomp.initOnce.Do(func() {
basepointNafTablePrecomp.table.FromP3(NewGeneratorPoint())
})
return &basepointNafTablePrecomp.table
}
// Proceed as in the single-base case, but share doublings
// between each point in the multiscalar equation.
// Build lookup tables for each point
tables := make([]projLookupTable, len(points))
for i := range tables {
tables[i].FromP3(points[i])
}
// Compute signed radix-16 digits for each scalar
digits := make([][64]int8, len(scalars))
for i := range digits {
digits[i] = scalars[i].signedRadix16()
}
// Unwrap first loop iteration to save computing 16*identity
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
// Lookup-and-add the appropriate multiple of each input point
for j := range tables {
tables[j].SelectInto(multiple, digits[j][63])
tmp1.Add(v, multiple) // tmp1 = v + x_(j,63)*Q in P1xP1 coords
v.fromP1xP1(tmp1) // update v
}
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
for i := 62; i >= 0; i-- {
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
v.fromP1xP1(tmp1) // v = 16*(prev) in P3 coords
// Lookup-and-add the appropriate multiple of each input point
for j := range tables {
tables[j].SelectInto(multiple, digits[j][i])
tmp1.Add(v, multiple) // tmp1 = v + x_(j,i)*Q in P1xP1 coords
v.fromP1xP1(tmp1) // update v
}
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
}
return v
var basepointNafTablePrecomp struct {
table nafLookupTable8
initOnce sync.Once
}
// VarTimeDoubleScalarBaseMult sets v = a * A + b * B, where B is the canonical
@ -173,6 +157,7 @@ func (v *Point) VarTimeDoubleScalarBaseMult(a *Scalar, A *Point, b *Scalar) *Poi
// "mass" of the scalar onto sparse coefficients (meaning
// fewer additions).
basepointNafTable := basepointNafTable()
var aTable nafLookupTable5
aTable.FromP3(A)
// Because the basepoint is fixed, we can use a wider NAF
@ -227,60 +212,3 @@ func (v *Point) VarTimeDoubleScalarBaseMult(a *Scalar, A *Point, b *Scalar) *Poi
v.fromP2(tmp2)
return v
}
// VarTimeMultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
//
// Execution time depends on the inputs.
func (v *Point) VarTimeMultiScalarMult(scalars []*Scalar, points []*Point) *Point {
if len(scalars) != len(points) {
panic("edwards25519: called VarTimeMultiScalarMult with different size inputs")
}
checkInitialized(points...)
// Generalize double-base NAF computation to arbitrary sizes.
// Here all the points are dynamic, so we only use the smaller
// tables.
// Build lookup tables for each point
tables := make([]nafLookupTable5, len(points))
for i := range tables {
tables[i].FromP3(points[i])
}
// Compute a NAF for each scalar
nafs := make([][256]int8, len(scalars))
for i := range nafs {
nafs[i] = scalars[i].nonAdjacentForm(5)
}
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
tmp2.Zero()
// Move from high to low bits, doubling the accumulator
// at each iteration and checking whether there is a nonzero
// coefficient to look up a multiple of.
//
// Skip trying to find the first nonzero coefficent, because
// searching might be more work than a few extra doublings.
for i := 255; i >= 0; i-- {
tmp1.Double(tmp2)
for j := range nafs {
if nafs[j][i] > 0 {
v.fromP1xP1(tmp1)
tables[j].SelectInto(multiple, nafs[j][i])
tmp1.Add(v, multiple)
} else if nafs[j][i] < 0 {
v.fromP1xP1(tmp1)
tables[j].SelectInto(multiple, -nafs[j][i])
tmp1.Sub(v, multiple)
}
}
tmp2.FromP1xP1(tmp1)
}
v.fromP2(tmp2)
return v
}

View file

@ -17,12 +17,7 @@ var (
// a random scalar generated using dalek.
dalekScalar = Scalar{[32]byte{219, 106, 114, 9, 174, 249, 155, 89, 69, 203, 201, 93, 92, 116, 234, 187, 78, 115, 103, 172, 182, 98, 62, 103, 187, 136, 13, 100, 248, 110, 12, 4}}
// the above, times the edwards25519 basepoint.
dalekScalarBasepoint = Point{
x: fieldElement{778774234987948, 1589187156384239, 1213330452914652, 186161118421127, 2186284806803213},
y: fieldElement{1241255309069369, 1115278942994853, 1016511918109334, 1303231926552315, 1801448517689873},
z: fieldElement{353337085654440, 1327844406437681, 2207296012811921, 707394926933424, 917408459573183},
t: fieldElement{585487439439725, 1792815221887900, 946062846079052, 1954901232609667, 1418300670001780},
}
dalekScalarBasepoint, _ = new(Point).SetBytes([]byte{0xf4, 0xef, 0x7c, 0xa, 0x34, 0x55, 0x7b, 0x9f, 0x72, 0x3b, 0xb6, 0x1e, 0xf9, 0x46, 0x9, 0x91, 0x1c, 0xb9, 0xc0, 0x6c, 0x17, 0x28, 0x2d, 0x8b, 0x43, 0x2b, 0x5, 0x18, 0x6a, 0x54, 0x3e, 0x48})
)
func TestScalarMultSmallScalars(t *testing.T) {
@ -51,7 +46,7 @@ func TestScalarMultVsDalek(t *testing.T) {
checkOnCurve(t, &p)
}
func TestBasepointMulVsDalek(t *testing.T) {
func TestBaseMultVsDalek(t *testing.T) {
var p Point
p.ScalarBaseMult(&dalekScalar)
if dalekScalarBasepoint.Equal(&p) != 1 {
@ -60,23 +55,23 @@ func TestBasepointMulVsDalek(t *testing.T) {
checkOnCurve(t, &p)
}
func TestVartimeDoubleBaseMulVsDalek(t *testing.T) {
func TestVarTimeDoubleBaseMultVsDalek(t *testing.T) {
var p Point
var z Scalar
p.VarTimeDoubleScalarBaseMult(&dalekScalar, B, &z)
if dalekScalarBasepoint.Equal(&p) != 1 {
t.Error("VartimeDoubleBaseMul fails with b=0")
t.Error("VarTimeDoubleScalarBaseMult fails with b=0")
}
checkOnCurve(t, &p)
p.VarTimeDoubleScalarBaseMult(&z, B, &dalekScalar)
if dalekScalarBasepoint.Equal(&p) != 1 {
t.Error("VartimeDoubleBaseMul fails with a=0")
t.Error("VarTimeDoubleScalarBaseMult fails with a=0")
}
checkOnCurve(t, &p)
}
func TestScalarMulDistributesOverAdd(t *testing.T) {
scalarMulDistributesOverAdd := func(x, y Scalar) bool {
func TestScalarMultDistributesOverAdd(t *testing.T) {
scalarMultDistributesOverAdd := func(x, y Scalar) bool {
var z Scalar
z.Add(&x, &y)
var p, q, r, check Point
@ -88,16 +83,16 @@ func TestScalarMulDistributesOverAdd(t *testing.T) {
return check.Equal(&r) == 1
}
if err := quick.Check(scalarMulDistributesOverAdd, quickCheckConfig32); err != nil {
if err := quick.Check(scalarMultDistributesOverAdd, quickCheckConfig32); err != nil {
t.Error(err)
}
}
func TestScalarMulNonIdentityPoint(t *testing.T) {
func TestScalarMultNonIdentityPoint(t *testing.T) {
// Check whether p.ScalarMult and q.ScalaBaseMult give the same,
// when p and q are originally set to the base point.
scalarMulNonIdentityPoint := func(x Scalar) bool {
scalarMultNonIdentityPoint := func(x Scalar) bool {
var p, q Point
p.Set(B)
q.Set(B)
@ -110,7 +105,7 @@ func TestScalarMulNonIdentityPoint(t *testing.T) {
return p.Equal(&q) == 1
}
if err := quick.Check(scalarMulNonIdentityPoint, quickCheckConfig32); err != nil {
if err := quick.Check(scalarMultNonIdentityPoint, quickCheckConfig32); err != nil {
t.Error(err)
}
}
@ -118,6 +113,7 @@ func TestScalarMulNonIdentityPoint(t *testing.T) {
func TestBasepointTableGeneration(t *testing.T) {
// The basepoint table is 32 affineLookupTables,
// corresponding to (16^2i)*B for table i.
basepointTable := basepointTable()
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
@ -144,8 +140,8 @@ func TestBasepointTableGeneration(t *testing.T) {
}
}
func TestScalarMulMatchesBasepointMul(t *testing.T) {
scalarMulMatchesBasepointMul := func(x Scalar) bool {
func TestScalarMultMatchesBaseMult(t *testing.T) {
scalarMultMatchesBaseMult := func(x Scalar) bool {
var p, q Point
p.ScalarMult(&x, B)
q.ScalarBaseMult(&x)
@ -153,27 +149,7 @@ func TestScalarMulMatchesBasepointMul(t *testing.T) {
return p.Equal(&q) == 1
}
if err := quick.Check(scalarMulMatchesBasepointMul, quickCheckConfig32); err != nil {
t.Error(err)
}
}
func TestMultiScalarMulMatchesBasepointMul(t *testing.T) {
multiScalarMulMatchesBasepointMul := func(x, y, z Scalar) bool {
var p, q1, q2, q3, check Point
p.MultiScalarMult([]*Scalar{&x, &y, &z}, []*Point{B, B, B})
q1.ScalarBaseMult(&x)
q2.ScalarBaseMult(&y)
q3.ScalarBaseMult(&z)
check.Add(&q1, &q2).Add(&check, &q3)
checkOnCurve(t, &p, &check, &q1, &q2, &q3)
return p.Equal(&check) == 1
}
if err := quick.Check(multiScalarMulMatchesBasepointMul, quickCheckConfig32); err != nil {
if err := quick.Check(scalarMultMatchesBaseMult, quickCheckConfig32); err != nil {
t.Error(err)
}
}
@ -182,13 +158,13 @@ func TestBasepointNafTableGeneration(t *testing.T) {
var table nafLookupTable8
table.FromP3(B)
if table != basepointNafTable {
if table != *basepointNafTable() {
t.Error("BasepointNafTable does not match")
}
}
func TestVartimeDoubleBaseMulMatchesBasepointMul(t *testing.T) {
vartimeDoubleBaseMulMatchesBasepointMul := func(x, y Scalar) bool {
func TestVarTimeDoubleBaseMultMatchesBaseMult(t *testing.T) {
varTimeDoubleBaseMultMatchesBaseMult := func(x, y Scalar) bool {
var p, q1, q2, check Point
p.VarTimeDoubleScalarBaseMult(&x, B, &y)
@ -201,34 +177,14 @@ func TestVartimeDoubleBaseMulMatchesBasepointMul(t *testing.T) {
return p.Equal(&check) == 1
}
if err := quick.Check(vartimeDoubleBaseMulMatchesBasepointMul, quickCheckConfig32); err != nil {
t.Error(err)
}
}
func TestVartimeMultiScalarMulMatchesBasepointMul(t *testing.T) {
vartimeMultiScalarMulMatchesBasepointMul := func(x, y, z Scalar) bool {
var p, q1, q2, q3, check Point
p.VarTimeMultiScalarMult([]*Scalar{&x, &y, &z}, []*Point{B, B, B})
q1.ScalarBaseMult(&x)
q2.ScalarBaseMult(&y)
q3.ScalarBaseMult(&z)
check.Add(&q1, &q2).Add(&check, &q3)
checkOnCurve(t, &p, &check, &q1, &q2, &q3)
return p.Equal(&check) == 1
}
if err := quick.Check(vartimeMultiScalarMulMatchesBasepointMul, quickCheckConfig32); err != nil {
if err := quick.Check(varTimeDoubleBaseMultMatchesBaseMult, quickCheckConfig32); err != nil {
t.Error(err)
}
}
// Benchmarks.
func BenchmarkBasepointMul(t *testing.B) {
func BenchmarkScalarBaseMult(t *testing.B) {
var p Point
for i := 0; i < t.N; i++ {
@ -236,7 +192,7 @@ func BenchmarkBasepointMul(t *testing.B) {
}
}
func BenchmarkScalarMul(t *testing.B) {
func BenchmarkScalarMult(t *testing.B) {
var p Point
for i := 0; i < t.N; i++ {
@ -244,22 +200,10 @@ func BenchmarkScalarMul(t *testing.B) {
}
}
func BenchmarkVartimeDoubleBaseMul(t *testing.B) {
func BenchmarkVarTimeDoubleScalarBaseMult(t *testing.B) {
var p Point
for i := 0; i < t.N; i++ {
p.VarTimeDoubleScalarBaseMult(&dalekScalar, B, &dalekScalar)
}
}
func BenchmarkMultiscalarMulSize8(t *testing.B) {
var p Point
x := dalekScalar
for i := 0; i < t.N; i++ {
p.MultiScalarMult([]*Scalar{&x, &x, &x, &x, &x, &x, &x, &x}, []*Point{B, B, B, B, B, B, B, B})
}
}
// TODO: add BenchmarkVartimeMultiscalarMulSize8 (need to have
// different scalars & points to measure cache effects).

File diff suppressed because one or more lines are too long

View file

@ -45,7 +45,7 @@ func (v *projLookupTable) FromP3(q *Point) {
}
}
// This is not optimised for speed; affine tables should be precomputed.
// This is not optimised for speed; fixed-base tables should be precomputed.
func (v *affineLookupTable) FromP3(q *Point) {
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
@ -72,7 +72,7 @@ func (v *nafLookupTable5) FromP3(q *Point) {
}
}
// This is not optimised for speed; affine tables should be precomputed.
// This is not optimised for speed; fixed-base tables should be precomputed.
func (v *nafLookupTable8) FromP3(q *Point) {
v.points[0].FromP3(q)
q2 := Point{}

View file

@ -29,7 +29,7 @@ func TestProjLookupTable(t *testing.T) {
accP3.fromP1xP1(&accP1xP1)
if accP3.Equal(I) != 1 {
t.Errorf("Sanity check on ProjLookupTable.SelectInto failed! %x %x %x", tmp1, tmp2, tmp3)
t.Errorf("Consistency check on ProjLookupTable.SelectInto failed! %x %x %x", tmp1, tmp2, tmp3)
}
}
@ -54,7 +54,7 @@ func TestAffineLookupTable(t *testing.T) {
accP3.fromP1xP1(&accP1xP1)
if accP3.Equal(I) != 1 {
t.Errorf("Sanity check on ProjLookupTable.SelectInto failed! %x %x %x", tmp1, tmp2, tmp3)
t.Errorf("Consistency check on ProjLookupTable.SelectInto failed! %x %x %x", tmp1, tmp2, tmp3)
}
}
@ -84,7 +84,7 @@ func TestNafLookupTable5(t *testing.T) {
rhs.fromP1xP1(&accP1xP1)
if lhs.Equal(rhs) != 1 {
t.Errorf("Sanity check on nafLookupTable5 failed")
t.Errorf("Consistency check on nafLookupTable5 failed")
}
}
@ -114,6 +114,6 @@ func TestNafLookupTable8(t *testing.T) {
rhs.fromP1xP1(&accP1xP1)
if lhs.Equal(rhs) != 1 {
t.Errorf("Sanity check on nafLookupTable8 failed")
t.Errorf("Consistency check on nafLookupTable8 failed")
}
}