Skip to content

Commit 1d3b3c6

Browse files
committed
make scale() thread-safe
1 parent ad00ce6 commit 1d3b3c6

File tree

1 file changed

+82
-25
lines changed

1 file changed

+82
-25
lines changed

src/ecdsa/ellipticcurve.py

+82-25
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
from six import python_2_unicode_compatible
5050
from . import numbertheory
51+
from ._rwlock import RWLock
5152

5253

5354
@python_2_unicode_compatible
@@ -145,6 +146,9 @@ def __init__(self, curve, x, y, z, order=None, generator=False):
145146
cause to precompute multiplication table for it
146147
"""
147148
self.__curve = curve
149+
# since it's generally better (faster) to use scaled points vs unscaled
150+
# ones, use writer-biased RWLock for locking:
151+
self._scale_lock = RWLock()
148152
if GMPY:
149153
self.__x = mpz(x)
150154
self.__y = mpz(y)
@@ -171,19 +175,25 @@ def __init__(self, curve, x, y, z, order=None, generator=False):
171175

172176
def __eq__(self, other):
173177
"""Compare two points with each-other."""
174-
if (not self.__y or not self.__z) and other is INFINITY:
175-
return True
176-
if self.__y and self.__z and other is INFINITY:
177-
return False
178+
try:
179+
self._scale_lock.reader_acquire()
180+
if other is INFINITY:
181+
return not self.__y or not self.__z
182+
x1, y1, z1 = self.__x, self.__y, self.__z
183+
finally:
184+
self._scale_lock.reader_release()
178185
if isinstance(other, Point):
179186
x2, y2, z2 = other.x(), other.y(), 1
180187
elif isinstance(other, PointJacobi):
181-
x2, y2, z2 = other.__x, other.__y, other.__z
188+
try:
189+
other._scale_lock.reader_acquire()
190+
x2, y2, z2 = other.__x, other.__y, other.__z
191+
finally:
192+
other._scale_lock.reader_release()
182193
else:
183194
return NotImplemented
184195
if self.__curve != other.curve():
185196
return False
186-
x1, y1, z1 = self.__x, self.__y, self.__z
187197
p = self.__curve.p()
188198

189199
zz1 = z1 * z1 % p
@@ -214,11 +224,17 @@ def x(self):
214224
call x() and y() on the returned instance. Or call `scale()`
215225
and then x() and y() on the returned instance.
216226
"""
217-
if self.__z == 1:
218-
return self.__x
227+
try:
228+
self._scale_lock.reader_acquire()
229+
if self.__z == 1:
230+
return self.__x
231+
x = self.__x
232+
z = self.__z
233+
finally:
234+
self._scale_lock.reader_release()
219235
p = self.__curve.p()
220-
z = numbertheory.inverse_mod(self.__z, p)
221-
return self.__x * z**2 % p
236+
z = numbertheory.inverse_mod(z, p)
237+
return x * z**2 % p
222238

223239
def y(self):
224240
"""
@@ -229,31 +245,54 @@ def y(self):
229245
call x() and y() on the returned instance. Or call `scale()`
230246
and then x() and y() on the returned instance.
231247
"""
232-
if self.__z == 1:
233-
return self.__y
248+
try:
249+
self._scale_lock.reader_acquire()
250+
if self.__z == 1:
251+
return self.__y
252+
y = self.__y
253+
z = self.__z
254+
finally:
255+
self._scale_lock.reader_release()
234256
p = self.__curve.p()
235-
z = numbertheory.inverse_mod(self.__z, p)
236-
return self.__y * z**3 % p
257+
z = numbertheory.inverse_mod(z, p)
258+
return y * z**3 % p
237259

238260
def scale(self):
239261
"""
240262
Return point scaled so that z == 1.
241263
242264
Modifies point in place, returns self.
243265
"""
244-
p = self.__curve.p()
245-
z_inv = numbertheory.inverse_mod(self.__z, p)
246-
zz_inv = z_inv * z_inv % p
247-
self.__x = self.__x * zz_inv % p
248-
self.__y = self.__y * zz_inv * z_inv % p
249-
self.__z = 1
266+
try:
267+
self._scale_lock.reader_acquire()
268+
if self.__z == 1:
269+
return self
270+
finally:
271+
self._scale_lock.reader_release()
272+
273+
try:
274+
self._scale_lock.writer_acquire()
275+
# scaling already scaled point is safe (as inverse of 1 is 1) and
276+
# quick so we don't need to optimise for the unlikely event when
277+
# two threads hit the lock at the same time
278+
p = self.__curve.p()
279+
z_inv = numbertheory.inverse_mod(self.__z, p)
280+
zz_inv = z_inv * z_inv % p
281+
self.__x = self.__x * zz_inv % p
282+
self.__y = self.__y * zz_inv * z_inv % p
283+
# we are setting the z last so that the check above will return true
284+
# only after all values were already updated
285+
self.__z = 1
286+
finally:
287+
self._scale_lock.writer_release()
250288
return self
251289

252290
def to_affine(self):
253291
"""Return point in affine form."""
254292
if not self.__y or not self.__z:
255293
return INFINITY
256294
self.scale()
295+
# after point is scaled, it's immutable, so no need to perform locking
257296
return Point(self.__curve, self.__x,
258297
self.__y, self.__order)
259298

@@ -323,7 +362,11 @@ def double(self):
323362

324363
p, a = self.__curve.p(), self.__curve.a()
325364

326-
X1, Y1, Z1 = self.__x, self.__y, self.__z
365+
try:
366+
self._scale_lock.reader_acquire()
367+
X1, Y1, Z1 = self.__x, self.__y, self.__z
368+
finally:
369+
self._scale_lock.reader_release()
327370

328371
X3, Y3, Z3 = self._double(X1, Y1, Z1, p, a)
329372

@@ -437,8 +480,16 @@ def __add__(self, other):
437480
raise ValueError("The other point is on different curve")
438481

439482
p = self.__curve.p()
440-
X1, Y1, Z1 = self.__x, self.__y, self.__z
441-
X2, Y2, Z2 = other.__x, other.__y, other.__z
483+
try:
484+
self._scale_lock.reader_acquire()
485+
X1, Y1, Z1 = self.__x, self.__y, self.__z
486+
finally:
487+
self._scale_lock.reader_release()
488+
try:
489+
other._scale_lock.reader_acquire()
490+
X2, Y2, Z2 = other.__x, other.__y, other.__z
491+
finally:
492+
other._scale_lock.reader_release()
442493
X3, Y3, Z3 = self._add(X1, Y1, Z1, X2, Y2, Z2, p)
443494

444495
if not Y3 or not Z3:
@@ -497,6 +548,7 @@ def __mul__(self, other):
497548
return self._mul_precompute(other)
498549

499550
self = self.scale()
551+
# once scaled, point is immutable, not need to lock
500552
X2, Y2 = self.__x, self.__y
501553
X3, Y3, Z3 = 0, 0, 1
502554
p, a = self.__curve.p(), self.__curve.a()
@@ -550,6 +602,7 @@ def mul_add(self, self_mul, other, other_mul):
550602
X3, Y3, Z3 = 0, 0, 1
551603
p, a = self.__curve.p(), self.__curve.a()
552604
self = self.scale()
605+
# after scaling, point is immutable, no need for locking
553606
X1, Y1 = self.__x, self.__y
554607
other = other.scale()
555608
X2, Y2 = other.__x, other.__y
@@ -575,8 +628,12 @@ def mul_add(self, self_mul, other, other_mul):
575628

576629
def __neg__(self):
577630
"""Return negated point."""
578-
return PointJacobi(self.__curve, self.__x, -self.__y, self.__z,
579-
self.__order)
631+
try:
632+
self._scale_lock.reader_acquire()
633+
return PointJacobi(self.__curve, self.__x, -self.__y, self.__z,
634+
self.__order)
635+
finally:
636+
self._scale_lock.reader_release()
580637

581638

582639
class Point(object):

0 commit comments

Comments
 (0)