diff --git a/EXAMPLES.md b/EXAMPLES.md index a1695a78..959e0314 100644 --- a/EXAMPLES.md +++ b/EXAMPLES.md @@ -7,7 +7,7 @@ - [Connections](#connections) - [Error handling](#error-handling) - [Asynchronous environments](#asynchronous-environments) - + ## Authentication SDK ### ID token validation diff --git a/auth0/authentication/async_token_verifier.py b/auth0/authentication/async_token_verifier.py index 64b97e5e..058e493f 100644 --- a/auth0/authentication/async_token_verifier.py +++ b/auth0/authentication/async_token_verifier.py @@ -1,8 +1,16 @@ """Token Verifier module""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + from .. import TokenValidationError from ..rest_async import AsyncRestClient from .token_verifier import AsymmetricSignatureVerifier, JwksFetcher, TokenVerifier +if TYPE_CHECKING: + from aiohttp import ClientSession + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey + class AsyncAsymmetricSignatureVerifier(AsymmetricSignatureVerifier): """Async verifier for RSA signatures, which rely on public key certificates. @@ -12,11 +20,11 @@ class AsyncAsymmetricSignatureVerifier(AsymmetricSignatureVerifier): algorithm (str, optional): The expected signing algorithm. Defaults to "RS256". """ - def __init__(self, jwks_url, algorithm="RS256"): + def __init__(self, jwks_url: str, algorithm: str = "RS256") -> None: super().__init__(jwks_url, algorithm) self._fetcher = AsyncJwksFetcher(jwks_url) - def set_session(self, session): + def set_session(self, session: ClientSession) -> None: """Set Client Session to improve performance by reusing session. Args: @@ -32,7 +40,7 @@ async def _fetch_key(self, key_id=None): key_id (str): The key's key id.""" return await self._fetcher.get_key(key_id) - async def verify_signature(self, token): + async def verify_signature(self, token) -> dict[str, Any]: """Verifies the signature of the given JSON web token. Args: @@ -57,11 +65,11 @@ class AsyncJwksFetcher(JwksFetcher): cache_ttl (str, optional): The lifetime of the JWK set cache in seconds. Defaults to 600 seconds. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._async_client = AsyncRestClient(None) - def set_session(self, session): + def set_session(self, session: ClientSession) -> None: """Set Client Session to improve performance by reusing session. Args: @@ -70,7 +78,7 @@ def set_session(self, session): """ self._async_client.set_session(session) - async def _fetch_jwks(self, force=False): + async def _fetch_jwks(self, force: bool = False) -> dict[str, RSAPublicKey]: """Attempts to obtain the JWK set from the cache, as long as it's still valid. When not, it will perform a network request to the jwks_url to obtain a fresh result and update the cache value with it. @@ -90,7 +98,7 @@ async def _fetch_jwks(self, force=False): self._cache_is_fresh = False return self._cache_value - async def get_key(self, key_id): + async def get_key(self, key_id: str) -> RSAPublicKey: """Obtains the JWK associated with the given key id. Args: @@ -126,7 +134,13 @@ class AsyncTokenVerifier(TokenVerifier): Defaults to 60 seconds. """ - def __init__(self, signature_verifier, issuer, audience, leeway=0): + def __init__( + self, + signature_verifier: AsyncAsymmetricSignatureVerifier, + issuer: str, + audience: str, + leeway: int = 0, + ) -> None: if not signature_verifier or not isinstance( signature_verifier, AsyncAsymmetricSignatureVerifier ): @@ -140,7 +154,7 @@ def __init__(self, signature_verifier, issuer, audience, leeway=0): self._sv = signature_verifier self._clock = None # legacy testing requirement - def set_session(self, session): + def set_session(self, session: ClientSession) -> None: """Set Client Session to improve performance by reusing session. Args: @@ -149,7 +163,13 @@ def set_session(self, session): """ self._sv.set_session(session) - async def verify(self, token, nonce=None, max_age=None, organization=None): + async def verify( + self, + token: str, + nonce: str | None = None, + max_age: int | None = None, + organization: str | None = None, + ) -> dict[str, Any]: """Attempts to verify the given ID token, following the steps defined in the OpenID Connect spec. Args: diff --git a/auth0/authentication/base.py b/auth0/authentication/base.py index 4e6417b0..01c79d2e 100644 --- a/auth0/authentication/base.py +++ b/auth0/authentication/base.py @@ -1,4 +1,9 @@ +from __future__ import annotations + +from typing import Any + from auth0.rest import RestClient, RestClientOptions +from auth0.types import RequestData, TimeoutType from .client_authentication import add_client_authentication @@ -21,15 +26,15 @@ class AuthenticationBase: def __init__( self, - domain, - client_id, - client_secret=None, - client_assertion_signing_key=None, - client_assertion_signing_alg=None, - telemetry=True, - timeout=5.0, - protocol="https", - ): + domain: str, + client_id: str, + client_secret: str | None = None, + client_assertion_signing_key: str | None = None, + client_assertion_signing_alg: str | None = None, + telemetry: bool = True, + timeout: TimeoutType = 5.0, + protocol: str = "https", + ) -> None: self.domain = domain self.client_id = client_id self.client_secret = client_secret @@ -41,7 +46,7 @@ def __init__( options=RestClientOptions(telemetry=telemetry, timeout=timeout, retries=0), ) - def _add_client_authentication(self, payload): + def _add_client_authentication(self, payload: dict[str, Any]) -> dict[str, Any]: return add_client_authentication( payload, self.domain, @@ -51,13 +56,28 @@ def _add_client_authentication(self, payload): self.client_assertion_signing_alg, ) - def post(self, url, data=None, headers=None): + def post( + self, + url: str, + data: RequestData | None = None, + headers: dict[str, str] | None = None, + ) -> Any: return self.client.post(url, data=data, headers=headers) - def authenticated_post(self, url, data=None, headers=None): + def authenticated_post( + self, + url: str, + data: dict[str, Any], + headers: dict[str, str] | None = None, + ) -> Any: return self.client.post( url, data=self._add_client_authentication(data), headers=headers ) - def get(self, url, params=None, headers=None): + def get( + self, + url: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> Any: return self.client.get(url, params, headers) diff --git a/auth0/authentication/client_authentication.py b/auth0/authentication/client_authentication.py index 7ab742f9..849058f4 100644 --- a/auth0/authentication/client_authentication.py +++ b/auth0/authentication/client_authentication.py @@ -1,18 +1,24 @@ +from __future__ import annotations + import datetime import uuid +from typing import Any import jwt def create_client_assertion_jwt( - domain, client_id, client_assertion_signing_key, client_assertion_signing_alg -): + domain: str, + client_id: str, + client_assertion_signing_key: str, + client_assertion_signing_alg: str | None, +) -> str: """Creates a JWT for the client_assertion field. Args: domain (str): The domain of your Auth0 tenant client_id (str): Your application's client ID - client_assertion_signing_key (str, optional): Private key used to sign the client assertion JWT + client_assertion_signing_key (str): Private key used to sign the client assertion JWT client_assertion_signing_alg (str, optional): Algorithm used to sign the client assertion JWT (defaults to 'RS256') Returns: @@ -35,20 +41,20 @@ def create_client_assertion_jwt( def add_client_authentication( - payload, - domain, - client_id, - client_secret, - client_assertion_signing_key, - client_assertion_signing_alg, -): + payload: dict[str, Any], + domain: str, + client_id: str, + client_secret: str | None, + client_assertion_signing_key: str | None, + client_assertion_signing_alg: str | None, +) -> dict[str, Any]: """Adds the client_assertion or client_secret fields to authenticate a payload. Args: payload (dict): The POST payload that needs additional fields to be authenticated. domain (str): The domain of your Auth0 tenant client_id (str): Your application's client ID - client_secret (str): Your application's client secret + client_secret (str, optional): Your application's client secret client_assertion_signing_key (str, optional): Private key used to sign the client assertion JWT client_assertion_signing_alg (str, optional): Algorithm used to sign the client assertion JWT (defaults to 'RS256') diff --git a/auth0/authentication/database.py b/auth0/authentication/database.py index c4691b27..9bfd6144 100644 --- a/auth0/authentication/database.py +++ b/auth0/authentication/database.py @@ -1,4 +1,6 @@ -import warnings +from __future__ import annotations + +from typing import Any from .base import AuthenticationBase @@ -12,17 +14,17 @@ class Database(AuthenticationBase): def signup( self, - email, - password, - connection, - username=None, - user_metadata=None, - given_name=None, - family_name=None, - name=None, - nickname=None, - picture=None, - ): + email: str, + password: str, + connection: str, + username: str | None = None, + user_metadata: dict[str, Any] | None = None, + given_name: str | None = None, + family_name: str | None = None, + name: str | None = None, + nickname: str | None = None, + picture: str | None = None, + ) -> dict[str, Any]: """Signup using email and password. Args: @@ -50,7 +52,7 @@ def signup( See: https://auth0.com/docs/api/authentication#signup """ - body = { + body: dict[str, Any] = { "client_id": self.client_id, "email": email, "password": password, @@ -71,11 +73,14 @@ def signup( if picture: body.update({"picture": picture}) - return self.post( + data: dict[str, Any] = self.post( f"{self.protocol}://{self.domain}/dbconnections/signup", data=body ) + return data - def change_password(self, email, connection, password=None): + def change_password( + self, email: str, connection: str, password: str | None = None + ) -> str: """Asks to change a password for a given user. email (str): The user's email address. @@ -88,7 +93,8 @@ def change_password(self, email, connection, password=None): "connection": connection, } - return self.post( + data: str = self.post( f"{self.protocol}://{self.domain}/dbconnections/change_password", data=body, ) + return data diff --git a/auth0/authentication/delegated.py b/auth0/authentication/delegated.py index 58ae4cb8..1266db11 100644 --- a/auth0/authentication/delegated.py +++ b/auth0/authentication/delegated.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Any + from .base import AuthenticationBase @@ -10,13 +14,13 @@ class Delegated(AuthenticationBase): def get_token( self, - target, - api_type, - grant_type, - id_token=None, - refresh_token=None, - scope="openid", - ): + target: str, + api_type: str, + grant_type: str, + id_token: str | None = None, + refresh_token: str | None = None, + scope: str = "openid", + ) -> Any: """Obtain a delegation token.""" if id_token and refresh_token: diff --git a/auth0/authentication/enterprise.py b/auth0/authentication/enterprise.py index 0b4b3c5f..518d1001 100644 --- a/auth0/authentication/enterprise.py +++ b/auth0/authentication/enterprise.py @@ -1,3 +1,5 @@ +from typing import Any + from .base import AuthenticationBase @@ -9,7 +11,7 @@ class Enterprise(AuthenticationBase): domain (str): Your auth0 domain (e.g: my-domain.us.auth0.com) """ - def saml_metadata(self): + def saml_metadata(self) -> Any: """Get SAML2.0 Metadata.""" return self.get( @@ -18,7 +20,7 @@ def saml_metadata(self): ) ) - def wsfed_metadata(self): + def wsfed_metadata(self) -> Any: """Returns the WS-Federation Metadata.""" url = "{}://{}/wsfed/FederationMetadata/2007-06/FederationMetadata.xml" diff --git a/auth0/authentication/get_token.py b/auth0/authentication/get_token.py index 4986e55b..9de89291 100644 --- a/auth0/authentication/get_token.py +++ b/auth0/authentication/get_token.py @@ -1,5 +1,8 @@ +from __future__ import annotations + +from typing import Any + from .base import AuthenticationBase -from .client_authentication import add_client_authentication class GetToken(AuthenticationBase): @@ -12,10 +15,10 @@ class GetToken(AuthenticationBase): def authorization_code( self, - code, - redirect_uri, - grant_type="authorization_code", - ): + code: str, + redirect_uri: str | None, + grant_type: str = "authorization_code", + ) -> Any: """Authorization code grant This is the OAuth 2.0 grant that regular web apps utilize in order @@ -47,11 +50,11 @@ def authorization_code( def authorization_code_pkce( self, - code_verifier, - code, - redirect_uri, - grant_type="authorization_code", - ): + code_verifier: str, + code: str, + redirect_uri: str | None, + grant_type: str = "authorization_code", + ) -> Any: """Authorization code pkce grant This is the OAuth 2.0 grant that mobile apps utilize in order to access an API. @@ -86,9 +89,9 @@ def authorization_code_pkce( def client_credentials( self, - audience, - grant_type="client_credentials", - ): + audience: str, + grant_type: str = "client_credentials", + ) -> Any: """Client credentials grant This is the OAuth 2.0 grant that server processes utilize in @@ -116,13 +119,13 @@ def client_credentials( def login( self, - username, - password, - scope=None, - realm=None, - audience=None, - grant_type="http://auth0.com/oauth/grant-type/password-realm", - ): + username: str, + password: str, + scope: str | None = None, + realm: str | None = None, + audience: str | None = None, + grant_type: str = "http://auth0.com/oauth/grant-type/password-realm", + ) -> Any: """Calls /oauth/token endpoint with password-realm grant type @@ -168,10 +171,10 @@ def login( def refresh_token( self, - refresh_token, - scope="", - grant_type="refresh_token", - ): + refresh_token: str, + scope: str = "", + grant_type: str = "refresh_token", + ) -> Any: """Calls /oauth/token endpoint with refresh token grant type Use this endpoint to refresh an access token, using the refresh token you got during authorization. @@ -199,7 +202,9 @@ def refresh_token( }, ) - def passwordless_login(self, username, otp, realm, scope, audience): + def passwordless_login( + self, username: str, otp: str, realm: str, scope: str, audience: str + ) -> Any: """Calls /oauth/token endpoint with http://auth0.com/oauth/grant-type/passwordless/otp grant type Once the verification code was received, login the user using this endpoint with their diff --git a/auth0/authentication/passwordless.py b/auth0/authentication/passwordless.py index 63d26b4d..dc4ac1af 100644 --- a/auth0/authentication/passwordless.py +++ b/auth0/authentication/passwordless.py @@ -1,4 +1,6 @@ -import warnings +from __future__ import annotations + +from typing import Any from .base import AuthenticationBase @@ -11,7 +13,9 @@ class Passwordless(AuthenticationBase): domain (str): Your auth0 domain (e.g: my-domain.us.auth0.com) """ - def email(self, email, send="link", auth_params=None): + def email( + self, email: str, send: str = "link", auth_params: dict[str, str] | None = None + ) -> Any: """Start flow sending an email. Given the user email address, it will send an email with: @@ -35,7 +39,7 @@ def email(self, email, send="link", auth_params=None): auth_params (dict, optional): Parameters to append or override. """ - data = { + data: dict[str, Any] = { "client_id": self.client_id, "connection": "email", "email": email, @@ -48,7 +52,7 @@ def email(self, email, send="link", auth_params=None): f"{self.protocol}://{self.domain}/passwordless/start", data=data ) - def sms(self, phone_number): + def sms(self, phone_number: str) -> Any: """Start flow sending an SMS message. Given the user phone number, it will send an SMS with diff --git a/auth0/authentication/revoke_token.py b/auth0/authentication/revoke_token.py index ded6397b..29223d45 100644 --- a/auth0/authentication/revoke_token.py +++ b/auth0/authentication/revoke_token.py @@ -1,3 +1,5 @@ +from typing import Any + from .base import AuthenticationBase @@ -8,7 +10,7 @@ class RevokeToken(AuthenticationBase): domain (str): Your auth0 domain (e.g: my-domain.us.auth0.com) """ - def revoke_refresh_token(self, token): + def revoke_refresh_token(self, token: str) -> Any: """Revokes a Refresh Token if it has been compromised Each revocation request invalidates not only the specific token, but all other tokens diff --git a/auth0/authentication/social.py b/auth0/authentication/social.py index c2517038..dc9b6a3a 100644 --- a/auth0/authentication/social.py +++ b/auth0/authentication/social.py @@ -1,3 +1,5 @@ +from typing import Any + from .base import AuthenticationBase @@ -9,7 +11,7 @@ class Social(AuthenticationBase): domain (str): Your auth0 domain (e.g: my-domain.us.auth0.com) """ - def login(self, access_token, connection, scope="openid"): + def login(self, access_token: str, connection: str, scope: str = "openid") -> Any: """Login using a social provider's access token Given the social provider's access_token and the connection specified, diff --git a/auth0/authentication/token_verifier.py b/auth0/authentication/token_verifier.py index 08331efc..030eda27 100644 --- a/auth0/authentication/token_verifier.py +++ b/auth0/authentication/token_verifier.py @@ -1,12 +1,18 @@ """Token Verifier module""" +from __future__ import annotations + import json import time +from typing import TYPE_CHECKING, Any, ClassVar import jwt import requests from auth0.exceptions import TokenValidationError +if TYPE_CHECKING: + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey + class SignatureVerifier: """Abstract class that will verify a given JSON web token's signature @@ -16,7 +22,7 @@ class SignatureVerifier: algorithm (str): The expected signing algorithm (e.g. RS256). """ - DISABLE_JWT_CHECKS = { + DISABLE_JWT_CHECKS: ClassVar[dict[str, bool]] = { "verify_signature": True, "verify_exp": False, "verify_nbf": False, @@ -28,24 +34,24 @@ class SignatureVerifier: "require_nbf": False, } - def __init__(self, algorithm): + def __init__(self, algorithm: str) -> None: if not algorithm or type(algorithm) != str: raise ValueError("algorithm must be specified.") self._algorithm = algorithm - def _fetch_key(self, key_id=None): + def _fetch_key(self, key_id: str) -> str | RSAPublicKey: """Obtains the key associated to the given key id. Must be implemented by subclasses. Args: - key_id (str, optional): The id of the key to fetch. + key_id (str): The id of the key to fetch. Returns: the key to use for verifying a cryptographic signature """ raise NotImplementedError - def _get_kid(self, token): + def _get_kid(self, token: str) -> str | None: """Gets the key id from the kid claim of the header of the token Args: @@ -72,7 +78,7 @@ def _get_kid(self, token): return header.get("kid", None) - def _decode_jwt(self, token, secret_or_certificate): + def _decode_jwt(self, token: str, secret_or_certificate: str) -> dict[str, Any]: """Verifies and decodes the given JSON web token with the given public key or shared secret. Args: @@ -94,7 +100,7 @@ def _decode_jwt(self, token, secret_or_certificate): raise TokenValidationError("Invalid token signature.") return decoded - def verify_signature(self, token): + def verify_signature(self, token: str) -> dict[str, Any]: """Verifies the signature of the given JSON web token. Args: @@ -105,9 +111,11 @@ def verify_signature(self, token): or the token's signature doesn't match the calculated one. """ kid = self._get_kid(token) + if kid is None: + kid = "" secret_or_certificate = self._fetch_key(key_id=kid) - return self._decode_jwt(token, secret_or_certificate) + return self._decode_jwt(token, secret_or_certificate) # type: ignore[arg-type] class SymmetricSignatureVerifier(SignatureVerifier): @@ -118,11 +126,11 @@ class SymmetricSignatureVerifier(SignatureVerifier): algorithm (str, optional): The expected signing algorithm. Defaults to "HS256". """ - def __init__(self, shared_secret, algorithm="HS256"): + def __init__(self, shared_secret: str, algorithm: str = "HS256") -> None: super().__init__(algorithm) self._shared_secret = shared_secret - def _fetch_key(self, key_id=None): + def _fetch_key(self, key_id: str = "") -> str: return self._shared_secret @@ -135,20 +143,19 @@ class JwksFetcher: cache_ttl (str, optional): The lifetime of the JWK set cache in seconds. Defaults to 600 seconds. """ - CACHE_TTL = 600 # 10 min cache lifetime + CACHE_TTL: ClassVar[int] = 600 # 10 min cache lifetime - def __init__(self, jwks_url, cache_ttl=CACHE_TTL): + def __init__(self, jwks_url: str, cache_ttl: int = CACHE_TTL) -> None: self._jwks_url = jwks_url self._init_cache(cache_ttl) - return - def _init_cache(self, cache_ttl): - self._cache_value = {} - self._cache_date = 0 + def _init_cache(self, cache_ttl: int) -> None: + self._cache_value: dict[str, RSAPublicKey] = {} + self._cache_date = 0.0 self._cache_ttl = cache_ttl self._cache_is_fresh = False - def _cache_expired(self): + def _cache_expired(self) -> bool: """Checks if the cache is expired Returns: @@ -156,7 +163,7 @@ def _cache_expired(self): """ return self._cache_date + self._cache_ttl < time.time() - def _cache_jwks(self, jwks): + def _cache_jwks(self, jwks: dict[str, Any]) -> None: """Cache the response of the JWKS request Args: @@ -166,7 +173,7 @@ def _cache_jwks(self, jwks): self._cache_is_fresh = True self._cache_date = time.time() - def _fetch_jwks(self, force=False): + def _fetch_jwks(self, force: bool = False) -> dict[str, RSAPublicKey]: """Attempts to obtain the JWK set from the cache, as long as it's still valid. When not, it will perform a network request to the jwks_url to obtain a fresh result and update the cache value with it. @@ -178,7 +185,7 @@ def _fetch_jwks(self, force=False): self._cache_value = {} response = requests.get(self._jwks_url) if response.ok: - jwks = response.json() + jwks: dict[str, Any] = response.json() self._cache_jwks(jwks) return self._cache_value @@ -186,20 +193,22 @@ def _fetch_jwks(self, force=False): return self._cache_value @staticmethod - def _parse_jwks(jwks): + def _parse_jwks(jwks: dict[str, Any]) -> dict[str, RSAPublicKey]: """ Converts a JWK string representation into a binary certificate in PEM format. """ - keys = {} + keys: dict[str, RSAPublicKey] = {} for key in jwks["keys"]: # noinspection PyUnresolvedReferences # requirement already includes cryptography -> pyjwt[crypto] - rsa_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key)) + rsa_key: RSAPublicKey = jwt.algorithms.RSAAlgorithm.from_jwk( + json.dumps(key) + ) keys[key["kid"]] = rsa_key return keys - def get_key(self, key_id): + def get_key(self, key_id: str) -> RSAPublicKey: """Obtains the JWK associated with the given key id. Args: @@ -232,11 +241,16 @@ class AsymmetricSignatureVerifier(SignatureVerifier): cache_ttl (int, optional): The lifetime of the JWK set cache in seconds. Defaults to 600 seconds. """ - def __init__(self, jwks_url, algorithm="RS256", cache_ttl=JwksFetcher.CACHE_TTL): + def __init__( + self, + jwks_url: str, + algorithm: str = "RS256", + cache_ttl: int = JwksFetcher.CACHE_TTL, + ) -> None: super().__init__(algorithm) self._fetcher = JwksFetcher(jwks_url, cache_ttl) - def _fetch_key(self, key_id=None): + def _fetch_key(self, key_id: str) -> RSAPublicKey: return self._fetcher.get_key(key_id) @@ -252,7 +266,13 @@ class TokenVerifier: Defaults to 60 seconds. """ - def __init__(self, signature_verifier, issuer, audience, leeway=0): + def __init__( + self, + signature_verifier: SignatureVerifier, + issuer: str, + audience: str, + leeway: int = 0, + ) -> None: if not signature_verifier or not isinstance( signature_verifier, SignatureVerifier ): @@ -266,7 +286,13 @@ def __init__(self, signature_verifier, issuer, audience, leeway=0): self._sv = signature_verifier self._clock = None # visible for testing - def verify(self, token, nonce=None, max_age=None, organization=None): + def verify( + self, + token: str, + nonce: str | None = None, + max_age: int | None = None, + organization: str | None = None, + ) -> dict[str, Any]: """Attempts to verify the given ID token, following the steps defined in the OpenID Connect spec. Args: @@ -296,7 +322,13 @@ def verify(self, token, nonce=None, max_age=None, organization=None): return payload - def _verify_payload(self, payload, nonce=None, max_age=None, organization=None): + def _verify_payload( + self, + payload: dict[str, Any], + nonce: str | None = None, + max_age: int | None = None, + organization: str | None = None, + ) -> None: # Issuer if "iss" not in payload or not isinstance(payload["iss"], str): raise TokenValidationError( diff --git a/auth0/authentication/users.py b/auth0/authentication/users.py index 255c90f6..9535edab 100644 --- a/auth0/authentication/users.py +++ b/auth0/authentication/users.py @@ -1,4 +1,9 @@ +from __future__ import annotations + +from typing import Any + from auth0.rest import RestClient, RestClientOptions +from auth0.types import TimeoutType class Users: @@ -13,11 +18,11 @@ class Users: def __init__( self, - domain, - telemetry=True, - timeout=5.0, - protocol="https", - ): + domain: str, + telemetry: bool = True, + timeout: TimeoutType = 5.0, + protocol: str = "https", + ) -> None: self.domain = domain self.protocol = protocol self.client = RestClient( @@ -31,7 +36,7 @@ def __init__( domain (str): Your auth0 domain (e.g: username.auth0.com) """ - def userinfo(self, access_token): + def userinfo(self, access_token: str) -> dict[str, Any]: """Returns the user information based on the Auth0 access token. This endpoint will work only if openid was granted as a scope for the access_token. @@ -42,7 +47,8 @@ def userinfo(self, access_token): The user profile. """ - return self.client.get( + data: dict[str, Any] = self.client.get( url=f"{self.protocol}://{self.domain}/userinfo", headers={"Authorization": f"Bearer {access_token}"}, ) + return data diff --git a/auth0/exceptions.py b/auth0/exceptions.py index 7f9aa325..8515be04 100644 --- a/auth0/exceptions.py +++ b/auth0/exceptions.py @@ -1,16 +1,27 @@ +from __future__ import annotations + +from typing import Any + + class Auth0Error(Exception): - def __init__(self, status_code, error_code, message, content=None): + def __init__( + self, + status_code: int, + error_code: str, + message: str, + content: Any | None = None, + ) -> None: self.status_code = status_code self.error_code = error_code self.message = message self.content = content - def __str__(self): + def __str__(self) -> str: return f"{self.status_code}: {self.message}" class RateLimitError(Auth0Error): - def __init__(self, error_code, message, reset_at): + def __init__(self, error_code: str, message: str, reset_at: int) -> None: super().__init__(status_code=429, error_code=error_code, message=message) self.reset_at = reset_at diff --git a/auth0/rest.py b/auth0/rest.py index c84e5d7f..41282b74 100644 --- a/auth0/rest.py +++ b/auth0/rest.py @@ -1,13 +1,20 @@ +from __future__ import annotations + import base64 import json import platform import sys from random import randint from time import sleep +from typing import TYPE_CHECKING, Any, Mapping import requests from auth0.exceptions import Auth0Error, RateLimitError +from auth0.types import RequestData, TimeoutType + +if TYPE_CHECKING: + from auth0.rest_async import RequestsResponse UNKNOWN_ERROR = "a0.sdk.internal.unknown" @@ -32,25 +39,22 @@ class RestClientOptions: (defaults to 3) """ - def __init__(self, telemetry=None, timeout=None, retries=None): - self.telemetry = True - self.timeout = 5.0 - self.retries = 3 - - if telemetry is not None: - self.telemetry = telemetry - - if timeout is not None: - self.timeout = timeout - - if retries is not None: - self.retries = retries + def __init__( + self, + telemetry: bool = True, + timeout: TimeoutType = 5.0, + retries: int = 3, + ) -> None: + self.telemetry = telemetry + self.timeout = timeout + self.retries = retries class RestClient: """Provides simple methods for handling all RESTful api endpoints. Args: + jwt (str, optional): The JWT to be used with the RestClient. telemetry (bool, optional): Enable or disable Telemetry (defaults to True) timeout (float or tuple, optional): Change the requests @@ -64,7 +68,13 @@ class RestClient: (defaults to 3) """ - def __init__(self, jwt, telemetry=True, timeout=5.0, options=None): + def __init__( + self, + jwt: str | None, + telemetry: bool = True, + timeout: TimeoutType = 5.0, + options: RestClientOptions | None = None, + ) -> None: if options is None: options = RestClientOptions(telemetry=telemetry, timeout=timeout) @@ -111,22 +121,27 @@ def __init__(self, jwt, telemetry=True, timeout=5.0, options=None): self.timeout = options.timeout # Returns a hard cap for the maximum number of retries allowed (10) - def MAX_REQUEST_RETRIES(self): + def MAX_REQUEST_RETRIES(self) -> int: return 10 # Returns the maximum amount of jitter to introduce in milliseconds (100ms) - def MAX_REQUEST_RETRY_JITTER(self): + def MAX_REQUEST_RETRY_JITTER(self) -> int: return 100 # Returns the maximum delay window allowed (1000ms) - def MAX_REQUEST_RETRY_DELAY(self): + def MAX_REQUEST_RETRY_DELAY(self) -> int: return 1000 # Returns the minimum delay window allowed (100ms) - def MIN_REQUEST_RETRY_DELAY(self): + def MIN_REQUEST_RETRY_DELAY(self) -> int: return 100 - def get(self, url, params=None, headers=None): + def get( + self, + url: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> Any: request_headers = self.base_headers.copy() request_headers.update(headers or {}) @@ -162,7 +177,12 @@ def get(self, url, params=None, headers=None): # Return the final Response return self._process_response(response) - def post(self, url, data=None, headers=None): + def post( + self, + url: str, + data: RequestData | None = None, + headers: dict[str, str] | None = None, + ) -> Any: request_headers = self.base_headers.copy() request_headers.update(headers or {}) @@ -171,7 +191,12 @@ def post(self, url, data=None, headers=None): ) return self._process_response(response) - def file_post(self, url, data=None, files=None): + def file_post( + self, + url: str, + data: RequestData | None = None, + files: dict[str, Any] | None = None, + ) -> Any: headers = self.base_headers.copy() headers.pop("Content-Type", None) @@ -180,7 +205,7 @@ def file_post(self, url, data=None, files=None): ) return self._process_response(response) - def patch(self, url, data=None): + def patch(self, url: str, data: RequestData | None = None) -> Any: headers = self.base_headers.copy() response = requests.patch( @@ -188,7 +213,7 @@ def patch(self, url, data=None): ) return self._process_response(response) - def put(self, url, data=None): + def put(self, url: str, data: RequestData | None = None) -> Any: headers = self.base_headers.copy() response = requests.put( @@ -196,7 +221,12 @@ def put(self, url, data=None): ) return self._process_response(response) - def delete(self, url, params=None, data=None): + def delete( + self, + url: str, + params: dict[str, Any] | None = None, + data: RequestData | None = None, + ) -> Any: headers = self.base_headers.copy() response = requests.delete( @@ -208,7 +238,7 @@ def delete(self, url, params=None, data=None): ) return self._process_response(response) - def _calculate_wait(self, attempt): + def _calculate_wait(self, attempt: int) -> int: # Retry the request. Apply a exponential backoff for subsequent attempts, using this formula: # max(MIN_REQUEST_RETRY_DELAY, min(MAX_REQUEST_RETRY_DELAY, (100ms * (2 ** attempt - 1)) + random_between(1, MAX_REQUEST_RETRY_JITTER))) @@ -225,14 +255,14 @@ def _calculate_wait(self, attempt): wait = max(self.MIN_REQUEST_RETRY_DELAY(), wait) self._metrics["retries"] = attempt - self._metrics["backoff"].append(wait) + self._metrics["backoff"].append(wait) # type: ignore[attr-defined] return wait - def _process_response(self, response): + def _process_response(self, response: requests.Response) -> Any: return self._parse(response).content() - def _parse(self, response): + def _parse(self, response: requests.Response) -> Response: if not response.text: return EmptyResponse(response.status_code) try: @@ -242,12 +272,14 @@ def _parse(self, response): class Response: - def __init__(self, status_code, content, headers): + def __init__( + self, status_code: int, content: Any, headers: Mapping[str, str] + ) -> None: self._status_code = status_code self._content = content self._headers = headers - def content(self): + def content(self) -> Any: if self._is_error(): if self._status_code == 429: reset_at = int(self._headers.get("x-ratelimit-reset", "-1")) @@ -272,7 +304,7 @@ def content(self): else: return self._content - def _is_error(self): + def _is_error(self) -> bool: return self._status_code is None or self._status_code >= 400 # Adding these methods to force implementation in subclasses because they are references in this parent class @@ -284,11 +316,11 @@ def _error_message(self): class JsonResponse(Response): - def __init__(self, response): + def __init__(self, response: requests.Response | RequestsResponse) -> None: content = json.loads(response.text) super().__init__(response.status_code, content, response.headers) - def _error_code(self): + def _error_code(self) -> str: if "errorCode" in self._content: return self._content.get("errorCode") elif "error" in self._content: @@ -298,7 +330,7 @@ def _error_code(self): else: return UNKNOWN_ERROR - def _error_message(self): + def _error_message(self) -> str: if "error_description" in self._content: return self._content.get("error_description") message = self._content.get("message", "") @@ -308,22 +340,22 @@ def _error_message(self): class PlainResponse(Response): - def __init__(self, response): + def __init__(self, response: requests.Response | RequestsResponse) -> None: super().__init__(response.status_code, response.text, response.headers) - def _error_code(self): + def _error_code(self) -> str: return UNKNOWN_ERROR - def _error_message(self): + def _error_message(self) -> str: return self._content class EmptyResponse(Response): - def __init__(self, status_code): + def __init__(self, status_code: int) -> None: super().__init__(status_code, "", {}) - def _error_code(self): + def _error_code(self) -> str: return UNKNOWN_ERROR - def _error_message(self): + def _error_message(self) -> str: return "" diff --git a/auth0/rest_async.py b/auth0/rest_async.py index c0fe02a3..183cfbb9 100644 --- a/auth0/rest_async.py +++ b/auth0/rest_async.py @@ -1,13 +1,17 @@ +from __future__ import annotations + import asyncio +from typing import Any import aiohttp from auth0.exceptions import RateLimitError +from auth0.types import RequestData -from .rest import EmptyResponse, JsonResponse, PlainResponse, RestClient +from .rest import EmptyResponse, JsonResponse, PlainResponse, Response, RestClient -def _clean_params(params): +def _clean_params(params: dict[Any, Any] | None) -> dict[Any, Any] | None: if params is None: return params return {k: v for k, v in params.items() if v is not None} @@ -30,9 +34,9 @@ class AsyncRestClient(RestClient): (defaults to 3) """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self._session = None + self._session: aiohttp.ClientSession | None = None sock_connect, sock_read = ( self.timeout if isinstance(self.timeout, tuple) @@ -42,13 +46,13 @@ def __init__(self, *args, **kwargs): sock_connect=sock_connect, sock_read=sock_read ) - def set_session(self, session): + def set_session(self, session: aiohttp.ClientSession) -> None: """Set Client Session to improve performance by reusing session. Session should be closed manually or within context manager. """ self._session = session - async def _request(self, *args, **kwargs): + async def _request(self, *args: Any, **kwargs: Any) -> Any: kwargs["headers"] = kwargs.get("headers", self.base_headers) kwargs["timeout"] = self.timeout if self._session is not None: @@ -61,7 +65,12 @@ async def _request(self, *args, **kwargs): async with session.request(*args, **kwargs) as response: return await self._process_response(response) - async def get(self, url, params=None, headers=None): + async def get( + self, + url: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> Any: request_headers = self.base_headers.copy() request_headers.update(headers or {}) # Track the API request attempt number @@ -92,32 +101,47 @@ async def get(self, url, params=None, headers=None): # sleep() functions in seconds, so convert the milliseconds formula above accordingly await asyncio.sleep(wait / 1000) - async def post(self, url, data=None, headers=None): + async def post( + self, + url: str, + data: RequestData | None = None, + headers: dict[str, str] | None = None, + ) -> Any: request_headers = self.base_headers.copy() request_headers.update(headers or {}) return await self._request("post", url, json=data, headers=request_headers) - async def file_post(self, url, data=None, files=None): + async def file_post( + self, + url: str, + data: dict[str, Any], + files: dict[str, Any], + ) -> Any: headers = self.base_headers.copy() headers.pop("Content-Type", None) return await self._request("post", url, data={**data, **files}, headers=headers) - async def patch(self, url, data=None): + async def patch(self, url: str, data: RequestData | None = None) -> Any: return await self._request("patch", url, json=data) - async def put(self, url, data=None): + async def put(self, url: str, data: RequestData | None = None) -> Any: return await self._request("put", url, json=data) - async def delete(self, url, params=None, data=None): + async def delete( + self, + url: str, + params: dict[str, Any] | None = None, + data: RequestData | None = None, + ) -> Any: return await self._request( "delete", url, json=data, params=_clean_params(params) or {} ) - async def _process_response(self, response): + async def _process_response(self, response: aiohttp.ClientResponse) -> Any: parsed_response = await self._parse(response) return parsed_response.content() - async def _parse(self, response): + async def _parse(self, response: aiohttp.ClientResponse) -> Response: text = await response.text() requests_response = RequestsResponse(response, text) if not text: @@ -129,7 +153,7 @@ async def _parse(self, response): class RequestsResponse: - def __init__(self, response, text): + def __init__(self, response: aiohttp.ClientResponse, text: str) -> None: self.status_code = response.status self.headers = response.headers self.text = text diff --git a/auth0/types.py b/auth0/types.py new file mode 100644 index 00000000..c1929cf2 --- /dev/null +++ b/auth0/types.py @@ -0,0 +1,5 @@ +from typing import Any, Dict, List, Tuple, Union + +TimeoutType = Union[float, Tuple[float, float]] + +RequestData = Union[Dict[str, Any], List[Any]] diff --git a/auth0/utils.py b/auth0/utils.py index a6909f04..807e9016 100644 --- a/auth0/utils.py +++ b/auth0/utils.py @@ -1,4 +1,4 @@ -def is_async_available(): +def is_async_available() -> bool: try: import asyncio diff --git a/docs/source/conf.py b/docs/source/conf.py index b3cdbc2e..d364fdb1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -93,3 +93,6 @@ def find_version(*file_paths): # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = [] + +# Sphinx somehow can't find this one +nitpick_ignore = [("py:class", "RSAPublicKey")] diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..af08759b --- /dev/null +++ b/mypy.ini @@ -0,0 +1,14 @@ +[mypy] +python_version = 3.7 + +[mypy-auth0.test.*,auth0.test_async.*] +ignore_errors = True + +[mypy-auth0.management.*] +ignore_errors = True + +[mypy-auth0.rest_async] +disable_error_code=override + +[mypy-auth0.authentication.async_token_verifier] +disable_error_code=override, misc, attr-defined