/rust/registry/src/index.crates.io-1949cf8c6b5b557f/elliptic-curve-0.13.8/src/scalar/nonzero.rs
Line | Count | Source |
1 | | //! Non-zero scalar type. |
2 | | |
3 | | use crate::{ |
4 | | ops::{Invert, Reduce, ReduceNonZero}, |
5 | | scalar::IsHigh, |
6 | | CurveArithmetic, Error, FieldBytes, PrimeCurve, Scalar, ScalarPrimitive, SecretKey, |
7 | | }; |
8 | | use base16ct::HexDisplay; |
9 | | use core::{ |
10 | | fmt, |
11 | | ops::{Deref, Mul, Neg}, |
12 | | str, |
13 | | }; |
14 | | use crypto_bigint::{ArrayEncoding, Integer}; |
15 | | use ff::{Field, PrimeField}; |
16 | | use generic_array::{typenum::Unsigned, GenericArray}; |
17 | | use rand_core::CryptoRngCore; |
18 | | use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; |
19 | | use zeroize::Zeroize; |
20 | | |
21 | | #[cfg(feature = "serde")] |
22 | | use serdect::serde::{de, ser, Deserialize, Serialize}; |
23 | | |
24 | | /// Non-zero scalar type. |
25 | | /// |
26 | | /// This type ensures that its value is not zero, ala `core::num::NonZero*`. |
27 | | /// To do this, the generic `S` type must impl both `Default` and |
28 | | /// `ConstantTimeEq`, with the requirement that `S::default()` returns 0. |
29 | | /// |
30 | | /// In the context of ECC, it's useful for ensuring that scalar multiplication |
31 | | /// cannot result in the point at infinity. |
32 | | #[derive(Clone)] |
33 | | pub struct NonZeroScalar<C> |
34 | | where |
35 | | C: CurveArithmetic, |
36 | | { |
37 | | scalar: Scalar<C>, |
38 | | } |
39 | | |
40 | | impl<C> NonZeroScalar<C> |
41 | | where |
42 | | C: CurveArithmetic, |
43 | | { |
44 | | /// Generate a random `NonZeroScalar`. |
45 | 36.2k | pub fn random(mut rng: &mut impl CryptoRngCore) -> Self { |
46 | | // Use rejection sampling to eliminate zero values. |
47 | | // While this method isn't constant-time, the attacker shouldn't learn |
48 | | // anything about unrelated outputs so long as `rng` is a secure `CryptoRng`. |
49 | | loop { |
50 | 36.2k | if let Some(result) = Self::new(Field::random(&mut rng)).into() { |
51 | 36.2k | break result; |
52 | 0 | } |
53 | | } |
54 | 36.2k | } |
55 | | |
56 | | /// Create a [`NonZeroScalar`] from a scalar. |
57 | 36.2k | pub fn new(scalar: Scalar<C>) -> CtOption<Self> { |
58 | 36.2k | CtOption::new(Self { scalar }, !scalar.is_zero()) |
59 | 36.2k | } |
60 | | |
61 | | /// Decode a [`NonZeroScalar`] from a big endian-serialized field element. |
62 | | pub fn from_repr(repr: FieldBytes<C>) -> CtOption<Self> { |
63 | | Scalar::<C>::from_repr(repr).and_then(Self::new) |
64 | | } |
65 | | |
66 | | /// Create a [`NonZeroScalar`] from a `C::Uint`. |
67 | | pub fn from_uint(uint: C::Uint) -> CtOption<Self> { |
68 | | ScalarPrimitive::new(uint).and_then(|scalar| Self::new(scalar.into())) |
69 | | } |
70 | | } |
71 | | |
72 | | impl<C> AsRef<Scalar<C>> for NonZeroScalar<C> |
73 | | where |
74 | | C: CurveArithmetic, |
75 | | { |
76 | 4.72k | fn as_ref(&self) -> &Scalar<C> { |
77 | 4.72k | &self.scalar |
78 | 4.72k | } |
79 | | } |
80 | | |
81 | | impl<C> ConditionallySelectable for NonZeroScalar<C> |
82 | | where |
83 | | C: CurveArithmetic, |
84 | | { |
85 | | fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { |
86 | | Self { |
87 | | scalar: Scalar::<C>::conditional_select(&a.scalar, &b.scalar, choice), |
88 | | } |
89 | | } |
90 | | } |
91 | | |
92 | | impl<C> ConstantTimeEq for NonZeroScalar<C> |
93 | | where |
94 | | C: CurveArithmetic, |
95 | | { |
96 | | fn ct_eq(&self, other: &Self) -> Choice { |
97 | | self.scalar.ct_eq(&other.scalar) |
98 | | } |
99 | | } |
100 | | |
101 | | impl<C> Copy for NonZeroScalar<C> where C: CurveArithmetic {} |
102 | | |
103 | | impl<C> Deref for NonZeroScalar<C> |
104 | | where |
105 | | C: CurveArithmetic, |
106 | | { |
107 | | type Target = Scalar<C>; |
108 | | |
109 | 6.06k | fn deref(&self) -> &Scalar<C> { |
110 | 6.06k | &self.scalar |
111 | 6.06k | } |
112 | | } |
113 | | |
114 | | impl<C> From<NonZeroScalar<C>> for FieldBytes<C> |
115 | | where |
116 | | C: CurveArithmetic, |
117 | | { |
118 | | fn from(scalar: NonZeroScalar<C>) -> FieldBytes<C> { |
119 | | Self::from(&scalar) |
120 | | } |
121 | | } |
122 | | |
123 | | impl<C> From<&NonZeroScalar<C>> for FieldBytes<C> |
124 | | where |
125 | | C: CurveArithmetic, |
126 | | { |
127 | | fn from(scalar: &NonZeroScalar<C>) -> FieldBytes<C> { |
128 | | scalar.to_repr() |
129 | | } |
130 | | } |
131 | | |
132 | | impl<C> From<NonZeroScalar<C>> for ScalarPrimitive<C> |
133 | | where |
134 | | C: CurveArithmetic, |
135 | | { |
136 | | #[inline] |
137 | | fn from(scalar: NonZeroScalar<C>) -> ScalarPrimitive<C> { |
138 | | Self::from(&scalar) |
139 | | } |
140 | | } |
141 | | |
142 | | impl<C> From<&NonZeroScalar<C>> for ScalarPrimitive<C> |
143 | | where |
144 | | C: CurveArithmetic, |
145 | | { |
146 | | fn from(scalar: &NonZeroScalar<C>) -> ScalarPrimitive<C> { |
147 | | ScalarPrimitive::from_bytes(&scalar.to_repr()).unwrap() |
148 | | } |
149 | | } |
150 | | |
151 | | impl<C> From<SecretKey<C>> for NonZeroScalar<C> |
152 | | where |
153 | | C: CurveArithmetic, |
154 | | { |
155 | | fn from(sk: SecretKey<C>) -> NonZeroScalar<C> { |
156 | | Self::from(&sk) |
157 | | } |
158 | | } |
159 | | |
160 | | impl<C> From<&SecretKey<C>> for NonZeroScalar<C> |
161 | | where |
162 | | C: CurveArithmetic, |
163 | | { |
164 | 82 | fn from(sk: &SecretKey<C>) -> NonZeroScalar<C> { |
165 | 82 | let scalar = sk.as_scalar_primitive().to_scalar(); |
166 | 82 | debug_assert!(!bool::from(scalar.is_zero())); |
167 | 82 | Self { scalar } |
168 | 82 | } |
169 | | } |
170 | | |
171 | | impl<C> Invert for NonZeroScalar<C> |
172 | | where |
173 | | C: CurveArithmetic, |
174 | | Scalar<C>: Invert<Output = CtOption<Scalar<C>>>, |
175 | | { |
176 | | type Output = Self; |
177 | | |
178 | | fn invert(&self) -> Self { |
179 | | Self { |
180 | | // This will always succeed since `scalar` will never be 0 |
181 | | scalar: Invert::invert(&self.scalar).unwrap(), |
182 | | } |
183 | | } |
184 | | |
185 | | fn invert_vartime(&self) -> Self::Output { |
186 | | Self { |
187 | | // This will always succeed since `scalar` will never be 0 |
188 | | scalar: Invert::invert_vartime(&self.scalar).unwrap(), |
189 | | } |
190 | | } |
191 | | } |
192 | | |
193 | | impl<C> IsHigh for NonZeroScalar<C> |
194 | | where |
195 | | C: CurveArithmetic, |
196 | | { |
197 | | fn is_high(&self) -> Choice { |
198 | | self.scalar.is_high() |
199 | | } |
200 | | } |
201 | | |
202 | | impl<C> Neg for NonZeroScalar<C> |
203 | | where |
204 | | C: CurveArithmetic, |
205 | | { |
206 | | type Output = NonZeroScalar<C>; |
207 | | |
208 | | fn neg(self) -> NonZeroScalar<C> { |
209 | | let scalar = -self.scalar; |
210 | | debug_assert!(!bool::from(scalar.is_zero())); |
211 | | NonZeroScalar { scalar } |
212 | | } |
213 | | } |
214 | | |
215 | | impl<C> Mul<NonZeroScalar<C>> for NonZeroScalar<C> |
216 | | where |
217 | | C: PrimeCurve + CurveArithmetic, |
218 | | { |
219 | | type Output = Self; |
220 | | |
221 | | #[inline] |
222 | | fn mul(self, other: Self) -> Self { |
223 | | Self::mul(self, &other) |
224 | | } |
225 | | } |
226 | | |
227 | | impl<C> Mul<&NonZeroScalar<C>> for NonZeroScalar<C> |
228 | | where |
229 | | C: PrimeCurve + CurveArithmetic, |
230 | | { |
231 | | type Output = Self; |
232 | | |
233 | | fn mul(self, other: &Self) -> Self { |
234 | | // Multiplication is modulo a prime, so the product of two non-zero |
235 | | // scalars is also non-zero. |
236 | | let scalar = self.scalar * other.scalar; |
237 | | debug_assert!(!bool::from(scalar.is_zero())); |
238 | | NonZeroScalar { scalar } |
239 | | } |
240 | | } |
241 | | |
242 | | /// Note: this is a non-zero reduction, as it's impl'd for [`NonZeroScalar`]. |
243 | | impl<C, I> Reduce<I> for NonZeroScalar<C> |
244 | | where |
245 | | C: CurveArithmetic, |
246 | | I: Integer + ArrayEncoding, |
247 | | Scalar<C>: Reduce<I> + ReduceNonZero<I>, |
248 | | { |
249 | | type Bytes = <Scalar<C> as Reduce<I>>::Bytes; |
250 | | |
251 | | fn reduce(n: I) -> Self { |
252 | | let scalar = Scalar::<C>::reduce_nonzero(n); |
253 | | debug_assert!(!bool::from(scalar.is_zero())); |
254 | | Self { scalar } |
255 | | } |
256 | | |
257 | | fn reduce_bytes(bytes: &Self::Bytes) -> Self { |
258 | | let scalar = Scalar::<C>::reduce_nonzero_bytes(bytes); |
259 | | debug_assert!(!bool::from(scalar.is_zero())); |
260 | | Self { scalar } |
261 | | } |
262 | | } |
263 | | |
264 | | /// Note: forwards to the [`Reduce`] impl. |
265 | | impl<C, I> ReduceNonZero<I> for NonZeroScalar<C> |
266 | | where |
267 | | Self: Reduce<I>, |
268 | | C: CurveArithmetic, |
269 | | I: Integer + ArrayEncoding, |
270 | | Scalar<C>: Reduce<I, Bytes = Self::Bytes> + ReduceNonZero<I>, |
271 | | { |
272 | | fn reduce_nonzero(n: I) -> Self { |
273 | | Self::reduce(n) |
274 | | } |
275 | | |
276 | | fn reduce_nonzero_bytes(bytes: &Self::Bytes) -> Self { |
277 | | Self::reduce_bytes(bytes) |
278 | | } |
279 | | } |
280 | | |
281 | | impl<C> TryFrom<&[u8]> for NonZeroScalar<C> |
282 | | where |
283 | | C: CurveArithmetic, |
284 | | { |
285 | | type Error = Error; |
286 | | |
287 | | fn try_from(bytes: &[u8]) -> Result<Self, Error> { |
288 | | if bytes.len() == C::FieldBytesSize::USIZE { |
289 | | Option::from(NonZeroScalar::from_repr(GenericArray::clone_from_slice( |
290 | | bytes, |
291 | | ))) |
292 | | .ok_or(Error) |
293 | | } else { |
294 | | Err(Error) |
295 | | } |
296 | | } |
297 | | } |
298 | | |
299 | | impl<C> Zeroize for NonZeroScalar<C> |
300 | | where |
301 | | C: CurveArithmetic, |
302 | | { |
303 | 36.3k | fn zeroize(&mut self) { |
304 | | // Use zeroize's volatile writes to ensure value is cleared. |
305 | 36.3k | self.scalar.zeroize(); |
306 | | |
307 | | // Write a 1 instead of a 0 to ensure this type's non-zero invariant |
308 | | // is upheld. |
309 | 36.3k | self.scalar = Scalar::<C>::ONE; |
310 | 36.3k | } |
311 | | } |
312 | | |
313 | | impl<C> fmt::Display for NonZeroScalar<C> |
314 | | where |
315 | | C: CurveArithmetic, |
316 | | { |
317 | | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
318 | | write!(f, "{self:X}") |
319 | | } |
320 | | } |
321 | | |
322 | | impl<C> fmt::LowerHex for NonZeroScalar<C> |
323 | | where |
324 | | C: CurveArithmetic, |
325 | | { |
326 | | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
327 | | write!(f, "{:x}", HexDisplay(&self.to_repr())) |
328 | | } |
329 | | } |
330 | | |
331 | | impl<C> fmt::UpperHex for NonZeroScalar<C> |
332 | | where |
333 | | C: CurveArithmetic, |
334 | | { |
335 | | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
336 | | write!(f, "{:}", HexDisplay(&self.to_repr())) |
337 | | } |
338 | | } |
339 | | |
340 | | impl<C> str::FromStr for NonZeroScalar<C> |
341 | | where |
342 | | C: CurveArithmetic, |
343 | | { |
344 | | type Err = Error; |
345 | | |
346 | | fn from_str(hex: &str) -> Result<Self, Error> { |
347 | | let mut bytes = FieldBytes::<C>::default(); |
348 | | |
349 | | if base16ct::mixed::decode(hex, &mut bytes)?.len() == bytes.len() { |
350 | | Option::from(Self::from_repr(bytes)).ok_or(Error) |
351 | | } else { |
352 | | Err(Error) |
353 | | } |
354 | | } |
355 | | } |
356 | | |
357 | | #[cfg(feature = "serde")] |
358 | | impl<C> Serialize for NonZeroScalar<C> |
359 | | where |
360 | | C: CurveArithmetic, |
361 | | { |
362 | | fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> |
363 | | where |
364 | | S: ser::Serializer, |
365 | | { |
366 | | ScalarPrimitive::from(self).serialize(serializer) |
367 | | } |
368 | | } |
369 | | |
370 | | #[cfg(feature = "serde")] |
371 | | impl<'de, C> Deserialize<'de> for NonZeroScalar<C> |
372 | | where |
373 | | C: CurveArithmetic, |
374 | | { |
375 | | fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> |
376 | | where |
377 | | D: de::Deserializer<'de>, |
378 | | { |
379 | | let scalar = ScalarPrimitive::deserialize(deserializer)?; |
380 | | Option::from(Self::new(scalar.into())) |
381 | | .ok_or_else(|| de::Error::custom("expected non-zero scalar")) |
382 | | } |
383 | | } |
384 | | |
385 | | #[cfg(all(test, feature = "dev"))] |
386 | | mod tests { |
387 | | use crate::dev::{NonZeroScalar, Scalar}; |
388 | | use ff::{Field, PrimeField}; |
389 | | use hex_literal::hex; |
390 | | use zeroize::Zeroize; |
391 | | |
392 | | #[test] |
393 | | fn round_trip() { |
394 | | let bytes = hex!("c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721"); |
395 | | let scalar = NonZeroScalar::from_repr(bytes.into()).unwrap(); |
396 | | assert_eq!(&bytes, scalar.to_repr().as_slice()); |
397 | | } |
398 | | |
399 | | #[test] |
400 | | fn zeroize() { |
401 | | let mut scalar = NonZeroScalar::new(Scalar::from(42u64)).unwrap(); |
402 | | scalar.zeroize(); |
403 | | assert_eq!(*scalar, Scalar::ONE); |
404 | | } |
405 | | } |