diff --git a/auth0/v3/authentication/authorize_client.py b/auth0/v3/authentication/authorize_client.py index 66fb64e2..8650399d 100644 --- a/auth0/v3/authentication/authorize_client.py +++ b/auth0/v3/authentication/authorize_client.py @@ -1,8 +1,11 @@ +import sys +from requests.compat import urlencode, urlunparse, quote_plus from .base import AuthenticationBase +_ver = '{}{}{}'.format(*sys.version_info) -class AuthorizeClient(AuthenticationBase): +class AuthorizeClient(AuthenticationBase): """Authorize Client Args: @@ -12,6 +15,25 @@ class AuthorizeClient(AuthenticationBase): def __init__(self, domain): self.domain = domain + def get_authorize_url(self, client_id, audience=None, state=None, redirect_uri=None, + response_type='code', scope='openid', quote_via=quote_plus): + """ + use quote_via=urllib.quote to to urlencode spaces into "%20", the default is "+" + """ + params = { + 'client_id': client_id, + 'audience': audience, + 'response_type': response_type, + 'scope': scope, + 'state': state, + 'redirect_uri': redirect_uri + } + query = urlencode(params, doseq=True, quote_via=quote_via) \ + if _ver > '34' \ + else '&'.join(['{}={}'.format(quote_via(k, safe=''), quote_via(v, safe='')) + for k, v in params.items()]) + return urlunparse(['https', self.domain, '/authorize', None, query, None]) + def authorize(self, client_id, audience=None, state=None, redirect_uri=None, response_type='code', scope='openid'): """Authorization code grant @@ -30,4 +52,3 @@ def authorize(self, client_id, audience=None, state=None, redirect_uri=None, return self.get( 'https://%s/authorize' % self.domain, params=params) - diff --git a/auth0/v3/test/authentication/test_authorize_client.py b/auth0/v3/test/authentication/test_authorize_client.py index 1bc54ec2..2546b39e 100644 --- a/auth0/v3/test/authentication/test_authorize_client.py +++ b/auth0/v3/test/authentication/test_authorize_client.py @@ -1,10 +1,74 @@ import unittest +from requests.compat import quote, urlparse + import mock from ...authentication.authorize_client import AuthorizeClient class TestAuthorizeClient(unittest.TestCase): + def test_get_authorize_url(self): + expected_result = urlparse( + 'https://my.domain.com/authorize' + '?client_id=cid' + '&audience=https%3A%2F%2Ftest.com%2Fapi' + '&state=st+ate' + '&redirect_uri=http%3A%2F%2Flocalhost%3Fcallback' + '&response_type=code' + '&scope=openid+profile') + + a = AuthorizeClient('my.domain.com') + + actual_result = urlparse(a.get_authorize_url( + client_id='cid', + audience='https://test.com/api', + state='st ate', + redirect_uri='http://localhost?callback', + scope='openid profile' + )) + + self.assertEqual(actual_result.scheme, expected_result.scheme) + self.assertEqual(actual_result.hostname, expected_result.hostname) + self.assertEqual(actual_result.path, expected_result.path) + + # there is no guarantee the order of items in the query is equal to the method call, that's okay + expected_query = expected_result.query.split('&') + actual_query = actual_result.query.split('&') + self.assertEqual(sorted(expected_query), sorted(actual_query)) + + def test_get_authorize_url_quote(self): + """ + sometimes we want to urlencode spaces into %20, we can do that with quote_via + """ + expected_result = urlparse( + 'https://my.domain.com/authorize' + '?client_id=cid' + '&audience=https%3A%2F%2Ftest.com%2Fapi%3Ffoo%3Dbar' + '&state=st%20ate' + '&redirect_uri=http%3A%2F%2Flocalhost%3Fcallback%3D%23123' + '&response_type=code' + '&scope=openid%20profile' + ) + a = AuthorizeClient('my.domain.com') + + actual_result = urlparse(a.get_authorize_url( + client_id='cid', + audience='https://test.com/api?foo=bar', + state='st ate', + redirect_uri='http://localhost?callback=#123', + quote_via=quote, + scope='openid profile' + )) + + self.assertEqual(actual_result.scheme, expected_result.scheme) + self.assertEqual(actual_result.hostname, expected_result.hostname) + self.assertEqual(actual_result.path, expected_result.path) + + # there is no guarantee the order of items in the query is equal to the method call, that's okay + expected_query = expected_result.query.split('&') + actual_query = actual_result.query.split('&') + self.assertEqual(sorted(expected_query), sorted(actual_query)) + @mock.patch('auth0.v3.authentication.authorize_client.AuthorizeClient.get') def test_login(self, mock_get):