Skip to content

Commit 3ad5c55

Browse files
committed
Remove pysaml2 dependencies from tests.
1 parent 83ad073 commit 3ad5c55

File tree

9 files changed

+963
-604
lines changed

9 files changed

+963
-604
lines changed

src/satosa/state.py

Lines changed: 126 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,20 @@
33
server.
44
"""
55
import base64
6+
from collections import UserDict
67
import copy
7-
import hashlib
88
import json
99
import logging
10-
from collections import UserDict
11-
from satosa.cookies import SimpleCookie
10+
from lzma import LZMACompressor
11+
from lzma import LZMADecompressor
1212
from uuid import uuid4
1313

14-
from lzma import LZMACompressor, LZMADecompressor
14+
# from cryptography.hazmat.primitives.ciphers.algorithms import AES
15+
from cryptojwt.jwe.aes import AES_GCMEncrypter
1516

16-
from Cryptodome import Random
17-
from Cryptodome.Cipher import AES
18-
19-
import satosa.logging_util as lu
17+
from satosa.cookies import SimpleCookie
2018
from satosa.exception import SATOSAStateError
21-
19+
import satosa.logging_util as lu
2220

2321
logger = logging.getLogger(__name__)
2422

@@ -27,8 +25,8 @@
2725

2826
class State(UserDict):
2927
"""
30-
This class holds a state attribute object. A state object must be able to be converted to
31-
a json string, otherwise will an exception be raised.
28+
This class holds a state attribute object. A state object must be possible to convert to
29+
a json string, otherwise an exception will be raised.
3230
"""
3331

3432
def __init__(self, urlstate_data=None, encryption_key=None):
@@ -52,27 +50,9 @@ def __init__(self, urlstate_data=None, encryption_key=None):
5250
raise ValueError("If an 'urlstate_data' is supplied 'encrypt_key' must be specified.")
5351

5452
if urlstate_data:
55-
try:
56-
urlstate_data_bytes = urlstate_data.encode("utf-8")
57-
urlstate_data_b64decoded = base64.urlsafe_b64decode(urlstate_data_bytes)
58-
lzma = LZMADecompressor()
59-
urlstate_data_decompressed = lzma.decompress(urlstate_data_b64decoded)
60-
urlstate_data_decrypted = _AESCipher(encryption_key).decrypt(
61-
urlstate_data_decompressed
62-
)
63-
lzma = LZMADecompressor()
64-
urlstate_data_decrypted_decompressed = lzma.decompress(urlstate_data_decrypted)
65-
urlstate_data_obj = json.loads(urlstate_data_decrypted_decompressed)
66-
except Exception as e:
67-
error_context = {
68-
"message": "Failed to load state data. Reinitializing empty state.",
69-
"reason": str(e),
70-
"urlstate_data": urlstate_data,
71-
}
72-
logger.warning(error_context)
53+
urlstate_data_obj = self.unpack(urlstate_data, encryption_key=encryption_key)
54+
if urlstate_data_obj is None:
7355
urlstate_data = {}
74-
else:
75-
urlstate_data = urlstate_data_obj
7656

7757
session_id = (
7858
urlstate_data[_SESSION_ID_KEY]
@@ -87,25 +67,53 @@ def __init__(self, urlstate_data=None, encryption_key=None):
8767
def session_id(self):
8868
return self.data.get(_SESSION_ID_KEY)
8969

90-
def urlstate(self, encryption_key):
70+
def unpack(self, data: str, encryption_key):
71+
"""
72+
73+
:param data: A string created by the method pack in this class.
74+
"""
75+
try:
76+
data_bytes = data.encode("utf-8")
77+
data_b64decoded = base64.urlsafe_b64decode(data_bytes)
78+
lzma = LZMADecompressor()
79+
data_decompressed = lzma.decompress(data_b64decoded)
80+
data_decrypted = AES_GCMEncrypter(key=encryption_key).decrypt(
81+
data_decompressed
82+
)
83+
lzma = LZMADecompressor()
84+
data_decrypted_decompressed = lzma.decompress(data_decrypted)
85+
data_obj = json.loads(data_decrypted_decompressed)
86+
except Exception as e:
87+
error_context = {
88+
"message": "Failed to load state data. Reinitializing empty state.",
89+
"reason": str(e),
90+
"urlstate_data": data,
91+
}
92+
logger.warning(error_context)
93+
data_obj = None
94+
95+
return data_obj
96+
97+
def pack(self, encryption_key):
9198
"""
92-
Will return a url safe representation of the state.
99+
Will return an url safe representation of the state.
93100
94101
:type encryption_key: Key used for encryption.
95102
:rtype: str
96103
97104
:return: Url representation av of the state.
98105
"""
106+
99107
lzma = LZMACompressor()
100-
urlstate_data = json.dumps(self.data)
101-
urlstate_data = lzma.compress(urlstate_data.encode("UTF-8"))
102-
urlstate_data += lzma.flush()
103-
urlstate_data = _AESCipher(encryption_key).encrypt(urlstate_data)
108+
_data = json.dumps(self.data)
109+
_data = lzma.compress(_data.encode("UTF-8"))
110+
_data += lzma.flush()
111+
_data = AES_GCMEncrypter(encryption_key).encrypt(_data)
104112
lzma = LZMACompressor()
105-
urlstate_data = lzma.compress(urlstate_data)
106-
urlstate_data += lzma.flush()
107-
urlstate_data = base64.urlsafe_b64encode(urlstate_data)
108-
return urlstate_data.decode("utf-8")
113+
_data = lzma.compress(_data)
114+
_data += lzma.flush()
115+
_data = base64.urlsafe_b64encode(_data)
116+
return _data.decode("utf-8")
109117

110118
def copy(self):
111119
"""
@@ -129,15 +137,15 @@ def state_dict(self):
129137

130138

131139
def state_to_cookie(
132-
state: State,
133-
*,
134-
name: str,
135-
path: str,
136-
encryption_key: str,
137-
secure: bool = None,
138-
httponly: bool = None,
139-
samesite: str = None,
140-
max_age: str = None,
140+
state: State,
141+
*,
142+
name: str,
143+
path: str,
144+
encryption_key: str,
145+
secure: bool = None,
146+
httponly: bool = None,
147+
samesite: str = None,
148+
max_age: str = None,
141149
) -> SimpleCookie:
142150
"""
143151
Saves a state to a cookie
@@ -205,71 +213,71 @@ def cookie_to_state(cookie_str: str, name: str, encryption_key: str) -> State:
205213
else:
206214
return state
207215

208-
209-
class _AESCipher(object):
210-
"""
211-
This class will perform AES encryption/decryption with a keylength of 256.
212-
213-
@see: http://stackoverflow.com/questions/12524994/encrypt-decrypt-using-pycrypto-aes-256
214-
"""
215-
216-
def __init__(self, key):
217-
"""
218-
Constructor
219-
220-
:type key: str
221-
222-
:param key: The key used for encryption and decryption. The longer key the better.
223-
"""
224-
self.bs = 32
225-
self.key = hashlib.sha256(key.encode()).digest()
226-
227-
def encrypt(self, raw):
228-
"""
229-
Encryptes the parameter raw.
230-
231-
:type raw: bytes
232-
:rtype: str
233-
234-
:param: bytes to be encrypted.
235-
236-
:return: A base 64 encoded string.
237-
"""
238-
raw = self._pad(raw)
239-
iv = Random.new().read(AES.block_size)
240-
cipher = AES.new(self.key, AES.MODE_CBC, iv)
241-
return base64.urlsafe_b64encode(iv + cipher.encrypt(raw))
242-
243-
def decrypt(self, enc):
244-
"""
245-
Decryptes the parameter enc.
246-
247-
:type enc: bytes
248-
:rtype: bytes
249-
250-
:param: The value to be decrypted.
251-
:return: The decrypted value.
252-
"""
253-
enc = base64.urlsafe_b64decode(enc)
254-
iv = enc[:AES.block_size]
255-
cipher = AES.new(self.key, AES.MODE_CBC, iv)
256-
return self._unpad(cipher.decrypt(enc[AES.block_size:]))
257-
258-
def _pad(self, b):
259-
"""
260-
Will padd the param to be of the correct length for the encryption alg.
261-
262-
:type b: bytes
263-
:rtype: bytes
264-
"""
265-
return b + (self.bs - len(b) % self.bs) * chr(self.bs - len(b) % self.bs).encode("UTF-8")
266-
267-
@staticmethod
268-
def _unpad(b):
269-
"""
270-
Removes the padding performed by the method _pad.
271-
272-
:type b: bytes
273-
:rtype: bytes
274-
"""
275-
return b[:-ord(b[len(b) - 1:])]
216+
#
217+
# class _AESCipher(object):
218+
# """
219+
# This class will perform AES encryption/decryption with a keylength of 256.
220+
#
221+
# @see: http://stackoverflow.com/questions/12524994/encrypt-decrypt-using-pycrypto-aes-256
222+
# """
223+
#
224+
# def __init__(self, key):
225+
# """
226+
# Constructor
227+
#
228+
# :type key: str
229+
#
230+
# :param key: The key used for encryption and decryption. The longer key the better.
231+
# """
232+
# self.bs = 32
233+
# self.key = hashlib.sha256(key.encode()).digest()
234+
#
235+
# def encrypt(self, raw):
236+
# """
237+
# Encryptes the parameter raw.
238+
#
239+
# :type raw: bytes
240+
# :rtype: str
241+
#
242+
# :param: bytes to be encrypted.
243+
#
244+
# :return: A base 64 encoded string.
245+
# """
246+
# raw = self._pad(raw)
247+
# iv = rndstr(AES.block_size)
248+
# cipher = AES.new(self.key, AES.MODE_CBC, iv)
249+
# return base64.urlsafe_b64encode(iv + cipher.encrypt(raw))
250+
#
251+
# def decrypt(self, enc):
252+
# """
253+
# Decryptes the parameter enc.
254+
#
255+
# :type enc: bytes
256+
# :rtype: bytes
257+
#
258+
# :param: The value to be decrypted.
259+
# :return: The decrypted value.
260+
# """
261+
# enc = base64.urlsafe_b64decode(enc)
262+
# iv = enc[:AES.block_size]
263+
# cipher = AES.new(self.key, AES.MODE_CBC, iv)
264+
# return self._unpad(cipher.decrypt(enc[AES.block_size:]))
265+
#
266+
# def _pad(self, b):
267+
# """
268+
# Will padd the param to be of the correct length for the encryption alg.
269+
#
270+
# :type b: bytes
271+
# :rtype: bytes
272+
# """
273+
# return b + (self.bs - len(b) % self.bs) * chr(self.bs - len(b) % self.bs).encode("UTF-8")
274+
#
275+
# @staticmethod
276+
# def _unpad(b):
277+
# """
278+
# Removes the padding performed by the method _pad.
279+
#
280+
# :type b: bytes
281+
# :rtype: bytes
282+
# """
283+
# return b[:-ord(b[len(b) - 1:])]

0 commit comments

Comments
 (0)