scalar: clean up fiat wrapper

This commit is contained in:
Filippo Valsorda 2022-07-31 20:07:46 +02:00 committed by Filippo Valsorda
parent 755954a498
commit 50a0a9e22d
3 changed files with 46 additions and 42 deletions

View file

@ -107,7 +107,7 @@ func TestScalarInvert(t *testing.T) {
var check Scalar var check Scalar
check.Multiply((*Scalar)(&x), &xInv) check.Multiply((*Scalar)(&x), &xInv)
return check.Equal(&scOne) == 1 && isReduced(xInv.Bytes()) return check.Equal(scOne) == 1 && isReduced(xInv.Bytes())
} }
if err := quick.Check(invertWorks, quickCheckConfig32); err != nil { if err := quick.Check(invertWorks, quickCheckConfig32); err != nil {
@ -119,7 +119,7 @@ func TestScalarInvert(t *testing.T) {
var check Scalar var check Scalar
check.Multiply(&randomScalar, randomInverse) check.Multiply(&randomScalar, randomInverse)
if check.Equal(&scOne) == 0 || !isReduced(randomInverse.Bytes()) { if check.Equal(scOne) == 0 || !isReduced(randomInverse.Bytes()) {
t.Error("inversion did not work") t.Error("inversion did not work")
} }

View file

@ -5,7 +5,6 @@
package edwards25519 package edwards25519
import ( import (
"crypto/subtle"
"encoding/binary" "encoding/binary"
"errors" "errors"
) )
@ -21,20 +20,11 @@ import (
// //
// The zero value is a valid zero element. // The zero value is a valid zero element.
type Scalar struct { type Scalar struct {
// A Scalar is an integer modulo l = 2^252 + 27742317777372353535851937790883648493. // s is the scalar in the Montgomery domain, in the format of the
// Internally, this implementation keeps the scalar in the Montgomery domain. // fiat-crypto implementation.
s fiat_sc255_montgomery_domain_field_element s fiat_sc255_montgomery_domain_field_element
} }
var (
scZeroBytes = [32]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
scOneBytes = [32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
scMinusOneBytes = [32]byte{236, 211, 245, 92, 26, 99, 18, 88, 214, 156, 247, 162, 222, 249, 222, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16}
scOne = Scalar{[4]uint64{0xd6ec31748d98951d, 0xc6ef5bf4737dcf70, 0xfffffffffffffffe, 0xfffffffffffffff}}
scMinusOne = Scalar{[4]uint64{0x812631a5cf5d3ed0, 0x4def9dea2f79cd65, 1, 0}}
)
// NewScalar returns a new zero Scalar. // NewScalar returns a new zero Scalar.
func NewScalar() *Scalar { func NewScalar() *Scalar {
return &Scalar{} return &Scalar{}
@ -92,12 +82,10 @@ func (s *Scalar) SetUniformBytes(x []byte) (*Scalar, error) {
if len(x) != 64 { if len(x) != 64 {
return nil, errors.New("edwards25519: invalid SetUniformBytes input length") return nil, errors.New("edwards25519: invalid SetUniformBytes input length")
} }
var wideBytes [64]byte
copy(wideBytes[:], x[:])
// TODO: We should deprecate scReduce as well, but we retain it here for consistent behavior // TODO: replace scReduce with a limbed reduction.
var reduced [32]byte var reduced [32]byte
scReduce(&reduced, &wideBytes) scReduce(&reduced, (*[64]byte)(x))
fiat_sc255_from_bytes((*[4]uint64)(&s.s), &reduced) fiat_sc255_from_bytes((*[4]uint64)(&s.s), &reduced)
fiat_sc255_to_montgomery(&s.s, (*fiat_sc255_non_montgomery_domain_field_element)(&s.s)) fiat_sc255_to_montgomery(&s.s, (*fiat_sc255_non_montgomery_domain_field_element)(&s.s))
@ -112,21 +100,21 @@ func (s *Scalar) SetCanonicalBytes(x []byte) (*Scalar, error) {
if len(x) != 32 { if len(x) != 32 {
return nil, errors.New("invalid scalar length") return nil, errors.New("invalid scalar length")
} }
if !isReduced(x) {
// Use bytes here because the original logic assumed the old 32-byte LE representation
ss := [32]byte{}
copy(ss[:], x)
if !isReduced(ss[:]) {
return nil, errors.New("invalid scalar encoding") return nil, errors.New("invalid scalar encoding")
} }
fiat_sc255_from_bytes((*[4]uint64)(&s.s), &ss) fiat_sc255_from_bytes((*[4]uint64)(&s.s), (*[32]byte)(x))
fiat_sc255_to_montgomery(&s.s, (*fiat_sc255_non_montgomery_domain_field_element)(&s.s)) fiat_sc255_to_montgomery(&s.s, (*fiat_sc255_non_montgomery_domain_field_element)(&s.s))
return s, nil return s, nil
} }
// isReduced returns whether the given scalar in 32-byte little endian encoded form is reduced modulo l. // scalarMinusOneBytes is l - 1 in little endian.
var scalarMinusOneBytes = [32]byte{236, 211, 245, 92, 26, 99, 18, 88, 214, 156, 247, 162, 222, 249, 222, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16}
// isReduced returns whether the given scalar in 32-byte little endian encoded
// form is reduced modulo l.
func isReduced(s []byte) bool { func isReduced(s []byte) bool {
if len(s) != 32 { if len(s) != 32 {
return false return false
@ -134,9 +122,9 @@ func isReduced(s []byte) bool {
for i := len(s) - 1; i >= 0; i-- { for i := len(s) - 1; i >= 0; i-- {
switch { switch {
case s[i] > scMinusOneBytes[i]: case s[i] > scalarMinusOneBytes[i]:
return false return false
case s[i] < scMinusOneBytes[i]: case s[i] < scalarMinusOneBytes[i]:
return true return true
} }
} }
@ -162,42 +150,54 @@ func (s *Scalar) SetBytesWithClamping(x []byte) (*Scalar, error) {
if len(x) != 32 { if len(x) != 32 {
return nil, errors.New("edwards25519: invalid SetBytesWithClamping input length") return nil, errors.New("edwards25519: invalid SetBytesWithClamping input length")
} }
var wideBytes [64]byte var wideBytes [64]byte
copy(wideBytes[:], x[:]) copy(wideBytes[:], x[:])
wideBytes[0] &= 248 wideBytes[0] &= 248
wideBytes[31] &= 63 wideBytes[31] &= 63
wideBytes[31] |= 64 wideBytes[31] |= 64
// TODO: replace scReduce with a limbed reduction.
var reduced [32]byte var reduced [32]byte
scReduce(&reduced, &wideBytes) scReduce(&reduced, &wideBytes)
fiat_sc255_from_bytes((*[4]uint64)(&s.s), &reduced) fiat_sc255_from_bytes((*[4]uint64)(&s.s), &reduced)
fiat_sc255_to_montgomery(&s.s, (*fiat_sc255_non_montgomery_domain_field_element)(&s.s)) fiat_sc255_to_montgomery(&s.s, (*fiat_sc255_non_montgomery_domain_field_element)(&s.s))
return s, nil return s, nil
} }
// Bytes returns the canonical 32-byte little-endian encoding of s. // Bytes returns the canonical 32-byte little-endian encoding of s.
func (s *Scalar) Bytes() []byte { func (s *Scalar) Bytes() []byte {
// This pattern, called "outlining", allows this function to inline so the // This function is outlined to make the allocations inline in the caller
// allocations can occur on the caller stack rather than escaping to the heap. // rather than happen on the heap.
// See https://blog.filippo.io/efficient-go-apis-with-the-inliner for more details.
var encoded [32]byte var encoded [32]byte
return s.bytes(&encoded) return s.bytes(&encoded)
} }
func (s *Scalar) bytes(out *[32]byte) []byte { func (s *Scalar) bytes(out *[32]byte) []byte {
var limbs fiat_sc255_non_montgomery_domain_field_element var ss fiat_sc255_non_montgomery_domain_field_element
fiat_sc255_from_montgomery(&limbs, &s.s) fiat_sc255_from_montgomery(&ss, &s.s)
fiat_sc255_to_bytes(out, (*[4]uint64)(&limbs)) fiat_sc255_to_bytes(out, (*[4]uint64)(&ss))
return out[:] return out[:]
} }
// Equal returns 1 if s and t are equal, and 0 otherwise. // Equal returns 1 if s and t are equal, and 0 otherwise.
func (s *Scalar) Equal(t *Scalar) int { func (s *Scalar) Equal(t *Scalar) int {
st := t.Bytes() var diff fiat_sc255_montgomery_domain_field_element
ss := s.Bytes() fiat_sc255_sub(&diff, &s.s, &t.s)
return subtle.ConstantTimeCompare(ss[:], st[:]) var nonzero uint64
fiat_sc255_nonzero(&nonzero, (*[4]uint64)(&diff))
nonzero |= nonzero >> 32
nonzero |= nonzero >> 16
nonzero |= nonzero >> 8
nonzero |= nonzero >> 4
nonzero |= nonzero >> 2
nonzero |= nonzero >> 1
return int(^nonzero) & 1
} }
// scMulAdd and scReduce are ported from the public domain, “ref10” // scReduce is ported from the public domain, “ref10”
// implementation of ed25519 from SUPERCOP. // implementation of ed25519 from SUPERCOP.
func load3(in []byte) int64 { func load3(in []byte) int64 {

View file

@ -14,17 +14,21 @@ import (
"testing/quick" "testing/quick"
) )
var scOneBytes = [32]byte{1}
var scOne, _ = new(Scalar).SetCanonicalBytes(scOneBytes[:])
var scMinusOne, _ = new(Scalar).SetCanonicalBytes(scalarMinusOneBytes[:])
// Generate returns a valid (reduced modulo l) Scalar with a distribution // Generate returns a valid (reduced modulo l) Scalar with a distribution
// weighted towards high, low, and edge values. // weighted towards high, low, and edge values.
func (Scalar) Generate(rand *mathrand.Rand, size int) reflect.Value { func (Scalar) Generate(rand *mathrand.Rand, size int) reflect.Value {
s := scZeroBytes var s [32]byte
diceRoll := rand.Intn(100) diceRoll := rand.Intn(100)
switch { switch {
case diceRoll == 0: case diceRoll == 0:
case diceRoll == 1: case diceRoll == 1:
s = scOneBytes s = scOneBytes
case diceRoll == 2: case diceRoll == 2:
s = scMinusOneBytes s = scalarMinusOneBytes
case diceRoll < 5: case diceRoll < 5:
// Generate a low scalar in [0, 2^125). // Generate a low scalar in [0, 2^125).
rand.Read(s[:16]) rand.Read(s[:16])
@ -86,7 +90,7 @@ func TestScalarSetCanonicalBytes(t *testing.T) {
t.Errorf("failed scalar->bytes->scalar round-trip: %v", err) t.Errorf("failed scalar->bytes->scalar round-trip: %v", err)
} }
b := scMinusOneBytes b := scalarMinusOneBytes
b[31] += 1 b[31] += 1
s := scOne s := scOne
if out, err := s.SetCanonicalBytes(b[:]); err == nil { if out, err := s.SetCanonicalBytes(b[:]); err == nil {
@ -236,10 +240,10 @@ func (notZeroScalar) Generate(rand *mathrand.Rand, size int) reflect.Value {
} }
func TestScalarEqual(t *testing.T) { func TestScalarEqual(t *testing.T) {
if scOne.Equal(&scMinusOne) == 1 { if scOne.Equal(scMinusOne) == 1 {
t.Errorf("scOne.Equal(&scMinusOne) is true") t.Errorf("scOne.Equal(&scMinusOne) is true")
} }
if scMinusOne.Equal(&scMinusOne) == 0 { if scMinusOne.Equal(scMinusOne) == 0 {
t.Errorf("scMinusOne.Equal(&scMinusOne) is false") t.Errorf("scMinusOne.Equal(&scMinusOne) is false")
} }
} }