diff --git a/AUTHORS b/AUTHORS index 2d8d5465b..4ebe787cd 100644 --- a/AUTHORS +++ b/AUTHORS @@ -102,6 +102,7 @@ Peter Karman Peter McDonald Petr DlouhĂ˝ pySilver +@realsuayip Rodney Richardson Rustem Saiargaliev Rustem Saiargaliev @@ -127,4 +128,5 @@ Yaroslav Halchenko Yuri Savin Miriam Forner Alex Kerkum -q0w +Tuhin Mitra +q0w \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 1821a8ae1..a29772c13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,12 +9,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added * Support for Django 5.2 * Support for Python 3.14 (Django >= 5.2.8) +* #1539 Add device authorization grant support + ### Fixed +* #1252 Fix crash when 'client' is in token request body +* #1496 Fix error when Bearer token string is empty but preceded by `Bearer` keyword. + @@ -27,7 +33,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * #1506 Support for Wildcard Origin and Redirect URIs - Adds a new setting [ALLOW_URL_WILDCARDS](https://django-oauth-toolkit.readthedocs.io/en/latest/settings.html#allow-uri-wildcards). This feature is useful for working with CI service such as cloudflare, netlify, and vercel that offer branch deployments for development previews and user acceptance testing. * #1586 Turkish language support added -* #1539 Add device authorization grant support ### Changed The project is now hosted in the django-oauth organization. diff --git a/oauth2_provider/middleware.py b/oauth2_provider/middleware.py index 65c9cf03c..5a8a86d87 100644 --- a/oauth2_provider/middleware.py +++ b/oauth2_provider/middleware.py @@ -52,8 +52,9 @@ def __init__(self, get_response): def __call__(self, request): authheader = request.META.get("HTTP_AUTHORIZATION", "") - if authheader.startswith("Bearer"): - tokenstring = authheader.split()[1] + splits = authheader.split(maxsplit=1) + if authheader.startswith("Bearer") and len(splits) == 2: + tokenstring = splits[1] AccessToken = get_access_token_model() try: token_checksum = hashlib.sha256(tokenstring.encode("utf-8")).hexdigest() diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index ec974b0c6..a202a6a82 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -214,19 +214,31 @@ def _load_application(self, client_id, request): If request.client was not set, load application instance for given client_id and store it in request.client """ - - # we want to be sure that request has the client attribute! - assert hasattr(request, "client"), '"request" instance has no "client" attribute' - + if request.client: + # check for cached client, to save the db hit if this has already been loaded + if not isinstance(request.client, Application): + # resetting request.client (client_id=%r): not an Application, something else set request.client erroneously + request.client = None + elif request.client.client_id != client_id: + # resetting request.client (client_id=%r): request.client.client_id does not match the given client_id + request.client = None + elif not request.client.is_usable(request): + # resetting request.client (client_id=%r): request.client is a valid Application, but is not usable + request.client = None + else: + # request.client is a valid Application, reusing it + return request.client try: - request.client = request.client or Application.objects.get(client_id=client_id) - # Check that the application can be used (defaults to always True) - if not request.client.is_usable(request): - log.debug("Failed body authentication: Application %r is disabled" % (client_id)) + # cache not hit, loading application from database for client_id %r + client = Application.objects.get(client_id=client_id) + if not client.is_usable(request): + # Failed to load application: Application %r is not usable return None + request.client = client + # Loaded application with client_id %r from database return request.client except Application.DoesNotExist: - log.debug("Failed body authentication: Application %r does not exist" % (client_id)) + # Failed to load application: Application with client_id %r does not exist return None def _set_oauth2_error_on_request(self, request, access_token, scopes): @@ -289,6 +301,7 @@ def client_authentication_required(self, request, *args, **kwargs): pass self._load_application(request.client_id, request) + log.debug("Determining if client authentication is required for client %r", request.client) if request.client: return request.client.client_type == AbstractApplication.CLIENT_CONFIDENTIAL diff --git a/tests/test_authorization_code.py b/tests/test_authorization_code.py index 360fac957..369b1939f 100644 --- a/tests/test_authorization_code.py +++ b/tests/test_authorization_code.py @@ -1308,6 +1308,27 @@ def test_request_body_params(self): self.assertEqual(content["scope"], "read write") self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + def test_request_body_params_client_typo(self): + """ + Verify that using incorrect parameter name (client instead of client_id) returns invalid_client error + """ + self.client.login(username="test_user", password="123456") + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client": self.application.client_id, + "client_secret": CLEARTEXT_SECRET, + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 401) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["error"], "invalid_client") + def test_public(self): """ Request an access token using client_type: public diff --git a/tests/test_oauth2_provider_middleware.py b/tests/test_oauth2_provider_middleware.py new file mode 100644 index 000000000..90610f78b --- /dev/null +++ b/tests/test_oauth2_provider_middleware.py @@ -0,0 +1,98 @@ +import datetime +import hashlib + +from django.contrib.auth import get_user_model +from django.test import RequestFactory, TestCase + +from oauth2_provider.middleware import OAuth2ExtraTokenMiddleware +from oauth2_provider.models import get_access_token_model, get_application_model + + +Application = get_application_model() +AccessToken = get_access_token_model() +User = get_user_model() + + +class TestOAuth2ExtraTokenMiddleware(TestCase): + def setUp(self): + self.factory = RequestFactory() + self.middleware = OAuth2ExtraTokenMiddleware(lambda r: None) + + # Create test user and application for valid token tests + self.user = User.objects.create_user("test_user", "test@example.com", "123456") + self.application = Application.objects.create( + name="Test Application", + user=self.user, + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + ) + + def test_malformed_bearer_header_no_token(self): + """Test that 'Authorization: Bearer' without token doesn't crash""" + request = self.factory.get("/", HTTP_AUTHORIZATION="Bearer") + + # This should not raise an IndexError + _ = self.middleware(request) + + # Should not have access_token attribute + self.assertFalse(hasattr(request, "access_token")) + + def test_malformed_bearer_header_empty_token(self): + """Test that 'Authorization: Bearer ' with empty token doesn't crash""" + request = self.factory.get("/", HTTP_AUTHORIZATION="Bearer ") + + # This should not raise an IndexError + _ = self.middleware(request) + + # Should not have access_token attribute + self.assertFalse(hasattr(request, "access_token")) + + def test_valid_bearer_token(self): + """Test that valid bearer token works correctly""" + # Create a valid access token + token_string = "test-token-12345" + token_checksum = hashlib.sha256(token_string.encode("utf-8")).hexdigest() + access_token = AccessToken.objects.create( + user=self.user, + scope="read", + expires=datetime.datetime.now() + datetime.timedelta(days=1), + token=token_string, + token_checksum=token_checksum, + application=self.application, + ) + + request = self.factory.get("/", HTTP_AUTHORIZATION=f"Bearer {token_string}") + + _ = self.middleware(request) + + # Should have access_token attribute set + self.assertTrue(hasattr(request, "access_token")) + self.assertEqual(request.access_token, access_token) + + def test_invalid_bearer_token(self): + """Test that invalid bearer token doesn't crash but doesn't set access_token""" + request = self.factory.get("/", HTTP_AUTHORIZATION="Bearer invalid-token-xyz") + + # This should not raise an exception + _ = self.middleware(request) + + # Should not have access_token attribute + self.assertFalse(hasattr(request, "access_token")) + + def test_no_authorization_header(self): + """Test that request without Authorization header works normally""" + request = self.factory.get("/") + + _ = self.middleware(request) + + # Should not have access_token attribute + self.assertFalse(hasattr(request, "access_token")) + + def test_non_bearer_authorization_header(self): + """Test that non-Bearer authorization headers are ignored""" + request = self.factory.get("/", HTTP_AUTHORIZATION="Basic dXNlcjpwYXNz") + + _ = self.middleware(request) + + # Should not have access_token attribute + self.assertFalse(hasattr(request, "access_token")) diff --git a/tests/test_oauth2_validators.py b/tests/test_oauth2_validators.py index 7e7e46de7..3fb292060 100644 --- a/tests/test_oauth2_validators.py +++ b/tests/test_oauth2_validators.py @@ -216,8 +216,52 @@ def test_client_authentication_required(self): self.request.client = "" self.assertTrue(self.validator.client_authentication_required(self.request)) - def test_load_application_fails_when_request_has_no_client(self): - self.assertRaises(AssertionError, self.validator.authenticate_client_id, "client_id", {}) + def test_load_application_loads_client_id_when_request_has_no_client(self): + self.request.client = None + application = self.validator._load_application("client_id", self.request) + self.assertEqual(application, self.application) + + def test_load_application_uses_cached_when_request_has_valid_client_matching_client_id(self): + self.request.client = self.application + application = self.validator._load_application("client_id", self.request) + self.assertIs(application, self.application) + self.assertIs(self.request.client, self.application) + + def test_load_application_succeeds_when_request_has_invalid_client_valid_client_id(self): + self.request.client = 'invalid_client' + application = self.validator._load_application("client_id", self.request) + self.assertEqual(application, self.application) + self.assertEqual(self.request.client, self.application) + + def test_load_application_overwrites_client_on_client_id_mismatch(self): + another_application = Application.objects.create( + client_id="another_client_id", + client_secret=CLEARTEXT_SECRET, + user=self.user, + client_type=Application.CLIENT_PUBLIC, + authorization_grant_type=Application.GRANT_PASSWORD, + ) + self.request.client = another_application + application = self.validator._load_application("client_id", self.request) + self.assertEqual(application, self.application) + self.assertEqual(self.request.client, self.application) + another_application.delete() + + @mock.patch.object(Application, "is_usable") + def test_load_application_returns_none_when_client_not_usable_cached(self, mock_is_usable): + mock_is_usable.return_value = False + self.request.client = self.application + application = self.validator._load_application("client_id", self.request) + self.assertIsNone(application) + self.assertIsNone(self.request.client) + + @mock.patch.object(Application, "is_usable") + def test_load_application_returns_none_when_client_not_usable_db_lookup(self, mock_is_usable): + mock_is_usable.return_value = False + self.request.client = None + application = self.validator._load_application("client_id", self.request) + self.assertIsNone(application) + self.assertIsNone(self.request.client) def test_rotate_refresh_token__is_true(self): self.assertTrue(self.validator.rotate_refresh_token(mock.MagicMock()))