1# coding: utf-8
2from .point import Point
3
4
5class Math:
6
7 @classmethod
8 def modularSquareRoot(cls, value, prime):
9 """Tonelli-Shanks algorithm for modular square root. Works for all odd primes."""
10 if value == 0:
11 return 0
12 if prime == 2:
13 return value % 2
14
15 # Factor out powers of 2: prime - 1 = Q * 2^S
16 Q = prime - 1
17 S = 0
18 while Q % 2 == 0:
19 Q //= 2
20 S += 1
21
22 if S == 1: # prime = 3 (mod 4)
23 return pow(value, (prime + 1) // 4, prime)
24
25 # Find a quadratic non-residue z
26 z = 2
27 while pow(z, (prime - 1) // 2, prime) != prime - 1:
28 z += 1
29
30 M = S
31 c = pow(z, Q, prime)
32 t = pow(value, Q, prime)
33 R = pow(value, (Q + 1) // 2, prime)
34
35 while True:
36 if t == 1:
37 return R
38
39 # Find the least i such that t^(2^i) = 1 (mod prime)
40 i = 1
41 temp = (t * t) % prime
42 while temp != 1:
43 temp = (temp * temp) % prime
44 i += 1
45
46 b = pow(c, 1 << (M - i - 1), prime)
47 M = i
48 c = (b * b) % prime
49 t = (t * c) % prime
50 R = (R * b) % prime
51
52 @classmethod
53 def multiplyGenerator(cls, curve, n):
54 """
55 Fast scalar multiplication n*G using a precomputed affine table of
56 powers-of-two multiples of G and the width-2 NAF of n. Every non-zero
57 NAF digit triggers one mixed add and zero doublings, trading the ~256
58 doublings of a windowed method for ~86 adds on average — a large net
59 reduction in field multiplications for 256-bit scalars.
60
61 :param curve: Elliptic curve with generator G
62 :param n: Scalar multiplier
63 :return: Point n*G
64 """
65 if n < 0 or n >= curve.N:
66 n = n % curve.N
67 if n == 0:
68 return Point(0, 0, 0)
69
70 table = cls._generatorPowersTable(curve)
71 A, P = curve.A, curve.P
72 _add = cls._jacobianAdd
73
74 r = Point(0, 0, 1)
75 i = 0
76 k = n
77 while k > 0:
78 if k & 1:
79 digit = 2 - (k & 3) # -1 or +1
80 k -= digit
81 g = table[i]
82 if digit == 1:
83 r = _add(r, g, A, P)
84 else:
85 r = _add(r, Point(g.x, P - g.y, 1), A, P)
86 k >>= 1
87 i += 1
88 return cls._fromJacobian(r, P)
89
90 @classmethod
91 def _generatorPowersTable(cls, curve):
92 """
93 Build [G, 2G, 4G, ..., 2^nBitLength * G] in affine (z=1) form, so each
94 add in multiplyGenerator hits the mixed-add fast path.
95 """
96 cached = getattr(curve, "_generatorPowersTable_", None)
97 if cached is not None:
98 return cached
99 A, P = curve.A, curve.P
100 current = Point(curve.G.x, curve.G.y, 1)
101 table = [current]
102 # NAF of an nBitLength-bit scalar can be up to nBitLength+1 digits.
103 for _ in range(curve.nBitLength):
104 doubled = cls._jacobianDouble(current, A, P)
105 if doubled.y == 0:
106 current = doubled
107 else:
108 zInv = cls.inv(doubled.z, P)
109 zInv2 = (zInv * zInv) % P
110 zInv3 = (zInv2 * zInv) % P
111 current = Point((doubled.x * zInv2) % P, (doubled.y * zInv3) % P, 1)
112 table.append(current)
113 curve._generatorPowersTable_ = table
114 return table
115
116 @classmethod
117 def multiply(cls, p, n, N, A, P):
118 """
119 Fast way to multily point and scalar in elliptic curves
120
121 :param p: First Point to mutiply
122 :param n: Scalar to mutiply
123 :param N: Order of the elliptic curve
124 :param P: Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod p)
125 :param A: Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod p)
126 :return: Point that represents the sum of First and Second Point
127 """
128 return cls._fromJacobian(
129 cls._jacobianMultiply(cls._toJacobian(p), n, N, A, P), P
130 )
131
132 @classmethod
133 def add(cls, p, q, A, P):
134 """
135 Fast way to add two points in elliptic curves
136
137 :param p: First Point you want to add
138 :param q: Second Point you want to add
139 :param P: Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod p)
140 :param A: Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod p)
141 :return: Point that represents the sum of First and Second Point
142 """
143 return cls._fromJacobian(
144 cls._jacobianAdd(cls._toJacobian(p), cls._toJacobian(q), A, P), P,
145 )
146
147 @classmethod
148 def multiplyAndAdd(cls, p1, n1, p2, n2, N=None, A=None, P=None, curve=None):
149 """
150 Compute n1*p1 + n2*p2. If ``curve`` is given and exposes ``glvParams``
151 (e.g. secp256k1), uses the GLV endomorphism to split both scalars into
152 ~128-bit halves and run a 4-scalar simultaneous multi-exponentiation.
153 Otherwise falls back to Shamir's trick with JSF. Not constant-time —
154 use only with public scalars (e.g. verification).
155
156 :param p1: First point
157 :param n1: First scalar
158 :param p2: Second point
159 :param n2: Second scalar
160 :param N: Order of the elliptic curve (ignored when ``curve`` is given)
161 :param A: Coefficient of the first-order term (ignored when ``curve`` is given)
162 :param P: Prime defining the field (ignored when ``curve`` is given)
163 :param curve: Optional curve object; enables GLV if ``curve.glvParams`` is set
164 :return: Point n1*p1 + n2*p2
165 """
166 if curve is not None:
167 N, A, P = curve.N, curve.A, curve.P
168 if curve.glvParams is not None:
169 return cls._glvMultiplyAndAdd(p1, n1, p2, n2, curve)
170 return cls._fromJacobian(
171 cls._shamirMultiply(
172 cls._toJacobian(p1), n1,
173 cls._toJacobian(p2), n2,
174 N, A, P,
175 ), P,
176 )
177
178 @classmethod
179 def _glvMultiplyAndAdd(cls, p1, n1, p2, n2, curve):
180 """
181 Compute n1*p1 + n2*p2 using the GLV endomorphism. Splits each 256-bit
182 scalar into two ~128-bit scalars via k ≡ k1 + k2·λ (mod N), then runs
183 a 4-scalar simultaneous double-and-add over (p1, φ(p1), p2, φ(p2))
184 with a 16-entry precomputed table of subset sums. Halves the loop
185 length versus the plain Shamir path.
186 """
187 glv = curve.glvParams
188 N, A, P = curve.N, curve.A, curve.P
189 beta = glv["beta"]
190
191 k1, k2 = cls._glvDecompose(n1 % N, glv, N)
192 k3, k4 = cls._glvDecompose(n2 % N, glv, N)
193
194 # Base points (affine, z=1) — φ((x,y)) = (β·x mod P, y).
195 bases = [
196 Point(p1.x, p1.y, 1),
197 Point((beta * p1.x) % P, p1.y, 1),
198 Point(p2.x, p2.y, 1),
199 Point((beta * p2.x) % P, p2.y, 1),
200 ]
201 scalars = [k1, k2, k3, k4]
202 for i in range(4):
203 if scalars[i] < 0:
204 scalars[i] = -scalars[i]
205 bases[i] = Point(bases[i].x, P - bases[i].y, 1)
206
207 # Precompute table[idx] = sum of bases[i] selected by bits of idx.
208 _add = cls._jacobianAdd
209 table = [Point(0, 0, 1)] * 16
210 for idx in range(1, 16):
211 low = idx & -idx
212 i = low.bit_length() - 1
213 table[idx] = _add(table[idx ^ low], bases[i], A, P)
214
215 _double = cls._jacobianDouble
216 maxLen = max(s.bit_length() for s in scalars)
217 r = Point(0, 0, 1)
218 s0, s1, s2, s3 = scalars
219 for bit in range(maxLen - 1, -1, -1):
220 r = _double(r, A, P)
221 idx = ((s0 >> bit) & 1) | (((s1 >> bit) & 1) << 1) \
222 | (((s2 >> bit) & 1) << 2) | (((s3 >> bit) & 1) << 3)
223 if idx:
224 r = _add(r, table[idx], A, P)
225
226 return cls._fromJacobian(r, P)
227
228 @staticmethod
229 def _glvDecompose(k, glv, N):
230 """
231 Decompose k into (k1, k2) with k ≡ k1 + k2·λ (mod N) and
232 |k1|, |k2| ~ √N. Babai rounding against the precomputed basis
233 {(a1, b1), (a2, b2)}; k1 and k2 may be negative.
234 """
235 a1, b1, a2, b2 = glv["a1"], glv["b1"], glv["a2"], glv["b2"]
236 halfN = N // 2
237 c1 = (b2 * k + halfN) // N
238 c2 = (-b1 * k + halfN) // N
239 k1 = k - c1 * a1 - c2 * a2
240 k2 = -c1 * b1 - c2 * b2
241 return k1, k2
242
243 @classmethod
244 def inv(cls, x, n):
245 """
246 Modular inverse via the Extended Euclidean Algorithm. Implemented in
247 pure Python for compatibility with Python 2.7+ and 3.x. CPython 3.8+
248 users get a faster C-level implementation via ``pow(x, -1, n)`` that
249 this falls back to when available.
250
251 :param x: Divisor (must be coprime to n)
252 :param n: Mod for division
253 :return: Value representing the division
254 :raises ValueError: when x is 0 mod n (no inverse exists)
255 """
256 if x % n == 0:
257 raise ValueError("0 has no modular inverse")
258
259 try:
260 return pow(x, -1, n)
261 except (TypeError, ValueError):
262 pass
263
264 lm, hm = 1, 0
265 low, high = x % n, n
266 while low > 1:
267 r = high // low
268 lm, hm = hm - lm * r, lm
269 low, high = high - low * r, low
270 return lm % n
271
272 @classmethod
273 def _toJacobian(cls, p):
274 """
275 Convert point to Jacobian coordinates
276
277 :param p: First Point you want to add
278 :return: Point in Jacobian coordinates
279 """
280 return Point(p.x, p.y, 1)
281
282 @classmethod
283 def _fromJacobian(cls, p, P):
284 """
285 Convert point back from Jacobian coordinates
286
287 :param p: First Point you want to add
288 :param P: Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod p)
289 :return: Point in default coordinates
290 """
291 if p.y == 0:
292 return Point(0, 0, 0)
293
294 z = cls.inv(p.z, P)
295 x = (p.x * z ** 2) % P
296 y = (p.y * z ** 3) % P
297
298 return Point(x, y, 0)
299
300 @classmethod
301 def _jacobianDouble(cls, p, A, P):
302 """
303 Double a point in elliptic curves
304
305 :param p: Point you want to double
306 :param P: Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod p)
307 :param A: Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod p)
308 :return: Point that represents the sum of First and Second Point
309 """
310 py = p.y
311 if py == 0:
312 return Point(0, 0, 0)
313
314 px, pz = p.x, p.z
315 ysq = (py * py) % P
316 S = (4 * px * ysq) % P
317 pz2 = (pz * pz) % P
318 if A == 0:
319 M = (3 * px * px) % P
320 elif A == -3 or A == P - 3:
321 M = (3 * (px - pz2) * (px + pz2)) % P
322 else:
323 M = (3 * px * px + A * pz2 * pz2) % P
324 nx = (M * M - 2 * S) % P
325 ny = (M * (S - nx) - 8 * ysq * ysq) % P
326 nz = (2 * py * pz) % P
327
328 return Point(nx, ny, nz)
329
330 @classmethod
331 def _jacobianAdd(cls, p, q, A, P):
332 """
333 Add two points in elliptic curves
334
335 :param p: First Point you want to add
336 :param q: Second Point you want to add
337 :param P: Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod p)
338 :param A: Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod p)
339 :return: Point that represents the sum of First and Second Point
340 """
341 if p.y == 0:
342 return q
343 if q.y == 0:
344 return p
345
346 px, py, pz = p.x, p.y, p.z
347 qx, qy, qz = q.x, q.y, q.z
348
349 pz2 = (pz * pz) % P
350 U2 = (qx * pz2) % P
351 S2 = (qy * pz2 * pz) % P
352
353 if qz == 1:
354 # Mixed affine+Jacobian add: qz²=qz³=1 saves four multiplications.
355 U1 = px
356 S1 = py
357 else:
358 qz2 = (qz * qz) % P
359 U1 = (px * qz2) % P
360 S1 = (py * qz2 * qz) % P
361
362 if U1 == U2:
363 if S1 != S2:
364 return Point(0, 0, 1)
365 return cls._jacobianDouble(p, A, P)
366
367 H = U2 - U1
368 R = S2 - S1
369 H2 = (H * H) % P
370 H3 = (H * H2) % P
371 U1H2 = (U1 * H2) % P
372 nx = (R * R - H3 - 2 * U1H2) % P
373 ny = (R * (U1H2 - nx) - S1 * H3) % P
374 nz = (H * pz) % P if qz == 1 else (H * pz * qz) % P
375
376 return Point(nx, ny, nz)
377
378 @classmethod
379 def _jacobianMultiply(cls, p, n, N, A, P):
380 """
381 Multiply point and scalar in elliptic curves using a branch-balanced
382 Montgomery ladder: each scalar bit triggers exactly one add and one
383 double in swapped order, masking simple branch-timing leaks. Note:
384 Python's bignum arithmetic is NOT constant-time per operation, so
385 total execution time still leaks through bignum-op duration. True
386 constant-time ECDSA is not achievable in pure Python.
387
388 :param p: First Point to multiply
389 :param n: Scalar to multiply
390 :param N: Order of the elliptic curve
391 :param P: Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod p)
392 :param A: Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod p)
393 :return: Point that represents the scalar multiplication
394 """
395 if p.y == 0 or n == 0:
396 return Point(0, 0, 1)
397
398 if n < 0 or n >= N:
399 n = n % N
400
401 if n == 0:
402 return Point(0, 0, 1)
403
404 _add = cls._jacobianAdd
405 _double = cls._jacobianDouble
406
407 # Montgomery ladder: always performs one add and one double per bit
408 r0 = Point(0, 0, 1)
409 r1 = Point(p.x, p.y, p.z)
410
411 for i in range(n.bit_length() - 1, -1, -1):
412 if (n >> i) & 1 == 0:
413 r1 = _add(r0, r1, A, P)
414 r0 = _double(r0, A, P)
415 else:
416 r0 = _add(r0, r1, A, P)
417 r1 = _double(r1, A, P)
418
419 return r0
420
421 @classmethod
422 def _shamirMultiply(cls, jp1, n1, jp2, n2, N, A, P):
423 """
424 Compute n1*p1 + n2*p2 using Shamir's trick with Joint Sparse Form
425 (Solinas 2001). JSF picks signed digits in {-1, 0, 1} so at most ~l/2
426 digit pairs are non-zero, versus ~3l/4 for the raw binary form. Not
427 constant-time — use only with public scalars (e.g. verification).
428
429 :param jp1: First point in Jacobian coordinates
430 :param n1: First scalar
431 :param jp2: Second point in Jacobian coordinates
432 :param n2: Second scalar
433 :param N: Order of the elliptic curve
434 :param A: Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod p)
435 :param P: Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod p)
436 :return: Point n1*p1 + n2*p2 in Jacobian coordinates
437 """
438 if n1 < 0 or n1 >= N:
439 n1 = n1 % N
440 if n2 < 0 or n2 >= N:
441 n2 = n2 % N
442
443 if n1 == 0 and n2 == 0:
444 return Point(0, 0, 1)
445
446 _add = cls._jacobianAdd
447 _double = cls._jacobianDouble
448
449 def neg(pt):
450 return Point(pt.x, 0 if pt.y == 0 else P - pt.y, pt.z)
451
452 jp1p2 = _add(jp1, jp2, A, P)
453 jp1mp2 = _add(jp1, neg(jp2), A, P)
454 addTable = {
455 (1, 0): jp1,
456 (-1, 0): neg(jp1),
457 (0, 1): jp2,
458 (0, -1): neg(jp2),
459 (1, 1): jp1p2,
460 (-1, -1): neg(jp1p2),
461 (1, -1): jp1mp2,
462 (-1, 1): neg(jp1mp2),
463 }
464
465 digits = cls._jsfDigits(n1, n2)
466 r = Point(0, 0, 1)
467 for u0, u1 in digits:
468 r = _double(r, A, P)
469 if u0 or u1:
470 r = _add(r, addTable[(u0, u1)], A, P)
471
472 return r
473
474 @staticmethod
475 def _jsfDigits(k0, k1):
476 """
477 Joint Sparse Form of (k0, k1): list of signed-digit pairs (u0, u1) in
478 {-1, 0, 1}, ordered MSB-first. At most one of any two consecutive pairs
479 is non-zero, giving density ~1/2 instead of ~3/4 from raw binary.
480 """
481 digits = []
482 d0 = 0
483 d1 = 0
484 while k0 + d0 != 0 or k1 + d1 != 0:
485 a0 = k0 + d0
486 a1 = k1 + d1
487 if a0 & 1:
488 u0 = 1 if (a0 & 3) == 1 else -1
489 if (a0 & 7) in (3, 5) and (a1 & 3) == 2:
490 u0 = -u0
491 else:
492 u0 = 0
493 if a1 & 1:
494 u1 = 1 if (a1 & 3) == 1 else -1
495 if (a1 & 7) in (3, 5) and (a0 & 3) == 2:
496 u1 = -u1
497 else:
498 u1 = 0
499 digits.append((u0, u1))
500 if 2 * d0 == 1 + u0:
501 d0 = 1 - d0
502 if 2 * d1 == 1 + u1:
503 d1 = 1 - d1
504 k0 >>= 1
505 k1 >>= 1
506 digits.reverse()
507 return digits