Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/potr/compatcrypto/pycrypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from Crypto.Hash import SHA as _SHA1
from Crypto.Hash import HMAC as _HMAC
from Crypto.PublicKey import DSA
from Crypto.Random import random
import Crypto.Random.random
from numbers import Number

from potr.compatcrypto import common
Expand Down Expand Up @@ -101,7 +101,7 @@ def fingerprint(self):

def sign(self, data):
# 2 <= K <= q
K = random.randrange(2, self.priv.q)
K = randrange(2, self.priv.q)
r, s = self.priv.sign(data, K)
return long_to_bytes(r, 20) + long_to_bytes(s, 20)

Expand Down Expand Up @@ -136,3 +136,9 @@ def parsePayload(cls, data, private=False):
x, data = read_mpi(data)
return cls((y, g, p, q, x), private=True), data
return cls((y, g, p, q), private=False), data

def getrandbits(k):
return Crypto.Random.random.getrandbits(k)

def randrange(start, stop):
return Crypto.Random.random.randrange(start, stop)
26 changes: 13 additions & 13 deletions src/potr/crypt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


from potr.compatcrypto import SHA256, SHA1, SHA1HMAC, SHA256HMAC, \
Counter, AESCTR, PK, random
Counter, AESCTR, PK, getrandbits, randrange
from potr.utils import bytes_to_long, long_to_bytes, pack_mpi, read_mpi
from potr import proto

Expand Down Expand Up @@ -58,7 +58,7 @@ def set_params(cls, prime, gen):
cls.gen = gen

def __init__(self):
self.priv = random.randrange(2, 2**320)
self.priv = randrange(2, 2**320)
self.pub = pow(self.gen, self.priv, self.prime)

DH.set_params(DH_MODULUS, DH_GENERATOR)
Expand Down Expand Up @@ -350,7 +350,7 @@ def __init__(self, privkey, onSuccess):
self.lastmsg = None

def startAKE(self):
self.r = long_to_bytes(random.getrandbits(128), 16)
self.r = long_to_bytes(getrandbits(128), 16)

gxmpi = pack_mpi(self.dh.pub)

Expand Down Expand Up @@ -549,8 +549,8 @@ def handle(self, tlv, appdata=None):

self.g3o = msg[3]

self.x2 = random.randrange(2, DH_MAX)
self.x3 = random.randrange(2, DH_MAX)
self.x2 = randrange(2, DH_MAX)
self.x3 = randrange(2, DH_MAX)

self.g2 = pow(msg[0], self.x2, DH_MODULUS)
self.g3 = pow(msg[3], self.x3, DH_MODULUS)
Expand Down Expand Up @@ -586,7 +586,7 @@ def handle(self, tlv, appdata=None):
self.abort(appdata=appdata)
return

r = random.randrange(2, DH_MAX)
r = randrange(2, DH_MAX)
self.p = pow(self.g3, r, DH_MODULUS)
msg = [self.p]
qa1 = pow(self.g1, r, DH_MODULUS)
Expand Down Expand Up @@ -689,8 +689,8 @@ def gotSecret(self, secret, question=None, appdata=None):

self.secret = bytes_to_long(combSecret)

self.x2 = random.randrange(2, DH_MAX)
self.x3 = random.randrange(2, DH_MAX)
self.x2 = randrange(2, DH_MAX)
self.x3 = randrange(2, DH_MAX)

msg = [pow(self.g1, self.x2, DH_MODULUS)]
msg += proof_known_log(self.g1, self.x2, 1)
Expand All @@ -715,7 +715,7 @@ def gotSecret(self, secret, question=None, appdata=None):
msg.append(pow(self.g1, self.x3, DH_MODULUS))
msg += proof_known_log(self.g1, self.x3, 4)

r = random.randrange(2, DH_MAX)
r = randrange(2, DH_MAX)

self.p = pow(self.g3, r, DH_MODULUS)
msg.append(self.p)
Expand All @@ -731,8 +731,8 @@ def gotSecret(self, secret, question=None, appdata=None):
self.sendTLV(proto.SMP2TLV(msg), appdata=appdata)

def proof_equal_coords(self, r, v):
r1 = random.randrange(2, DH_MAX)
r2 = random.randrange(2, DH_MAX)
r1 = randrange(2, DH_MAX)
r2 = randrange(2, DH_MAX)
temp2 = pow(self.g1, r1, DH_MODULUS) \
* pow(self.g2, r2, DH_MODULUS) % DH_MODULUS
temp1 = pow(self.g3, r1, DH_MODULUS)
Expand Down Expand Up @@ -761,7 +761,7 @@ def check_equal_coords(self, coords, v):
return long_to_bytes(c, 32) == cprime

def proof_equal_logs(self, v):
r = random.randrange(2, DH_MAX)
r = randrange(2, DH_MAX)
temp1 = pow(self.g1, r, DH_MODULUS)
temp2 = pow(self.qab, r, DH_MODULUS)

Expand All @@ -783,7 +783,7 @@ def check_equal_logs(self, logs, v):
return long_to_bytes(c, 32) == cprime

def proof_known_log(g, x, v):
r = random.randrange(2, DH_MAX)
r = randrange(2, DH_MAX)
c = bytes_to_long(SHA256(struct.pack(b'B', v) + pack_mpi(pow(g, r, DH_MODULUS))))
temp = x * c % SM_ORDER
return c, (r-temp) % SM_ORDER
Expand Down
43 changes: 43 additions & 0 deletions tests/test_compatcrypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,46 @@ def test_SHA256HMAC(self):
self.assertEqual(
to_hex(potr.compatcrypto.SHA256HMAC(b'key', b'this is a test')),
b'a85e8284b3aabd90add3da46176bce8e10eff8eafd7d096d8ba7d9396623b894')

def test_AESCTR_default_counter(self):
key = potr.utils.long_to_bytes(
potr.compatcrypto.getrandbits(128), 16)

aes_encrypter = potr.compatcrypto.AESCTR(key)
ciphertext = aes_encrypter.encrypt(b'setec astronomy')

aes_decrypter = potr.compatcrypto.AESCTR(key)
self.assertEqual(aes_decrypter.decrypt(ciphertext), b'setec astronomy')

def test_AESCTR_number_counter(self):
key = potr.utils.long_to_bytes(
potr.compatcrypto.getrandbits(128), 16)

aes_encrypter = potr.compatcrypto.AESCTR(key, 2010)
ciphertext = aes_encrypter.encrypt(b'setec astronomy')

aes_decrypter = potr.compatcrypto.AESCTR(key, 2010)
self.assertEqual(aes_decrypter.decrypt(ciphertext), b'setec astronomy')

def test_AESCTR_counter_counter(self):
key = potr.utils.long_to_bytes(
potr.compatcrypto.getrandbits(128), 16)

aes_encrypter = potr.compatcrypto.AESCTR(key, potr.compatcrypto.Counter(2013))
ciphertext = aes_encrypter.encrypt(b'setec astronomy')

aes_decrypter = potr.compatcrypto.AESCTR(key, potr.compatcrypto.Counter(2013))
self.assertEqual(aes_decrypter.decrypt(ciphertext), b'setec astronomy')

def test_getrandbits(self):
bits = potr.compatcrypto.getrandbits(128)
byts = potr.utils.long_to_bytes(bits, 16)
self.assertEquals(len(byts), 16)

def test_randrange(self):
pick = potr.compatcrypto.randrange(7, 8)
self.assertEqual(pick, 7)

pick = potr.compatcrypto.randrange(0, 10000)
self.assertGreaterEqual(pick, 0)
self.assertLess(pick, 10000)