Coverage Report

Created: 2025-11-16 06:36

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/rust/registry/src/index.crates.io-1949cf8c6b5b557f/elliptic-curve-0.13.8/src/scalar/nonzero.rs
Line
Count
Source
1
//! Non-zero scalar type.
2
3
use crate::{
4
    ops::{Invert, Reduce, ReduceNonZero},
5
    scalar::IsHigh,
6
    CurveArithmetic, Error, FieldBytes, PrimeCurve, Scalar, ScalarPrimitive, SecretKey,
7
};
8
use base16ct::HexDisplay;
9
use core::{
10
    fmt,
11
    ops::{Deref, Mul, Neg},
12
    str,
13
};
14
use crypto_bigint::{ArrayEncoding, Integer};
15
use ff::{Field, PrimeField};
16
use generic_array::{typenum::Unsigned, GenericArray};
17
use rand_core::CryptoRngCore;
18
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
19
use zeroize::Zeroize;
20
21
#[cfg(feature = "serde")]
22
use serdect::serde::{de, ser, Deserialize, Serialize};
23
24
/// Non-zero scalar type.
25
///
26
/// This type ensures that its value is not zero, ala `core::num::NonZero*`.
27
/// To do this, the generic `S` type must impl both `Default` and
28
/// `ConstantTimeEq`, with the requirement that `S::default()` returns 0.
29
///
30
/// In the context of ECC, it's useful for ensuring that scalar multiplication
31
/// cannot result in the point at infinity.
32
#[derive(Clone)]
33
pub struct NonZeroScalar<C>
34
where
35
    C: CurveArithmetic,
36
{
37
    scalar: Scalar<C>,
38
}
39
40
impl<C> NonZeroScalar<C>
41
where
42
    C: CurveArithmetic,
43
{
44
    /// Generate a random `NonZeroScalar`.
45
36.2k
    pub fn random(mut rng: &mut impl CryptoRngCore) -> Self {
46
        // Use rejection sampling to eliminate zero values.
47
        // While this method isn't constant-time, the attacker shouldn't learn
48
        // anything about unrelated outputs so long as `rng` is a secure `CryptoRng`.
49
        loop {
50
36.2k
            if let Some(result) = Self::new(Field::random(&mut rng)).into() {
51
36.2k
                break result;
52
0
            }
53
        }
54
36.2k
    }
55
56
    /// Create a [`NonZeroScalar`] from a scalar.
57
36.2k
    pub fn new(scalar: Scalar<C>) -> CtOption<Self> {
58
36.2k
        CtOption::new(Self { scalar }, !scalar.is_zero())
59
36.2k
    }
60
61
    /// Decode a [`NonZeroScalar`] from a big endian-serialized field element.
62
    pub fn from_repr(repr: FieldBytes<C>) -> CtOption<Self> {
63
        Scalar::<C>::from_repr(repr).and_then(Self::new)
64
    }
65
66
    /// Create a [`NonZeroScalar`] from a `C::Uint`.
67
    pub fn from_uint(uint: C::Uint) -> CtOption<Self> {
68
        ScalarPrimitive::new(uint).and_then(|scalar| Self::new(scalar.into()))
69
    }
70
}
71
72
impl<C> AsRef<Scalar<C>> for NonZeroScalar<C>
73
where
74
    C: CurveArithmetic,
75
{
76
4.72k
    fn as_ref(&self) -> &Scalar<C> {
77
4.72k
        &self.scalar
78
4.72k
    }
79
}
80
81
impl<C> ConditionallySelectable for NonZeroScalar<C>
82
where
83
    C: CurveArithmetic,
84
{
85
    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
86
        Self {
87
            scalar: Scalar::<C>::conditional_select(&a.scalar, &b.scalar, choice),
88
        }
89
    }
90
}
91
92
impl<C> ConstantTimeEq for NonZeroScalar<C>
93
where
94
    C: CurveArithmetic,
95
{
96
    fn ct_eq(&self, other: &Self) -> Choice {
97
        self.scalar.ct_eq(&other.scalar)
98
    }
99
}
100
101
impl<C> Copy for NonZeroScalar<C> where C: CurveArithmetic {}
102
103
impl<C> Deref for NonZeroScalar<C>
104
where
105
    C: CurveArithmetic,
106
{
107
    type Target = Scalar<C>;
108
109
6.06k
    fn deref(&self) -> &Scalar<C> {
110
6.06k
        &self.scalar
111
6.06k
    }
112
}
113
114
impl<C> From<NonZeroScalar<C>> for FieldBytes<C>
115
where
116
    C: CurveArithmetic,
117
{
118
    fn from(scalar: NonZeroScalar<C>) -> FieldBytes<C> {
119
        Self::from(&scalar)
120
    }
121
}
122
123
impl<C> From<&NonZeroScalar<C>> for FieldBytes<C>
124
where
125
    C: CurveArithmetic,
126
{
127
    fn from(scalar: &NonZeroScalar<C>) -> FieldBytes<C> {
128
        scalar.to_repr()
129
    }
130
}
131
132
impl<C> From<NonZeroScalar<C>> for ScalarPrimitive<C>
133
where
134
    C: CurveArithmetic,
135
{
136
    #[inline]
137
    fn from(scalar: NonZeroScalar<C>) -> ScalarPrimitive<C> {
138
        Self::from(&scalar)
139
    }
140
}
141
142
impl<C> From<&NonZeroScalar<C>> for ScalarPrimitive<C>
143
where
144
    C: CurveArithmetic,
145
{
146
    fn from(scalar: &NonZeroScalar<C>) -> ScalarPrimitive<C> {
147
        ScalarPrimitive::from_bytes(&scalar.to_repr()).unwrap()
148
    }
149
}
150
151
impl<C> From<SecretKey<C>> for NonZeroScalar<C>
152
where
153
    C: CurveArithmetic,
154
{
155
    fn from(sk: SecretKey<C>) -> NonZeroScalar<C> {
156
        Self::from(&sk)
157
    }
158
}
159
160
impl<C> From<&SecretKey<C>> for NonZeroScalar<C>
161
where
162
    C: CurveArithmetic,
163
{
164
82
    fn from(sk: &SecretKey<C>) -> NonZeroScalar<C> {
165
82
        let scalar = sk.as_scalar_primitive().to_scalar();
166
82
        debug_assert!(!bool::from(scalar.is_zero()));
167
82
        Self { scalar }
168
82
    }
169
}
170
171
impl<C> Invert for NonZeroScalar<C>
172
where
173
    C: CurveArithmetic,
174
    Scalar<C>: Invert<Output = CtOption<Scalar<C>>>,
175
{
176
    type Output = Self;
177
178
    fn invert(&self) -> Self {
179
        Self {
180
            // This will always succeed since `scalar` will never be 0
181
            scalar: Invert::invert(&self.scalar).unwrap(),
182
        }
183
    }
184
185
    fn invert_vartime(&self) -> Self::Output {
186
        Self {
187
            // This will always succeed since `scalar` will never be 0
188
            scalar: Invert::invert_vartime(&self.scalar).unwrap(),
189
        }
190
    }
191
}
192
193
impl<C> IsHigh for NonZeroScalar<C>
194
where
195
    C: CurveArithmetic,
196
{
197
    fn is_high(&self) -> Choice {
198
        self.scalar.is_high()
199
    }
200
}
201
202
impl<C> Neg for NonZeroScalar<C>
203
where
204
    C: CurveArithmetic,
205
{
206
    type Output = NonZeroScalar<C>;
207
208
    fn neg(self) -> NonZeroScalar<C> {
209
        let scalar = -self.scalar;
210
        debug_assert!(!bool::from(scalar.is_zero()));
211
        NonZeroScalar { scalar }
212
    }
213
}
214
215
impl<C> Mul<NonZeroScalar<C>> for NonZeroScalar<C>
216
where
217
    C: PrimeCurve + CurveArithmetic,
218
{
219
    type Output = Self;
220
221
    #[inline]
222
    fn mul(self, other: Self) -> Self {
223
        Self::mul(self, &other)
224
    }
225
}
226
227
impl<C> Mul<&NonZeroScalar<C>> for NonZeroScalar<C>
228
where
229
    C: PrimeCurve + CurveArithmetic,
230
{
231
    type Output = Self;
232
233
    fn mul(self, other: &Self) -> Self {
234
        // Multiplication is modulo a prime, so the product of two non-zero
235
        // scalars is also non-zero.
236
        let scalar = self.scalar * other.scalar;
237
        debug_assert!(!bool::from(scalar.is_zero()));
238
        NonZeroScalar { scalar }
239
    }
240
}
241
242
/// Note: this is a non-zero reduction, as it's impl'd for [`NonZeroScalar`].
243
impl<C, I> Reduce<I> for NonZeroScalar<C>
244
where
245
    C: CurveArithmetic,
246
    I: Integer + ArrayEncoding,
247
    Scalar<C>: Reduce<I> + ReduceNonZero<I>,
248
{
249
    type Bytes = <Scalar<C> as Reduce<I>>::Bytes;
250
251
    fn reduce(n: I) -> Self {
252
        let scalar = Scalar::<C>::reduce_nonzero(n);
253
        debug_assert!(!bool::from(scalar.is_zero()));
254
        Self { scalar }
255
    }
256
257
    fn reduce_bytes(bytes: &Self::Bytes) -> Self {
258
        let scalar = Scalar::<C>::reduce_nonzero_bytes(bytes);
259
        debug_assert!(!bool::from(scalar.is_zero()));
260
        Self { scalar }
261
    }
262
}
263
264
/// Note: forwards to the [`Reduce`] impl.
265
impl<C, I> ReduceNonZero<I> for NonZeroScalar<C>
266
where
267
    Self: Reduce<I>,
268
    C: CurveArithmetic,
269
    I: Integer + ArrayEncoding,
270
    Scalar<C>: Reduce<I, Bytes = Self::Bytes> + ReduceNonZero<I>,
271
{
272
    fn reduce_nonzero(n: I) -> Self {
273
        Self::reduce(n)
274
    }
275
276
    fn reduce_nonzero_bytes(bytes: &Self::Bytes) -> Self {
277
        Self::reduce_bytes(bytes)
278
    }
279
}
280
281
impl<C> TryFrom<&[u8]> for NonZeroScalar<C>
282
where
283
    C: CurveArithmetic,
284
{
285
    type Error = Error;
286
287
    fn try_from(bytes: &[u8]) -> Result<Self, Error> {
288
        if bytes.len() == C::FieldBytesSize::USIZE {
289
            Option::from(NonZeroScalar::from_repr(GenericArray::clone_from_slice(
290
                bytes,
291
            )))
292
            .ok_or(Error)
293
        } else {
294
            Err(Error)
295
        }
296
    }
297
}
298
299
impl<C> Zeroize for NonZeroScalar<C>
300
where
301
    C: CurveArithmetic,
302
{
303
36.3k
    fn zeroize(&mut self) {
304
        // Use zeroize's volatile writes to ensure value is cleared.
305
36.3k
        self.scalar.zeroize();
306
307
        // Write a 1 instead of a 0 to ensure this type's non-zero invariant
308
        // is upheld.
309
36.3k
        self.scalar = Scalar::<C>::ONE;
310
36.3k
    }
311
}
312
313
impl<C> fmt::Display for NonZeroScalar<C>
314
where
315
    C: CurveArithmetic,
316
{
317
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
318
        write!(f, "{self:X}")
319
    }
320
}
321
322
impl<C> fmt::LowerHex for NonZeroScalar<C>
323
where
324
    C: CurveArithmetic,
325
{
326
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
327
        write!(f, "{:x}", HexDisplay(&self.to_repr()))
328
    }
329
}
330
331
impl<C> fmt::UpperHex for NonZeroScalar<C>
332
where
333
    C: CurveArithmetic,
334
{
335
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
336
        write!(f, "{:}", HexDisplay(&self.to_repr()))
337
    }
338
}
339
340
impl<C> str::FromStr for NonZeroScalar<C>
341
where
342
    C: CurveArithmetic,
343
{
344
    type Err = Error;
345
346
    fn from_str(hex: &str) -> Result<Self, Error> {
347
        let mut bytes = FieldBytes::<C>::default();
348
349
        if base16ct::mixed::decode(hex, &mut bytes)?.len() == bytes.len() {
350
            Option::from(Self::from_repr(bytes)).ok_or(Error)
351
        } else {
352
            Err(Error)
353
        }
354
    }
355
}
356
357
#[cfg(feature = "serde")]
358
impl<C> Serialize for NonZeroScalar<C>
359
where
360
    C: CurveArithmetic,
361
{
362
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
363
    where
364
        S: ser::Serializer,
365
    {
366
        ScalarPrimitive::from(self).serialize(serializer)
367
    }
368
}
369
370
#[cfg(feature = "serde")]
371
impl<'de, C> Deserialize<'de> for NonZeroScalar<C>
372
where
373
    C: CurveArithmetic,
374
{
375
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
376
    where
377
        D: de::Deserializer<'de>,
378
    {
379
        let scalar = ScalarPrimitive::deserialize(deserializer)?;
380
        Option::from(Self::new(scalar.into()))
381
            .ok_or_else(|| de::Error::custom("expected non-zero scalar"))
382
    }
383
}
384
385
#[cfg(all(test, feature = "dev"))]
386
mod tests {
387
    use crate::dev::{NonZeroScalar, Scalar};
388
    use ff::{Field, PrimeField};
389
    use hex_literal::hex;
390
    use zeroize::Zeroize;
391
392
    #[test]
393
    fn round_trip() {
394
        let bytes = hex!("c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721");
395
        let scalar = NonZeroScalar::from_repr(bytes.into()).unwrap();
396
        assert_eq!(&bytes, scalar.to_repr().as_slice());
397
    }
398
399
    #[test]
400
    fn zeroize() {
401
        let mut scalar = NonZeroScalar::new(Scalar::from(42u64)).unwrap();
402
        scalar.zeroize();
403
        assert_eq!(*scalar, Scalar::ONE);
404
    }
405
}