33server.
44"""
55import base64
6+ from collections import UserDict
67import copy
7- import hashlib
88import json
99import logging
10- from collections import UserDict
11- from satosa . cookies import SimpleCookie
10+ from lzma import LZMACompressor
11+ from lzma import LZMADecompressor
1212from uuid import uuid4
1313
14- from lzma import LZMACompressor , LZMADecompressor
14+ # from cryptography.hazmat.primitives.ciphers.algorithms import AES
15+ from cryptojwt .jwe .aes import AES_GCMEncrypter
1516
16- from Cryptodome import Random
17- from Cryptodome .Cipher import AES
18-
19- import satosa .logging_util as lu
17+ from satosa .cookies import SimpleCookie
2018from satosa .exception import SATOSAStateError
21-
19+ import satosa . logging_util as lu
2220
2321logger = logging .getLogger (__name__ )
2422
2725
2826class State (UserDict ):
2927 """
30- This class holds a state attribute object. A state object must be able to be converted to
31- a json string, otherwise will an exception be raised.
28+ This class holds a state attribute object. A state object must be possible to convert to
29+ a json string, otherwise an exception will be raised.
3230 """
3331
3432 def __init__ (self , urlstate_data = None , encryption_key = None ):
@@ -52,27 +50,9 @@ def __init__(self, urlstate_data=None, encryption_key=None):
5250 raise ValueError ("If an 'urlstate_data' is supplied 'encrypt_key' must be specified." )
5351
5452 if urlstate_data :
55- try :
56- urlstate_data_bytes = urlstate_data .encode ("utf-8" )
57- urlstate_data_b64decoded = base64 .urlsafe_b64decode (urlstate_data_bytes )
58- lzma = LZMADecompressor ()
59- urlstate_data_decompressed = lzma .decompress (urlstate_data_b64decoded )
60- urlstate_data_decrypted = _AESCipher (encryption_key ).decrypt (
61- urlstate_data_decompressed
62- )
63- lzma = LZMADecompressor ()
64- urlstate_data_decrypted_decompressed = lzma .decompress (urlstate_data_decrypted )
65- urlstate_data_obj = json .loads (urlstate_data_decrypted_decompressed )
66- except Exception as e :
67- error_context = {
68- "message" : "Failed to load state data. Reinitializing empty state." ,
69- "reason" : str (e ),
70- "urlstate_data" : urlstate_data ,
71- }
72- logger .warning (error_context )
53+ urlstate_data_obj = self .unpack (urlstate_data , encryption_key = encryption_key )
54+ if urlstate_data_obj is None :
7355 urlstate_data = {}
74- else :
75- urlstate_data = urlstate_data_obj
7656
7757 session_id = (
7858 urlstate_data [_SESSION_ID_KEY ]
@@ -87,25 +67,53 @@ def __init__(self, urlstate_data=None, encryption_key=None):
8767 def session_id (self ):
8868 return self .data .get (_SESSION_ID_KEY )
8969
90- def urlstate (self , encryption_key ):
70+ def unpack (self , data : str , encryption_key ):
71+ """
72+
73+ :param data: A string created by the method pack in this class.
74+ """
75+ try :
76+ data_bytes = data .encode ("utf-8" )
77+ data_b64decoded = base64 .urlsafe_b64decode (data_bytes )
78+ lzma = LZMADecompressor ()
79+ data_decompressed = lzma .decompress (data_b64decoded )
80+ data_decrypted = AES_GCMEncrypter (key = encryption_key ).decrypt (
81+ data_decompressed
82+ )
83+ lzma = LZMADecompressor ()
84+ data_decrypted_decompressed = lzma .decompress (data_decrypted )
85+ data_obj = json .loads (data_decrypted_decompressed )
86+ except Exception as e :
87+ error_context = {
88+ "message" : "Failed to load state data. Reinitializing empty state." ,
89+ "reason" : str (e ),
90+ "urlstate_data" : data ,
91+ }
92+ logger .warning (error_context )
93+ data_obj = None
94+
95+ return data_obj
96+
97+ def pack (self , encryption_key ):
9198 """
92- Will return a url safe representation of the state.
99+ Will return an url safe representation of the state.
93100
94101 :type encryption_key: Key used for encryption.
95102 :rtype: str
96103
97104 :return: Url representation av of the state.
98105 """
106+
99107 lzma = LZMACompressor ()
100- urlstate_data = json .dumps (self .data )
101- urlstate_data = lzma .compress (urlstate_data .encode ("UTF-8" ))
102- urlstate_data += lzma .flush ()
103- urlstate_data = _AESCipher (encryption_key ).encrypt (urlstate_data )
108+ _data = json .dumps (self .data )
109+ _data = lzma .compress (_data .encode ("UTF-8" ))
110+ _data += lzma .flush ()
111+ _data = AES_GCMEncrypter (encryption_key ).encrypt (_data )
104112 lzma = LZMACompressor ()
105- urlstate_data = lzma .compress (urlstate_data )
106- urlstate_data += lzma .flush ()
107- urlstate_data = base64 .urlsafe_b64encode (urlstate_data )
108- return urlstate_data .decode ("utf-8" )
113+ _data = lzma .compress (_data )
114+ _data += lzma .flush ()
115+ _data = base64 .urlsafe_b64encode (_data )
116+ return _data .decode ("utf-8" )
109117
110118 def copy (self ):
111119 """
@@ -129,15 +137,15 @@ def state_dict(self):
129137
130138
131139def state_to_cookie (
132- state : State ,
133- * ,
134- name : str ,
135- path : str ,
136- encryption_key : str ,
137- secure : bool = None ,
138- httponly : bool = None ,
139- samesite : str = None ,
140- max_age : str = None ,
140+ state : State ,
141+ * ,
142+ name : str ,
143+ path : str ,
144+ encryption_key : str ,
145+ secure : bool = None ,
146+ httponly : bool = None ,
147+ samesite : str = None ,
148+ max_age : str = None ,
141149) -> SimpleCookie :
142150 """
143151 Saves a state to a cookie
@@ -205,71 +213,71 @@ def cookie_to_state(cookie_str: str, name: str, encryption_key: str) -> State:
205213 else :
206214 return state
207215
208-
209- class _AESCipher (object ):
210- """
211- This class will perform AES encryption/decryption with a keylength of 256.
212-
213- @see: http://stackoverflow.com/questions/12524994/encrypt-decrypt-using-pycrypto-aes-256
214- """
215-
216- def __init__ (self , key ):
217- """
218- Constructor
219-
220- :type key: str
221-
222- :param key: The key used for encryption and decryption. The longer key the better.
223- """
224- self .bs = 32
225- self .key = hashlib .sha256 (key .encode ()).digest ()
226-
227- def encrypt (self , raw ):
228- """
229- Encryptes the parameter raw.
230-
231- :type raw: bytes
232- :rtype: str
233-
234- :param: bytes to be encrypted.
235-
236- :return: A base 64 encoded string.
237- """
238- raw = self ._pad (raw )
239- iv = Random . new (). read (AES .block_size )
240- cipher = AES .new (self .key , AES .MODE_CBC , iv )
241- return base64 .urlsafe_b64encode (iv + cipher .encrypt (raw ))
242-
243- def decrypt (self , enc ):
244- """
245- Decryptes the parameter enc.
246-
247- :type enc: bytes
248- :rtype: bytes
249-
250- :param: The value to be decrypted.
251- :return: The decrypted value.
252- """
253- enc = base64 .urlsafe_b64decode (enc )
254- iv = enc [:AES .block_size ]
255- cipher = AES .new (self .key , AES .MODE_CBC , iv )
256- return self ._unpad (cipher .decrypt (enc [AES .block_size :]))
257-
258- def _pad (self , b ):
259- """
260- Will padd the param to be of the correct length for the encryption alg.
261-
262- :type b: bytes
263- :rtype: bytes
264- """
265- return b + (self .bs - len (b ) % self .bs ) * chr (self .bs - len (b ) % self .bs ).encode ("UTF-8" )
266-
267- @staticmethod
268- def _unpad (b ):
269- """
270- Removes the padding performed by the method _pad.
271-
272- :type b: bytes
273- :rtype: bytes
274- """
275- return b [:- ord (b [len (b ) - 1 :])]
216+ #
217+ # class _AESCipher(object):
218+ # """
219+ # This class will perform AES encryption/decryption with a keylength of 256.
220+ #
221+ # @see: http://stackoverflow.com/questions/12524994/encrypt-decrypt-using-pycrypto-aes-256
222+ # """
223+ #
224+ # def __init__(self, key):
225+ # """
226+ # Constructor
227+ #
228+ # :type key: str
229+ #
230+ # :param key: The key used for encryption and decryption. The longer key the better.
231+ # """
232+ # self.bs = 32
233+ # self.key = hashlib.sha256(key.encode()).digest()
234+ #
235+ # def encrypt(self, raw):
236+ # """
237+ # Encryptes the parameter raw.
238+ #
239+ # :type raw: bytes
240+ # :rtype: str
241+ #
242+ # :param: bytes to be encrypted.
243+ #
244+ # :return: A base 64 encoded string.
245+ # """
246+ # raw = self._pad(raw)
247+ # iv = rndstr (AES.block_size)
248+ # cipher = AES.new(self.key, AES.MODE_CBC, iv)
249+ # return base64.urlsafe_b64encode(iv + cipher.encrypt(raw))
250+ #
251+ # def decrypt(self, enc):
252+ # """
253+ # Decryptes the parameter enc.
254+ #
255+ # :type enc: bytes
256+ # :rtype: bytes
257+ #
258+ # :param: The value to be decrypted.
259+ # :return: The decrypted value.
260+ # """
261+ # enc = base64.urlsafe_b64decode(enc)
262+ # iv = enc[:AES.block_size]
263+ # cipher = AES.new(self.key, AES.MODE_CBC, iv)
264+ # return self._unpad(cipher.decrypt(enc[AES.block_size:]))
265+ #
266+ # def _pad(self, b):
267+ # """
268+ # Will padd the param to be of the correct length for the encryption alg.
269+ #
270+ # :type b: bytes
271+ # :rtype: bytes
272+ # """
273+ # return b + (self.bs - len(b) % self.bs) * chr(self.bs - len(b) % self.bs).encode("UTF-8")
274+ #
275+ # @staticmethod
276+ # def _unpad(b):
277+ # """
278+ # Removes the padding performed by the method _pad.
279+ #
280+ # :type b: bytes
281+ # :rtype: bytes
282+ # """
283+ # return b[:-ord(b[len(b) - 1:])]
0 commit comments