diff --git a/rest_framework_simplejwt/models.py b/rest_framework_simplejwt/models.py index a0e2c8345..859dfd3ae 100644 --- a/rest_framework_simplejwt/models.py +++ b/rest_framework_simplejwt/models.py @@ -34,11 +34,11 @@ def __str__(self) -> str: return f"TokenUser {self.id}" @cached_property - def id(self) -> Union[int, str]: + def id(self) -> str: return self.token[api_settings.USER_ID_CLAIM] @cached_property - def pk(self) -> Union[int, str]: + def pk(self) -> str: return self.id @cached_property diff --git a/rest_framework_simplejwt/tokens.py b/rest_framework_simplejwt/tokens.py index 431540760..4d633e843 100644 --- a/rest_framework_simplejwt/tokens.py +++ b/rest_framework_simplejwt/tokens.py @@ -225,8 +225,7 @@ def for_user(cls: type[T], user: AuthUser) -> T: ) user_id = getattr(user, api_settings.USER_ID_FIELD) - if not isinstance(user_id, int): - user_id = str(user_id) + user_id = str(user_id) token = cls() token[api_settings.USER_ID_CLAIM] = user_id diff --git a/setup.py b/setup.py index aa08d5295..c9cfba619 100755 --- a/setup.py +++ b/setup.py @@ -59,7 +59,7 @@ install_requires=[ "django>=4.2", "djangorestframework>=3.14", - "pyjwt>=1.7.1,<2.10.0", + "pyjwt>=1.7.1", ], python_requires=">=3.9", extras_require=extras_require, diff --git a/tests/test_authentication.py b/tests/test_authentication.py index cb6c3dc3f..b17a12caa 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -225,6 +225,17 @@ def test_get_user_with_check_revoke_token(self): # Otherwise, should return correct user self.assertEqual(self.backend.get_user(payload).id, u.id) + def test_get_user_with_str_user_id_claim(self): + """ + Verify that even though the user id is an int, it can be verified + and retrieved if the user id claim value is a string + """ + + user = User.objects.create_user(username="testuser") + payload = {api_settings.USER_ID_CLAIM: str(user.id)} + auth_user = self.backend.get_user(payload) + self.assertEqual(auth_user.id, user.id) + class TestJWTStatelessUserAuthentication(TestCase): def setUp(self): diff --git a/tests/test_models.py b/tests/test_models.py index c757eea75..719470639 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -12,7 +12,7 @@ class TestTokenUser(TestCase): def setUp(self): self.token = AuthToken() - self.token[api_settings.USER_ID_CLAIM] = 42 + self.token[api_settings.USER_ID_CLAIM] = "42" self.token["username"] = "deep-thought" self.token["some_other_stuff"] = "arstarst" @@ -40,13 +40,13 @@ def test_str(self): self.assertEqual(str(self.user), "TokenUser 42") def test_id(self): - self.assertEqual(self.user.id, 42) + self.assertEqual(self.user.id, "42") def test_pk(self): - self.assertEqual(self.user.pk, 42) + self.assertEqual(self.user.pk, "42") def test_is_staff(self): - payload = {api_settings.USER_ID_CLAIM: 42} + payload = {api_settings.USER_ID_CLAIM: "42"} user = TokenUser(payload) self.assertFalse(user.is_staff) @@ -57,7 +57,7 @@ def test_is_staff(self): self.assertTrue(user.is_staff) def test_is_superuser(self): - payload = {api_settings.USER_ID_CLAIM: 42} + payload = {api_settings.USER_ID_CLAIM: "42"} user = TokenUser(payload) self.assertFalse(user.is_superuser) @@ -68,15 +68,15 @@ def test_is_superuser(self): self.assertTrue(user.is_superuser) def test_eq(self): - user1 = TokenUser({api_settings.USER_ID_CLAIM: 1}) - user2 = TokenUser({api_settings.USER_ID_CLAIM: 2}) - user3 = TokenUser({api_settings.USER_ID_CLAIM: 1}) + user1 = TokenUser({api_settings.USER_ID_CLAIM: "1"}) + user2 = TokenUser({api_settings.USER_ID_CLAIM: "2"}) + user3 = TokenUser({api_settings.USER_ID_CLAIM: "1"}) self.assertNotEqual(user1, user2) self.assertEqual(user1, user3) def test_eq_not_implemented(self): - user1 = TokenUser({api_settings.USER_ID_CLAIM: 1}) + user1 = TokenUser({api_settings.USER_ID_CLAIM: "1"}) user2 = "user2" self.assertFalse(user1 == user2) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 5605bae08..25485789c 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -383,8 +383,7 @@ def test_for_user(self): token = MyToken.for_user(self.user) user_id = getattr(self.user, api_settings.USER_ID_FIELD) - if not isinstance(user_id, int): - user_id = str(user_id) + user_id = str(user_id) self.assertEqual(token[api_settings.USER_ID_CLAIM], user_id) @@ -404,6 +403,10 @@ def test_get_token_backend(self): self.assertEqual(token.get_token_backend(), token_backend) + def test_token_user_id_claim_should_always_be_string(self): + token = MyToken.for_user(self.user) + self.assertIsInstance(token[api_settings.USER_ID_CLAIM], str) + class TestSlidingToken(TestCase): def test_init(self): diff --git a/tox.ini b/tox.ini index bc63d3640..109266c2f 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps= drf314: djangorestframework>=3.14,<3.15 drf315: djangorestframework>=3.15,<3.16 pyjwt171: pyjwt>=1.7.1,<1.8 - pyjwt2: pyjwt>=2,<2.10.0 + pyjwt2: pyjwt>=2,<3 allowlist_externals=make [testenv:docs]