From 3ad5c55c29241bb77d417a0dd84b613bb7c61df0 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sun, 1 Oct 2023 11:17:06 +0200 Subject: [PATCH] Remove pysaml2 dependencies from tests. --- src/satosa/state.py | 244 ++++++++-------- tests/conftest.py | 236 +-------------- tests/conftest_oidc.py | 53 ++++ tests/conftest_saml.py | 196 +++++++++++++ tests/flows/test_oidc-saml.py | 91 +++--- tests/flows/test_saml-oidc.py | 4 + tests/satosa/test_state.py | 4 +- tests/util.py | 218 -------------- tests/util_saml2.py | 521 ++++++++++++++++++++++++++++++++++ 9 files changed, 963 insertions(+), 604 deletions(-) create mode 100644 tests/conftest_oidc.py create mode 100644 tests/conftest_saml.py create mode 100644 tests/util_saml2.py diff --git a/src/satosa/state.py b/src/satosa/state.py index 1fc768425..f5c6677b9 100644 --- a/src/satosa/state.py +++ b/src/satosa/state.py @@ -3,22 +3,20 @@ server. """ import base64 +from collections import UserDict import copy -import hashlib import json import logging -from collections import UserDict -from satosa.cookies import SimpleCookie +from lzma import LZMACompressor +from lzma import LZMADecompressor from uuid import uuid4 -from lzma import LZMACompressor, LZMADecompressor +# from cryptography.hazmat.primitives.ciphers.algorithms import AES +from cryptojwt.jwe.aes import AES_GCMEncrypter -from Cryptodome import Random -from Cryptodome.Cipher import AES - -import satosa.logging_util as lu +from satosa.cookies import SimpleCookie from satosa.exception import SATOSAStateError - +import satosa.logging_util as lu logger = logging.getLogger(__name__) @@ -27,8 +25,8 @@ class State(UserDict): """ - This class holds a state attribute object. A state object must be able to be converted to - a json string, otherwise will an exception be raised. + This class holds a state attribute object. A state object must be possible to convert to + a json string, otherwise an exception will be raised. """ def __init__(self, urlstate_data=None, encryption_key=None): @@ -52,27 +50,9 @@ def __init__(self, urlstate_data=None, encryption_key=None): raise ValueError("If an 'urlstate_data' is supplied 'encrypt_key' must be specified.") if urlstate_data: - try: - urlstate_data_bytes = urlstate_data.encode("utf-8") - urlstate_data_b64decoded = base64.urlsafe_b64decode(urlstate_data_bytes) - lzma = LZMADecompressor() - urlstate_data_decompressed = lzma.decompress(urlstate_data_b64decoded) - urlstate_data_decrypted = _AESCipher(encryption_key).decrypt( - urlstate_data_decompressed - ) - lzma = LZMADecompressor() - urlstate_data_decrypted_decompressed = lzma.decompress(urlstate_data_decrypted) - urlstate_data_obj = json.loads(urlstate_data_decrypted_decompressed) - except Exception as e: - error_context = { - "message": "Failed to load state data. Reinitializing empty state.", - "reason": str(e), - "urlstate_data": urlstate_data, - } - logger.warning(error_context) + urlstate_data_obj = self.unpack(urlstate_data, encryption_key=encryption_key) + if urlstate_data_obj is None: urlstate_data = {} - else: - urlstate_data = urlstate_data_obj session_id = ( urlstate_data[_SESSION_ID_KEY] @@ -87,25 +67,53 @@ def __init__(self, urlstate_data=None, encryption_key=None): def session_id(self): return self.data.get(_SESSION_ID_KEY) - def urlstate(self, encryption_key): + def unpack(self, data: str, encryption_key): + """ + + :param data: A string created by the method pack in this class. + """ + try: + data_bytes = data.encode("utf-8") + data_b64decoded = base64.urlsafe_b64decode(data_bytes) + lzma = LZMADecompressor() + data_decompressed = lzma.decompress(data_b64decoded) + data_decrypted = AES_GCMEncrypter(key=encryption_key).decrypt( + data_decompressed + ) + lzma = LZMADecompressor() + data_decrypted_decompressed = lzma.decompress(data_decrypted) + data_obj = json.loads(data_decrypted_decompressed) + except Exception as e: + error_context = { + "message": "Failed to load state data. Reinitializing empty state.", + "reason": str(e), + "urlstate_data": data, + } + logger.warning(error_context) + data_obj = None + + return data_obj + + def pack(self, encryption_key): """ - Will return a url safe representation of the state. + Will return an url safe representation of the state. :type encryption_key: Key used for encryption. :rtype: str :return: Url representation av of the state. """ + lzma = LZMACompressor() - urlstate_data = json.dumps(self.data) - urlstate_data = lzma.compress(urlstate_data.encode("UTF-8")) - urlstate_data += lzma.flush() - urlstate_data = _AESCipher(encryption_key).encrypt(urlstate_data) + _data = json.dumps(self.data) + _data = lzma.compress(_data.encode("UTF-8")) + _data += lzma.flush() + _data = AES_GCMEncrypter(encryption_key).encrypt(_data) lzma = LZMACompressor() - urlstate_data = lzma.compress(urlstate_data) - urlstate_data += lzma.flush() - urlstate_data = base64.urlsafe_b64encode(urlstate_data) - return urlstate_data.decode("utf-8") + _data = lzma.compress(_data) + _data += lzma.flush() + _data = base64.urlsafe_b64encode(_data) + return _data.decode("utf-8") def copy(self): """ @@ -129,15 +137,15 @@ def state_dict(self): def state_to_cookie( - state: State, - *, - name: str, - path: str, - encryption_key: str, - secure: bool = None, - httponly: bool = None, - samesite: str = None, - max_age: str = None, + state: State, + *, + name: str, + path: str, + encryption_key: str, + secure: bool = None, + httponly: bool = None, + samesite: str = None, + max_age: str = None, ) -> SimpleCookie: """ Saves a state to a cookie @@ -205,71 +213,71 @@ def cookie_to_state(cookie_str: str, name: str, encryption_key: str) -> State: else: return state - -class _AESCipher(object): - """ - This class will perform AES encryption/decryption with a keylength of 256. - - @see: http://stackoverflow.com/questions/12524994/encrypt-decrypt-using-pycrypto-aes-256 - """ - - def __init__(self, key): - """ - Constructor - - :type key: str - - :param key: The key used for encryption and decryption. The longer key the better. - """ - self.bs = 32 - self.key = hashlib.sha256(key.encode()).digest() - - def encrypt(self, raw): - """ - Encryptes the parameter raw. - - :type raw: bytes - :rtype: str - - :param: bytes to be encrypted. - - :return: A base 64 encoded string. - """ - raw = self._pad(raw) - iv = Random.new().read(AES.block_size) - cipher = AES.new(self.key, AES.MODE_CBC, iv) - return base64.urlsafe_b64encode(iv + cipher.encrypt(raw)) - - def decrypt(self, enc): - """ - Decryptes the parameter enc. - - :type enc: bytes - :rtype: bytes - - :param: The value to be decrypted. - :return: The decrypted value. - """ - enc = base64.urlsafe_b64decode(enc) - iv = enc[:AES.block_size] - cipher = AES.new(self.key, AES.MODE_CBC, iv) - return self._unpad(cipher.decrypt(enc[AES.block_size:])) - - def _pad(self, b): - """ - Will padd the param to be of the correct length for the encryption alg. - - :type b: bytes - :rtype: bytes - """ - return b + (self.bs - len(b) % self.bs) * chr(self.bs - len(b) % self.bs).encode("UTF-8") - - @staticmethod - def _unpad(b): - """ - Removes the padding performed by the method _pad. - - :type b: bytes - :rtype: bytes - """ - return b[:-ord(b[len(b) - 1:])] +# +# class _AESCipher(object): +# """ +# This class will perform AES encryption/decryption with a keylength of 256. +# +# @see: http://stackoverflow.com/questions/12524994/encrypt-decrypt-using-pycrypto-aes-256 +# """ +# +# def __init__(self, key): +# """ +# Constructor +# +# :type key: str +# +# :param key: The key used for encryption and decryption. The longer key the better. +# """ +# self.bs = 32 +# self.key = hashlib.sha256(key.encode()).digest() +# +# def encrypt(self, raw): +# """ +# Encryptes the parameter raw. +# +# :type raw: bytes +# :rtype: str +# +# :param: bytes to be encrypted. +# +# :return: A base 64 encoded string. +# """ +# raw = self._pad(raw) +# iv = rndstr(AES.block_size) +# cipher = AES.new(self.key, AES.MODE_CBC, iv) +# return base64.urlsafe_b64encode(iv + cipher.encrypt(raw)) +# +# def decrypt(self, enc): +# """ +# Decryptes the parameter enc. +# +# :type enc: bytes +# :rtype: bytes +# +# :param: The value to be decrypted. +# :return: The decrypted value. +# """ +# enc = base64.urlsafe_b64decode(enc) +# iv = enc[:AES.block_size] +# cipher = AES.new(self.key, AES.MODE_CBC, iv) +# return self._unpad(cipher.decrypt(enc[AES.block_size:])) +# +# def _pad(self, b): +# """ +# Will padd the param to be of the correct length for the encryption alg. +# +# :type b: bytes +# :rtype: bytes +# """ +# return b + (self.bs - len(b) % self.bs) * chr(self.bs - len(b) % self.bs).encode("UTF-8") +# +# @staticmethod +# def _unpad(b): +# """ +# Removes the padding performed by the method _pad. +# +# :type b: bytes +# :rtype: bytes +# """ +# return b[:-ord(b[len(b) - 1:])] diff --git a/tests/conftest.py b/tests/conftest.py index f0602a028..8ce87d4d6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,11 @@ -import copy import os import pytest -from saml2 import BINDING_HTTP_REDIRECT, BINDING_HTTP_POST -from saml2.extension.idpdisc import BINDING_DISCO -from saml2.saml import NAME_FORMAT_URI, NAMEID_FORMAT_TRANSIENT, NAMEID_FORMAT_PERSISTENT from satosa.context import Context from satosa.state import State -from .util import create_metadata_from_config_dict -from .util import generate_cert, write_cert +from .util import generate_cert +from .util import write_cert BASE_URL = "https://test-proxy.com" @@ -36,87 +32,6 @@ def cert_and_key(tmpdir): return cert_path, key_path -@pytest.fixture -def sp_conf(cert_and_key): - sp_base = "http://example.com" - spconfig = { - "entityid": "{}/unittest_sp.xml".format(sp_base), - "service": { - "sp": { - "endpoints": { - "assertion_consumer_service": [ - ("%s/acs/redirect" % sp_base, BINDING_HTTP_REDIRECT) - ], - "discovery_response": [("%s/disco" % sp_base, BINDING_DISCO)] - }, - "want_response_signed": False, - "allow_unsolicited": True, - "name_id_format": [NAMEID_FORMAT_PERSISTENT] - }, - }, - "cert_file": cert_and_key[0], - "key_file": cert_and_key[1], - "metadata": {"inline": []}, - } - - return spconfig - - -@pytest.fixture -def idp_conf(cert_and_key): - idp_base = "http://idp.example.com" - - idpconfig = { - "entityid": "{}/{}/proxy.xml".format(idp_base, "Saml2IDP"), - "description": "A SAML2SAML proxy", - "service": { - "idp": { - "name": "Proxy IdP", - "endpoints": { - "single_sign_on_service": [ - ("%s/sso/redirect" % idp_base, BINDING_HTTP_REDIRECT), - ], - }, - "policy": { - "default": { - "lifetime": {"minutes": 15}, - "attribute_restrictions": None, # means all I have - "name_form": NAME_FORMAT_URI, - "fail_on_missing_requested": False - }, - }, - "subject_data": {}, - "name_id_format": [NAMEID_FORMAT_TRANSIENT, - NAMEID_FORMAT_PERSISTENT], - "want_authn_requests_signed": False, - "ui_info": { - "display_name": [{"text": "SATOSA Test IdP", "lang": "en"}], - "description": [{"text": "Test IdP for SATOSA unit tests.", "lang": "en"}], - "logo": [{"text": "https://idp.example.com/static/logo.png", "width": "120", "height": "60", - "lang": "en"}], - }, - }, - }, - "cert_file": cert_and_key[0], - "key_file": cert_and_key[1], - "metadata": {"inline": []}, - "organization": { - "name": [["Test IdP Org.", "en"]], - "display_name": [["Test IdP", "en"]], - "url": [["https://idp.example.com/about", "en"]] - }, - "contact_person": [ - {"given_name": "Test IdP", "sur_name": "Support", "email_address": ["help@idp.example.com"], - "contact_type": "support" - }, - {"given_name": "Test IdP", "sur_name": "Tech support", - "email_address": ["tech@idp.example.com"], "contact_type": "technical"} - ] - } - - return idpconfig - - @pytest.fixture def context(): context = Context() @@ -180,153 +95,6 @@ def response_microservice_config(): return data -@pytest.fixture -def saml_frontend_config(cert_and_key, sp_conf): - data = { - "module": "satosa.frontends.saml2.SAMLFrontend", - "name": "SAML2Frontend", - "config": { - "idp_config": { - "entityid": "frontend-entity_id", - "service": { - "idp": { - "endpoints": { - "single_sign_on_service": [] - }, - "name": "Frontend IdP", - "name_id_format": NAMEID_FORMAT_TRANSIENT, - "policy": { - "default": { - "attribute_restrictions": None, - "fail_on_missing_requested": False, - "lifetime": {"minutes": 15}, - "name_form": NAME_FORMAT_URI - } - } - } - }, - "cert_file": cert_and_key[0], - "key_file": cert_and_key[1], - "metadata": {"inline": [create_metadata_from_config_dict(sp_conf)]}, - "organization": { - "name": [["SATOSA Org.", "en"]], - "display_name": [["SATOSA", "en"]], - "url": [["https://satosa.example.com/about", "en"]] - }, - "contact_person": [ - {"given_name": "SATOSA", "sur_name": "Support", "email_address": ["help@satosa.example.com"], - "contact_type": "support" - }, - {"given_name": "SATOSA", "sur_name": "Tech Support", "email_address": ["tech@satosa.example.com"], - "contact_type": "technical" - } - ] - }, - - "endpoints": { - "single_sign_on_service": {BINDING_HTTP_POST: "sso/post", - BINDING_HTTP_REDIRECT: "sso/redirect"} - } - } - } - - return data - - -@pytest.fixture -def saml_backend_config(idp_conf): - name = "SAML2Backend" - data = { - "module": "satosa.backends.saml2.SAMLBackend", - "name": name, - "config": { - "sp_config": { - "entityid": "backend-entity_id", - "organization": {"display_name": "Example Identities", "name": "Test Identities Org.", - "url": "http://www.example.com"}, - "contact_person": [ - {"contact_type": "technical", "email_address": "technical@example.com", - "given_name": "Technical"}, - {"contact_type": "support", "email_address": "support@example.com", "given_name": "Support"} - ], - "service": { - "sp": { - "want_response_signed": False, - "allow_unsolicited": True, - "endpoints": { - "assertion_consumer_service": [ - ("{}/{}/acs/redirect".format(BASE_URL, name), BINDING_HTTP_REDIRECT)], - "discovery_response": [("{}/disco", BINDING_DISCO)] - - } - } - }, - "metadata": {"inline": [create_metadata_from_config_dict(idp_conf)]} - } - } - } - return data - - -@pytest.fixture -def saml_mirror_frontend_config(saml_frontend_config): - data = copy.deepcopy(saml_frontend_config) - data["module"] = "satosa.frontends.saml2.SAMLMirrorFrontend" - data["name"] = "SAMLMirrorFrontend" - return data - - -@pytest.fixture -def oidc_backend_config(): - data = { - "module": "satosa.backends.openid_connect.OpenIDConnectBackend", - "name": "OIDCBackend", - "config": { - "provider_metadata": { - "issuer": "https://op.example.com", - "authorization_endpoint": "https://example.com/authorization" - }, - "client": { - "auth_req_params": { - "response_type": "code", - "scope": "openid, profile, email, address, phone" - }, - "client_metadata": { - "client_id": "backend_client", - "application_name": "SATOSA", - "application_type": "web", - "contacts": ["suppert@example.com"], - "redirect_uris": ["http://example.com/OIDCBackend"], - "subject_type": "public", - } - }, - "entity_info": { - "contact_person": [{ - "contact_type": "technical", - "email_address": ["technical_test@example.com", "support_test@example.com"], - "given_name": "Test", - "sur_name": "OP" - }, { - "contact_type": "support", - "email_address": ["support_test@example.com"], - "given_name": "Support_test" - }], - "organization": { - "display_name": ["OP Identities", "en"], - "name": [["En test-OP", "se"], ["A test OP", "en"]], - "url": [["http://www.example.com", "en"], ["http://www.example.se", "se"]], - "ui_info": { - "description": [["This is a test OP", "en"]], - "display_name": [["OP - TEST", "en"]] - } - } - } - } - } - - return data - - @pytest.fixture def account_linking_module_config(signing_key_path): account_linking_config = { diff --git a/tests/conftest_oidc.py b/tests/conftest_oidc.py new file mode 100644 index 000000000..96fd45ced --- /dev/null +++ b/tests/conftest_oidc.py @@ -0,0 +1,53 @@ +import pytest + + +@pytest.fixture +def oidc_backend_config(): + data = { + "module": "satosa.backends.openid_connect.OpenIDConnectBackend", + "name": "OIDCBackend", + "config": { + "provider_metadata": { + "issuer": "https://op.example.com", + "authorization_endpoint": "https://example.com/authorization" + }, + "client": { + "auth_req_params": { + "response_type": "code", + "scope": "openid, profile, email, address, phone" + }, + "client_metadata": { + "client_id": "backend_client", + "application_name": "SATOSA", + "application_type": "web", + "contacts": ["suppert@example.com"], + "redirect_uris": ["http://example.com/OIDCBackend"], + "subject_type": "public", + } + }, + "entity_info": { + "contact_person": [{ + "contact_type": "technical", + "email_address": ["technical_test@example.com", "support_test@example.com"], + "given_name": "Test", + "sur_name": "OP" + }, { + "contact_type": "support", + "email_address": ["support_test@example.com"], + "given_name": "Support_test" + }], + "organization": { + "display_name": ["OP Identities", "en"], + "name": [["En test-OP", "se"], ["A test OP", "en"]], + "url": [["http://www.example.com", "en"], ["http://www.example.se", "se"]], + "ui_info": { + "description": [["This is a test OP", "en"]], + "display_name": [["OP - TEST", "en"]] + } + } + } + } + } + + return data + diff --git a/tests/conftest_saml.py b/tests/conftest_saml.py new file mode 100644 index 000000000..76220c06c --- /dev/null +++ b/tests/conftest_saml.py @@ -0,0 +1,196 @@ +import copy + +import pytest + +saml2 = pytest.importorskip('saml2') +from saml2 import BINDING_HTTP_REDIRECT, BINDING_HTTP_POST +from saml2.extension.idpdisc import BINDING_DISCO +from saml2.saml import NAME_FORMAT_URI, NAMEID_FORMAT_TRANSIENT, NAMEID_FORMAT_PERSISTENT + +from .util import create_metadata_from_config_dict + +BASE_URL = "https://test-proxy.com" + + +@pytest.fixture +def sp_conf(cert_and_key): + sp_base = "http://example.com" + spconfig = { + "entityid": "{}/unittest_sp.xml".format(sp_base), + "service": { + "sp": { + "endpoints": { + "assertion_consumer_service": [ + ("%s/acs/redirect" % sp_base, BINDING_HTTP_REDIRECT) + ], + "discovery_response": [("%s/disco" % sp_base, BINDING_DISCO)] + }, + "want_response_signed": False, + "allow_unsolicited": True, + "name_id_format": [NAMEID_FORMAT_PERSISTENT] + }, + }, + "cert_file": cert_and_key[0], + "key_file": cert_and_key[1], + "metadata": {"inline": []}, + } + + return spconfig + + +@pytest.fixture +def idp_conf(cert_and_key): + idp_base = "http://idp.example.com" + + idpconfig = { + "entityid": "{}/{}/proxy.xml".format(idp_base, "Saml2IDP"), + "description": "A SAML2SAML proxy", + "service": { + "idp": { + "name": "Proxy IdP", + "endpoints": { + "single_sign_on_service": [ + ("%s/sso/redirect" % idp_base, BINDING_HTTP_REDIRECT), + ], + }, + "policy": { + "default": { + "lifetime": {"minutes": 15}, + "attribute_restrictions": None, # means all I have + "name_form": NAME_FORMAT_URI, + "fail_on_missing_requested": False + }, + }, + "subject_data": {}, + "name_id_format": [NAMEID_FORMAT_TRANSIENT, + NAMEID_FORMAT_PERSISTENT], + "want_authn_requests_signed": False, + "ui_info": { + "display_name": [{"text": "SATOSA Test IdP", "lang": "en"}], + "description": [{"text": "Test IdP for SATOSA unit tests.", "lang": "en"}], + "logo": [{"text": "https://idp.example.com/static/logo.png", "width": "120", + "height": "60", + "lang": "en"}], + }, + }, + }, + "cert_file": cert_and_key[0], + "key_file": cert_and_key[1], + "metadata": {"inline": []}, + "organization": { + "name": [["Test IdP Org.", "en"]], + "display_name": [["Test IdP", "en"]], + "url": [["https://idp.example.com/about", "en"]] + }, + "contact_person": [ + {"given_name": "Test IdP", "sur_name": "Support", + "email_address": ["help@idp.example.com"], + "contact_type": "support" + }, + {"given_name": "Test IdP", "sur_name": "Tech support", + "email_address": ["tech@idp.example.com"], "contact_type": "technical"} + ] + } + + return idpconfig + + +@pytest.fixture +def saml_frontend_config(cert_and_key, sp_conf): + data = { + "module": "satosa.frontends.saml2.SAMLFrontend", + "name": "SAML2Frontend", + "config": { + "idp_config": { + "entityid": "frontend-entity_id", + "service": { + "idp": { + "endpoints": { + "single_sign_on_service": [] + }, + "name": "Frontend IdP", + "name_id_format": NAMEID_FORMAT_TRANSIENT, + "policy": { + "default": { + "attribute_restrictions": None, + "fail_on_missing_requested": False, + "lifetime": {"minutes": 15}, + "name_form": NAME_FORMAT_URI + } + } + } + }, + "cert_file": cert_and_key[0], + "key_file": cert_and_key[1], + "metadata": {"inline": [create_metadata_from_config_dict(sp_conf)]}, + "organization": { + "name": [["SATOSA Org.", "en"]], + "display_name": [["SATOSA", "en"]], + "url": [["https://satosa.example.com/about", "en"]] + }, + "contact_person": [ + {"given_name": "SATOSA", "sur_name": "Support", + "email_address": ["help@satosa.example.com"], + "contact_type": "support" + }, + {"given_name": "SATOSA", "sur_name": "Tech Support", + "email_address": ["tech@satosa.example.com"], + "contact_type": "technical" + } + ] + }, + + "endpoints": { + "single_sign_on_service": {BINDING_HTTP_POST: "sso/post", + BINDING_HTTP_REDIRECT: "sso/redirect"} + } + } + } + + return data + + +@pytest.fixture +def saml_backend_config(idp_conf): + name = "SAML2Backend" + data = { + "module": "satosa.backends.saml2.SAMLBackend", + "name": name, + "config": { + "sp_config": { + "entityid": "backend-entity_id", + "organization": {"display_name": "Example Identities", + "name": "Test Identities Org.", + "url": "http://www.example.com"}, + "contact_person": [ + {"contact_type": "technical", "email_address": "technical@example.com", + "given_name": "Technical"}, + {"contact_type": "support", "email_address": "support@example.com", + "given_name": "Support"} + ], + "service": { + "sp": { + "want_response_signed": False, + "allow_unsolicited": True, + "endpoints": { + "assertion_consumer_service": [ + ("{}/{}/acs/redirect".format(BASE_URL, name), + BINDING_HTTP_REDIRECT)], + "discovery_response": [("{}/disco", BINDING_DISCO)] + + } + } + }, + "metadata": {"inline": [create_metadata_from_config_dict(idp_conf)]} + } + } + } + return data + + +@pytest.fixture +def saml_mirror_frontend_config(saml_frontend_config): + data = copy.deepcopy(saml_frontend_config) + data["module"] = "satosa.frontends.saml2.SAMLMirrorFrontend" + data["name"] = "SAMLMirrorFrontend" + return data diff --git a/tests/flows/test_oidc-saml.py b/tests/flows/test_oidc-saml.py index 2a299bfef..7e47fd6f9 100644 --- a/tests/flows/test_oidc-saml.py +++ b/tests/flows/test_oidc-saml.py @@ -1,19 +1,28 @@ -import os -import json import base64 -from urllib.parse import urlparse, urlencode, parse_qsl +import json +import os +from urllib.parse import parse_qsl +from urllib.parse import urlencode +from urllib.parse import urlparse + +from cryptojwt import JWS +from cryptojwt.jwk.rsa import RSAKey +from cryptojwt.jwk.rsa import import_rsa_key_from_cert_file -import mongomock import pytest -from jwkest.jwk import rsa_load, RSAKey -from jwkest.jws import JWS -from oic.oic.message import ClaimsRequest, Claims +oic = pytest.importorskip("oic") +mongomock = pytest.importorskip("mongomock") + +from oic.oic.message import Claims +from oic.oic.message import ClaimsRequest from pyop.storage import StorageBase from saml2 import BINDING_HTTP_REDIRECT from saml2.config import IdPConfig from werkzeug.test import Client from werkzeug.wrappers import Response + + from satosa.metadata_creation.saml_metadata import create_entity_descriptors from satosa.proxy_server import make_app from satosa.satosa_config import SATOSAConfig @@ -21,13 +30,13 @@ from tests.users import OIDC_USERS from tests.util import FakeIdP - CLIENT_ID = "client1" CLIENT_SECRET = "secret" CLIENT_REDIRECT_URI = "https://client.example.com/cb" REDIRECT_URI = "https://client.example.com/cb" DB_URI = "mongodb://localhost/satosa" + @pytest.fixture(scope="session") def client_db_path(tmpdir_factory): tmpdir = str(tmpdir_factory.getbasetemp()) @@ -46,6 +55,7 @@ def client_db_path(tmpdir_factory): return path + @pytest.fixture def oidc_frontend_config(signing_key_path): data = { @@ -94,30 +104,34 @@ def _client_setup(self): "response_types": ["id_token"] } - def test_full_flow(self, satosa_config_dict, oidc_frontend_config, saml_backend_config, idp_conf): + def test_full_flow(self, satosa_config_dict, oidc_frontend_config, saml_backend_config, + idp_conf): self._client_setup() subject_id = "testuser1" # proxy config satosa_config_dict["FRONTEND_MODULES"] = [oidc_frontend_config] satosa_config_dict["BACKEND_MODULES"] = [saml_backend_config] - satosa_config_dict["INTERNAL_ATTRIBUTES"]["attributes"] = {attr_name: {"openid": [attr_name], - "saml": [attr_name]} - for attr_name in USERS[subject_id]} + satosa_config_dict["INTERNAL_ATTRIBUTES"]["attributes"] = { + attr_name: {"openid": [attr_name], + "saml": [attr_name]} + for attr_name in USERS[subject_id]} _, backend_metadata = create_entity_descriptors(SATOSAConfig(satosa_config_dict)) # application test_client = Client(make_app(SATOSAConfig(satosa_config_dict)), Response) # get frontend OP config info - provider_config = json.loads(test_client.get("/.well-known/openid-configuration").data.decode("utf-8")) + provider_config = json.loads( + test_client.get("/.well-known/openid-configuration").data.decode("utf-8")) # create auth req claims_request = ClaimsRequest(id_token=Claims(**{k: None for k in USERS[subject_id]})) req_args = {"scope": "openid", "response_type": "id_token", "client_id": CLIENT_ID, "redirect_uri": REDIRECT_URI, "nonce": "nonce", "claims": claims_request.to_json()} - auth_req = urlparse(provider_config["authorization_endpoint"]).path + "?" + urlencode(req_args) + auth_req = urlparse(provider_config["authorization_endpoint"]).path + "?" + urlencode( + req_args) # make auth req to proxy proxied_auth_req = test_client.get(auth_req) @@ -144,8 +158,11 @@ def test_full_flow(self, satosa_config_dict, oidc_frontend_config, saml_backend_ # verify auth resp from proxy resp_dict = dict(parse_qsl(urlparse(authn_resp.data.decode("utf-8")).fragment)) - signing_key = RSAKey(key=rsa_load(oidc_frontend_config["config"]["signing_key_path"]), - use="sig", alg="RS256") + signing_key = import_rsa_key_from_cert_file( + oidc_frontend_config["config"]["signing_key_path"] + ) + # signing_key = RSAKey(key=rsa_load(oidc_frontend_config["config"]["signing_key_path"]), + # use="sig", alg="RS256") id_token_claims = JWS().verify_compact(resp_dict["id_token"], keys=[signing_key]) assert all( @@ -153,29 +170,33 @@ def test_full_flow(self, satosa_config_dict, oidc_frontend_config, saml_backend_ for name, values in OIDC_USERS[subject_id].items() ) - def test_full_stateless_id_token_flow(self, satosa_config_dict, oidc_stateless_frontend_config, saml_backend_config, idp_conf): + def test_full_stateless_id_token_flow(self, satosa_config_dict, oidc_stateless_frontend_config, + saml_backend_config, idp_conf): subject_id = "testuser1" # proxy config satosa_config_dict["FRONTEND_MODULES"] = [oidc_stateless_frontend_config] satosa_config_dict["BACKEND_MODULES"] = [saml_backend_config] - satosa_config_dict["INTERNAL_ATTRIBUTES"]["attributes"] = {attr_name: {"openid": [attr_name], - "saml": [attr_name]} - for attr_name in USERS[subject_id]} + satosa_config_dict["INTERNAL_ATTRIBUTES"]["attributes"] = { + attr_name: {"openid": [attr_name], + "saml": [attr_name]} + for attr_name in USERS[subject_id]} _, backend_metadata = create_entity_descriptors(SATOSAConfig(satosa_config_dict)) # application test_client = Client(make_app(SATOSAConfig(satosa_config_dict)), Response) # get frontend OP config info - provider_config = json.loads(test_client.get("/.well-known/openid-configuration").data.decode("utf-8")) + provider_config = json.loads( + test_client.get("/.well-known/openid-configuration").data.decode("utf-8")) # create auth req claims_request = ClaimsRequest(id_token=Claims(**{k: None for k in USERS[subject_id]})) req_args = {"scope": "openid", "response_type": "id_token", "client_id": CLIENT_ID, "redirect_uri": REDIRECT_URI, "nonce": "nonce", "claims": claims_request.to_json()} - auth_req = urlparse(provider_config["authorization_endpoint"]).path + "?" + urlencode(req_args) + auth_req = urlparse(provider_config["authorization_endpoint"]).path + "?" + urlencode( + req_args) # make auth req to proxy proxied_auth_req = test_client.get(auth_req) @@ -202,8 +223,9 @@ def test_full_stateless_id_token_flow(self, satosa_config_dict, oidc_stateless_f # verify auth resp from proxy resp_dict = dict(parse_qsl(urlparse(authn_resp.data.decode("utf-8")).fragment)) - signing_key = RSAKey(key=rsa_load(oidc_stateless_frontend_config["config"]["signing_key_path"]), - use="sig", alg="RS256") + signing_key = RSAKey( + key=rsa_load(oidc_stateless_frontend_config["config"]["signing_key_path"]), + use="sig", alg="RS256") id_token_claims = JWS().verify_compact(resp_dict["id_token"], keys=[signing_key]) assert all( @@ -211,29 +233,33 @@ def test_full_stateless_id_token_flow(self, satosa_config_dict, oidc_stateless_f for name, values in OIDC_USERS[subject_id].items() ) - def test_full_stateless_code_flow(self, satosa_config_dict, oidc_stateless_frontend_config, saml_backend_config, idp_conf): + def test_full_stateless_code_flow(self, satosa_config_dict, oidc_stateless_frontend_config, + saml_backend_config, idp_conf): subject_id = "testuser1" # proxy config satosa_config_dict["FRONTEND_MODULES"] = [oidc_stateless_frontend_config] satosa_config_dict["BACKEND_MODULES"] = [saml_backend_config] - satosa_config_dict["INTERNAL_ATTRIBUTES"]["attributes"] = {attr_name: {"openid": [attr_name], - "saml": [attr_name]} - for attr_name in USERS[subject_id]} + satosa_config_dict["INTERNAL_ATTRIBUTES"]["attributes"] = { + attr_name: {"openid": [attr_name], + "saml": [attr_name]} + for attr_name in USERS[subject_id]} _, backend_metadata = create_entity_descriptors(SATOSAConfig(satosa_config_dict)) # application test_client = Client(make_app(SATOSAConfig(satosa_config_dict)), Response) # get frontend OP config info - provider_config = json.loads(test_client.get("/.well-known/openid-configuration").data.decode("utf-8")) + provider_config = json.loads( + test_client.get("/.well-known/openid-configuration").data.decode("utf-8")) # create auth req claims_request = ClaimsRequest(id_token=Claims(**{k: None for k in USERS[subject_id]})) req_args = {"scope": "openid", "response_type": "code", "client_id": CLIENT_ID, "redirect_uri": REDIRECT_URI, "nonce": "nonce", "claims": claims_request.to_json()} - auth_req = urlparse(provider_config["authorization_endpoint"]).path + "?" + urlencode(req_args) + auth_req = urlparse(provider_config["authorization_endpoint"]).path + "?" + urlencode( + req_args) # make auth req to proxy proxied_auth_req = test_client.get(auth_req) @@ -275,8 +301,9 @@ def test_full_stateless_code_flow(self, satosa_config_dict, oidc_stateless_front # verify auth resp from proxy resp_dict = json.loads(authn_resp.data.decode("utf-8")) - signing_key = RSAKey(key=rsa_load(oidc_stateless_frontend_config["config"]["signing_key_path"]), - use="sig", alg="RS256") + signing_key = RSAKey( + key=rsa_load(oidc_stateless_frontend_config["config"]["signing_key_path"]), + use="sig", alg="RS256") id_token_claims = JWS().verify_compact(resp_dict["id_token"], keys=[signing_key]) assert all( diff --git a/tests/flows/test_saml-oidc.py b/tests/flows/test_saml-oidc.py index bc41acfe1..845a44895 100644 --- a/tests/flows/test_saml-oidc.py +++ b/tests/flows/test_saml-oidc.py @@ -1,6 +1,10 @@ import time from urllib.parse import urlparse, parse_qsl, urlencode +import pytest +oic = pytest.importorskip("oic", reason="No pyoidc") +saml2 = pytest.importorskip("saml2", reason="No pysaml2") + from oic.oic.message import IdToken from saml2 import BINDING_HTTP_REDIRECT from saml2.config import SPConfig diff --git a/tests/satosa/test_state.py b/tests/satosa/test_state.py index eadee2182..af969718a 100644 --- a/tests/satosa/test_state.py +++ b/tests/satosa/test_state.py @@ -65,9 +65,9 @@ def test_urlstate_length_should_fit_in_browser_cookie(self): state["my_dict_hash"] = my_dict_hash state["my_dict_router"] = my_dict_router state["my_dict_backend"] = my_dict_backend - urlstate = state.urlstate(enc_key) + urlstate = state.pack(enc_key) # Some browsers only support 2000bytes, and since state is not the only parameter it should - # not be greater then half that size. + # not be greater than half that size. urlstate_len = len(quote_plus(urlstate)) print("Size of state on the url is:%s" % urlstate_len) assert urlstate_len < 1000, "Urlstate is way to long!" diff --git a/tests/util.py b/tests/util.py index c26c796fe..cedbd7894 100644 --- a/tests/util.py +++ b/tests/util.py @@ -6,16 +6,7 @@ from datetime import datetime from urllib.parse import parse_qsl, urlparse -from Cryptodome.PublicKey import RSA from bs4 import BeautifulSoup -from saml2 import server, BINDING_HTTP_POST, BINDING_HTTP_REDIRECT -from saml2.authn_context import AuthnBroker, authn_context_class_ref, PASSWORD -from saml2.cert import OpenSSLWrapper -from saml2.client import Saml2Client -from saml2.config import Config -from saml2.metadata import entity_descriptor -from saml2.saml import name_id_from_string, NAMEID_FORMAT_TRANSIENT, NAMEID_FORMAT_PERSISTENT -from saml2.samlp import NameIDPolicy from satosa.backends.base import BackendModule from satosa.frontends.base import FrontendModule @@ -25,215 +16,6 @@ from satosa.response import Response -class FakeSP(Saml2Client): - """ - A SAML service provider that can be used to perform tests. - """ - - def __init__(self, config): - """ - :type config: {dict} - :param config: SP SAML configuration. - """ - Saml2Client.__init__(self, config) - - def make_auth_req(self, entity_id, nameid_format=None, relay_state="relay_state", - request_binding=BINDING_HTTP_REDIRECT, response_binding=BINDING_HTTP_REDIRECT, - subject=None): - """ - :type entity_id: str - :rtype: str - - :param entity_id: SAML entity id - :return: Authentication URL. - """ - # Picks a binding to use for sending the Request to the IDP - _binding, destination = self.pick_binding( - 'single_sign_on_service', - [request_binding], 'idpsso', - entity_id=entity_id) - - kwargs = {} - if subject: - kwargs['subject'] = subject - - req_id, req = self.create_authn_request( - destination, - binding=response_binding, - nameid_format=nameid_format, - **kwargs - ) - - ht_args = self.apply_binding(_binding, '%s' % req, destination, - relay_state=relay_state) - - if _binding == BINDING_HTTP_POST: - form_post_html = "\n".join(ht_args["data"]) - doctree = BeautifulSoup(form_post_html, "html.parser") - saml_request = doctree.find("input", {"name": "SAMLRequest"})["value"] - resp = {"SAMLRequest": saml_request, "RelayState": relay_state} - elif _binding == BINDING_HTTP_REDIRECT: - resp = dict(parse_qsl(urlparse(dict(ht_args["headers"])["Location"]).query)) - - return destination, resp - - -class FakeIdP(server.Server): - """ - A SAML IdP that can be used to perform tests. - """ - - def __init__(self, user_db, config): - """ - :type user_db: {dict} - :type config: {dict} - - :param user_db: A dictionary with the user id as key and parameter dictionary as value. - :param config: IdP SAML configuration. - """ - server.Server.__init__(self, config=config) - self.user_db = user_db - - def __create_authn_response(self, saml_request, relay_state, binding, - userid, response_binding=BINDING_HTTP_POST): - """ - Handles a SAML request, validates and creates a SAML response but - does not apply the binding to encode it. - :type saml_request: str - :type relay_state: str - :type binding: str - :type userid: str - :rtype: tuple [string, saml2.samlp.Response] - - :param saml_request: - :param relay_state: RelayState is a parameter used by some SAML - protocol implementations to identify the specific resource at the - resource provider in an IDP initiated single sign on scenario. - :param binding: - :param userid: The user identification. - :return: A tuple containing the destination and instance of - saml2.samlp.Response - """ - auth_req = self.parse_authn_request(saml_request, binding) - binding_out, destination = self.pick_binding( - 'assertion_consumer_service', - bindings=[response_binding], - entity_id=auth_req.message.issuer.text, request=auth_req.message) - - resp_args = self.response_args(auth_req.message) - authn_broker = AuthnBroker() - authn_broker.add(authn_context_class_ref(PASSWORD), lambda: None, 10, - 'unittest_idp.xml') - authn_broker.get_authn_by_accr(PASSWORD) - resp_args['authn'] = authn_broker.get_authn_by_accr(PASSWORD) - - resp = self.create_authn_response(self.user_db[userid], - userid=userid, - **resp_args) - - return destination, resp - - def __apply_binding_to_authn_response(self, - resp, - response_binding, - relay_state, - destination): - """ - Applies the binding to the response. - """ - if response_binding == BINDING_HTTP_POST: - saml_response = base64.b64encode(str(resp).encode("utf-8")) - resp = {"SAMLResponse": saml_response, "RelayState": relay_state} - elif response_binding == BINDING_HTTP_REDIRECT: - http_args = self.apply_binding( - response_binding, - '%s' % resp, - destination, - relay_state, - response=True) - resp = dict(parse_qsl(urlparse( - dict(http_args["headers"])["Location"]).query)) - - return resp - - def handle_auth_req(self, saml_request, relay_state, binding, userid, - response_binding=BINDING_HTTP_POST): - """ - Handles a SAML request, validates and creates a SAML response. - :type saml_request: str - :type relay_state: str - :type binding: str - :type userid: str - :rtype: tuple - - :param saml_request: - :param relay_state: RelayState is a parameter used by some SAML - protocol implementations to identify the specific resource at the - resource provider in an IDP initiated single sign on scenario. - :param binding: - :param userid: The user identification. - :return: A tuple with the destination and encoded response as a string - """ - - destination, _resp = self.__create_authn_response( - saml_request, - relay_state, - binding, - userid, - response_binding) - - resp = self.__apply_binding_to_authn_response( - _resp, - response_binding, - relay_state, - destination) - - return destination, resp - - def handle_auth_req_no_name_id(self, saml_request, relay_state, binding, - userid, response_binding=BINDING_HTTP_POST): - """ - Handles a SAML request, validates and creates a SAML response but - without a element. - :type saml_request: str - :type relay_state: str - :type binding: str - :type userid: str - :rtype: tuple - - :param saml_request: - :param relay_state: RelayState is a parameter used by some SAML - protocol implementations to identify the specific resource at the - resource provider in an IDP initiated single sign on scenario. - :param binding: - :param userid: The user identification. - :return: A tuple with the destination and encoded response as a string - """ - - destination, _resp = self.__create_authn_response( - saml_request, - relay_state, - binding, - userid, - response_binding) - - # Remove the element from the response. - _resp.assertion.subject.name_id = None - - resp = self.__apply_binding_to_authn_response( - _resp, - response_binding, - relay_state, - destination) - - return destination, resp - - -def create_metadata_from_config_dict(config): - nspair = {"xs": "http://www.w3.org/2001/XMLSchema"} - conf = Config().load(config) - return entity_descriptor(conf).to_string(nspair).decode("utf-8") - def generate_cert(): cert_info = { diff --git a/tests/util_saml2.py b/tests/util_saml2.py new file mode 100644 index 000000000..c26c796fe --- /dev/null +++ b/tests/util_saml2.py @@ -0,0 +1,521 @@ +""" +Contains help methods and classes to perform tests. +""" +import base64 +import tempfile +from datetime import datetime +from urllib.parse import parse_qsl, urlparse + +from Cryptodome.PublicKey import RSA +from bs4 import BeautifulSoup +from saml2 import server, BINDING_HTTP_POST, BINDING_HTTP_REDIRECT +from saml2.authn_context import AuthnBroker, authn_context_class_ref, PASSWORD +from saml2.cert import OpenSSLWrapper +from saml2.client import Saml2Client +from saml2.config import Config +from saml2.metadata import entity_descriptor +from saml2.saml import name_id_from_string, NAMEID_FORMAT_TRANSIENT, NAMEID_FORMAT_PERSISTENT +from saml2.samlp import NameIDPolicy + +from satosa.backends.base import BackendModule +from satosa.frontends.base import FrontendModule +from satosa.internal import AuthenticationInformation +from satosa.internal import InternalData +from satosa.micro_services.base import RequestMicroService, ResponseMicroService +from satosa.response import Response + + +class FakeSP(Saml2Client): + """ + A SAML service provider that can be used to perform tests. + """ + + def __init__(self, config): + """ + :type config: {dict} + :param config: SP SAML configuration. + """ + Saml2Client.__init__(self, config) + + def make_auth_req(self, entity_id, nameid_format=None, relay_state="relay_state", + request_binding=BINDING_HTTP_REDIRECT, response_binding=BINDING_HTTP_REDIRECT, + subject=None): + """ + :type entity_id: str + :rtype: str + + :param entity_id: SAML entity id + :return: Authentication URL. + """ + # Picks a binding to use for sending the Request to the IDP + _binding, destination = self.pick_binding( + 'single_sign_on_service', + [request_binding], 'idpsso', + entity_id=entity_id) + + kwargs = {} + if subject: + kwargs['subject'] = subject + + req_id, req = self.create_authn_request( + destination, + binding=response_binding, + nameid_format=nameid_format, + **kwargs + ) + + ht_args = self.apply_binding(_binding, '%s' % req, destination, + relay_state=relay_state) + + if _binding == BINDING_HTTP_POST: + form_post_html = "\n".join(ht_args["data"]) + doctree = BeautifulSoup(form_post_html, "html.parser") + saml_request = doctree.find("input", {"name": "SAMLRequest"})["value"] + resp = {"SAMLRequest": saml_request, "RelayState": relay_state} + elif _binding == BINDING_HTTP_REDIRECT: + resp = dict(parse_qsl(urlparse(dict(ht_args["headers"])["Location"]).query)) + + return destination, resp + + +class FakeIdP(server.Server): + """ + A SAML IdP that can be used to perform tests. + """ + + def __init__(self, user_db, config): + """ + :type user_db: {dict} + :type config: {dict} + + :param user_db: A dictionary with the user id as key and parameter dictionary as value. + :param config: IdP SAML configuration. + """ + server.Server.__init__(self, config=config) + self.user_db = user_db + + def __create_authn_response(self, saml_request, relay_state, binding, + userid, response_binding=BINDING_HTTP_POST): + """ + Handles a SAML request, validates and creates a SAML response but + does not apply the binding to encode it. + :type saml_request: str + :type relay_state: str + :type binding: str + :type userid: str + :rtype: tuple [string, saml2.samlp.Response] + + :param saml_request: + :param relay_state: RelayState is a parameter used by some SAML + protocol implementations to identify the specific resource at the + resource provider in an IDP initiated single sign on scenario. + :param binding: + :param userid: The user identification. + :return: A tuple containing the destination and instance of + saml2.samlp.Response + """ + auth_req = self.parse_authn_request(saml_request, binding) + binding_out, destination = self.pick_binding( + 'assertion_consumer_service', + bindings=[response_binding], + entity_id=auth_req.message.issuer.text, request=auth_req.message) + + resp_args = self.response_args(auth_req.message) + authn_broker = AuthnBroker() + authn_broker.add(authn_context_class_ref(PASSWORD), lambda: None, 10, + 'unittest_idp.xml') + authn_broker.get_authn_by_accr(PASSWORD) + resp_args['authn'] = authn_broker.get_authn_by_accr(PASSWORD) + + resp = self.create_authn_response(self.user_db[userid], + userid=userid, + **resp_args) + + return destination, resp + + def __apply_binding_to_authn_response(self, + resp, + response_binding, + relay_state, + destination): + """ + Applies the binding to the response. + """ + if response_binding == BINDING_HTTP_POST: + saml_response = base64.b64encode(str(resp).encode("utf-8")) + resp = {"SAMLResponse": saml_response, "RelayState": relay_state} + elif response_binding == BINDING_HTTP_REDIRECT: + http_args = self.apply_binding( + response_binding, + '%s' % resp, + destination, + relay_state, + response=True) + resp = dict(parse_qsl(urlparse( + dict(http_args["headers"])["Location"]).query)) + + return resp + + def handle_auth_req(self, saml_request, relay_state, binding, userid, + response_binding=BINDING_HTTP_POST): + """ + Handles a SAML request, validates and creates a SAML response. + :type saml_request: str + :type relay_state: str + :type binding: str + :type userid: str + :rtype: tuple + + :param saml_request: + :param relay_state: RelayState is a parameter used by some SAML + protocol implementations to identify the specific resource at the + resource provider in an IDP initiated single sign on scenario. + :param binding: + :param userid: The user identification. + :return: A tuple with the destination and encoded response as a string + """ + + destination, _resp = self.__create_authn_response( + saml_request, + relay_state, + binding, + userid, + response_binding) + + resp = self.__apply_binding_to_authn_response( + _resp, + response_binding, + relay_state, + destination) + + return destination, resp + + def handle_auth_req_no_name_id(self, saml_request, relay_state, binding, + userid, response_binding=BINDING_HTTP_POST): + """ + Handles a SAML request, validates and creates a SAML response but + without a element. + :type saml_request: str + :type relay_state: str + :type binding: str + :type userid: str + :rtype: tuple + + :param saml_request: + :param relay_state: RelayState is a parameter used by some SAML + protocol implementations to identify the specific resource at the + resource provider in an IDP initiated single sign on scenario. + :param binding: + :param userid: The user identification. + :return: A tuple with the destination and encoded response as a string + """ + + destination, _resp = self.__create_authn_response( + saml_request, + relay_state, + binding, + userid, + response_binding) + + # Remove the element from the response. + _resp.assertion.subject.name_id = None + + resp = self.__apply_binding_to_authn_response( + _resp, + response_binding, + relay_state, + destination) + + return destination, resp + + +def create_metadata_from_config_dict(config): + nspair = {"xs": "http://www.w3.org/2001/XMLSchema"} + conf = Config().load(config) + return entity_descriptor(conf).to_string(nspair).decode("utf-8") + + +def generate_cert(): + cert_info = { + "cn": "localhost", + "country_code": "se", + "state": "ac", + "city": "Umea", + "organization": "ITS", + "organization_unit": "DIRG" + } + osw = OpenSSLWrapper() + cert_str, key_str = osw.create_certificate(cert_info, request=False) + return cert_str, key_str + + +def write_cert(cert_path, key_path): + cert, key = generate_cert() + with open(cert_path, "wb") as cert_file: + cert_file.write(cert) + with open(key_path, "wb") as key_file: + key_file.write(key) + + +class FileGenerator(object): + """ + Creates different types of temporary files that is useful for testing. + """ + _instance = None + + def __init__(self): + if FileGenerator._instance: + raise TypeError('Singletons must be accessed through `get_instance()`.') + else: + FileGenerator._instance = self + self.generate_certs = {} + self.metadata = {} + + @staticmethod + def get_instance(): + """ + :rtype: FileGenerator + + :return: A singleton instance of the class. + """ + if FileGenerator._instance is None: + FileGenerator._instance = FileGenerator() + return FileGenerator._instance + + def generate_cert(self, code=None): + """ + Will generate a certificate and key. If code is used the same certificate and key will + always be returned for the same code. + :type code: str + :rtype: (tempfile._TemporaryFileWrapper, tempfile._TemporaryFileWrapper) + + :param: code: A unique code to represent a certificate and key. + :return: A certificate and key temporary file. + """ + if code in self.generate_certs: + return self.generate_certs[code] + + cert_str, key_str = generate_cert() + + cert_file = tempfile.NamedTemporaryFile() + cert_file.write(cert_str) + cert_file.flush() + key_file = tempfile.NamedTemporaryFile() + key_file.write(key_str) + key_file.flush() + if code is not None: + self.generate_certs[code] = cert_file, key_file + return cert_file, key_file + + def create_metadata(self, config, code=None): + """ + Will generate a metadata file. If code is used the same metadata file will + always be returned for the same code. + :type config: {dict} + :type code: str + + :param config: A SAML configuration. + :param code: A unique code to represent a certificate and key. + """ + if code in self.metadata: + return self.metadata[code] + + desc = create_metadata_from_config_dict(config) + + tmp_file = tempfile.NamedTemporaryFile() + tmp_file.write(desc.encode("utf-8")) + tmp_file.flush() + if code: + self.metadata[code] = tmp_file + return tmp_file + + +def private_to_public_key(pk_file): + f = open(pk_file, 'r') + pk = RSA.importKey(f.read()) + return pk.publickey().exportKey('PEM') + + +def create_name_id(): + """ + :rtype: str + + :return: Returns a SAML nameid as XML string. + """ + test_name_id = """ + + tmatsuo@example.com + +""" + return name_id_from_string(test_name_id) + + +def create_name_id_policy_transient(): + """ + Creates a transient name id policy. + :return: + """ + nameid_format = NAMEID_FORMAT_TRANSIENT + name_id_policy = NameIDPolicy(format=nameid_format) + return name_id_policy + + +def create_name_id_policy_persistent(): + """ + Creates a transient name id policy. + :return: + """ + nameid_format = NAMEID_FORMAT_PERSISTENT + name_id_policy = NameIDPolicy(format=nameid_format) + return name_id_policy + + +class FakeBackend(BackendModule): + def __init__(self, start_auth_func=None, internal_attributes=None, + base_url="", name="FakeBackend", + register_endpoints_func=None): + super().__init__(None, internal_attributes, base_url, name) + + self.start_auth_func = start_auth_func + self.register_endpoints_func = register_endpoints_func + + def start_auth(self, context, request_info, state): + """ + TODO comment + :type context: TODO comment + :type request_info: TODO comment + :type state: TODO comment + + :param context: TODO comment + :param request_info: TODO comment + :param state: TODO comment + """ + if self.start_auth: + return self.start_auth(context, request_info, state) + return None + + def register_endpoints(self): + """ + TODO comment + """ + if self.register_endpoints_func: + return self.register_endpoints_func() + return None + + +class FakeFrontend(FrontendModule): + """ + TODO comment + """ + + def __init__(self, handle_authn_request_func=None, internal_attributes=None, + base_url="", name="FakeFrontend", + handle_authn_response_func=None, + register_endpoints_func=None): + super().__init__(None, internal_attributes, base_url, name) + self.handle_authn_request_func = handle_authn_request_func + self.handle_authn_response_func = handle_authn_response_func + self.register_endpoints_func = register_endpoints_func + + def handle_authn_request(self, context, binding_in): + """ + TODO comment + + :type context: + :type binding_in: + + :param context: + :param binding_in: + :return: + """ + if self.handle_authn_request_func: + return self.handle_authn_request_func(context, binding_in) + return None + + def handle_authn_response(self, context, internal_response, state): + """ + TODO comment + :type context: TODO comment + :type internal_response: TODO comment + :type state: TODO comment + + :param context: TODO comment + :param internal_response: TODO comment + :param state: TODO comment + :return: TODO comment + """ + if self.handle_authn_response_func: + return self.handle_authn_response_func(context, internal_response, state) + return None + + def register_endpoints(self, backend_names): + if self.register_endpoints_func: + return self.register_endpoints_func(backend_names) + + +class TestBackend(BackendModule): + __test__ = False + + def __init__(self, auth_callback_func, internal_attributes, config, base_url, name): + super().__init__(auth_callback_func, internal_attributes, base_url, name) + + def register_endpoints(self): + return [("^{}/response$".format(self.name), self.handle_response)] + + def start_auth(self, context, internal_request): + return Response("Auth request received, passed to test backend") + + def handle_response(self, context): + auth_info = AuthenticationInformation("test", str(datetime.now()), "test_issuer") + internal_resp = InternalData(auth_info=auth_info) + internal_resp.attributes = context.request + internal_resp.subject_id = "test_user" + return self.auth_callback_func(context, internal_resp) + + +class TestFrontend(FrontendModule): + __test__ = False + + def __init__(self, auth_req_callback_func, internal_attributes, config, base_url, name): + super().__init__(auth_req_callback_func, internal_attributes, base_url, name) + + def register_endpoints(self, backend_names): + url_map = [("^{}/{}/request$".format(p, self.name), self.handle_request) for p in backend_names] + return url_map + + def handle_request(self, context): + internal_req = InternalData( + subject_type=NAMEID_FORMAT_TRANSIENT, requester="test_client" + ) + return self.auth_req_callback_func(context, internal_req) + + def handle_authn_response(self, context, internal_resp): + return Response("Auth response received, passed to test frontend") + + +class TestRequestMicroservice(RequestMicroService): + __test__ = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def register_endpoints(self): + return [("^request_microservice/callback$", self.callback)] + + def callback(self): + pass + + +class TestResponseMicroservice(ResponseMicroService): + __test__ = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def register_endpoints(self): + return [("^response_microservice/callback$", self.callback)] + + def callback(self): + pass