Skip to content

Commit d0b3b64

Browse files
authored
Fix prefetch_related updates. (encode#4553)
1 parent aed4ed5 commit d0b3b64

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

rest_framework/mixins.py

+7
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ def update(self, request, *args, **kwargs):
6868
serializer = self.get_serializer(instance, data=request.data, partial=partial)
6969
serializer.is_valid(raise_exception=True)
7070
self.perform_update(serializer)
71+
72+
if getattr(instance, '_prefetched_objects_cache', None):
73+
# If 'prefetch_related' has been applied to a queryset, we need to
74+
# refresh the instance from the database.
75+
instance = self.get_object()
76+
serializer = self.get_serializer(instance)
77+
7178
return Response(serializer.data)
7279

7380
def perform_update(self, serializer):

tests/test_prefetch_related.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from django.contrib.auth.models import Group, User
2+
from django.test import TestCase
3+
4+
from rest_framework import generics, serializers
5+
from rest_framework.test import APIRequestFactory
6+
7+
factory = APIRequestFactory()
8+
9+
10+
class UserSerializer(serializers.ModelSerializer):
11+
class Meta:
12+
model = User
13+
fields = ('id', 'username', 'email', 'groups')
14+
15+
16+
class UserUpdate(generics.UpdateAPIView):
17+
queryset = User.objects.all().prefetch_related('groups')
18+
serializer_class = UserSerializer
19+
20+
21+
class TestPrefetchRelatedUpdates(TestCase):
22+
def setUp(self):
23+
self.user = User.objects.create(username='tom', email='tom@example.com')
24+
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
25+
self.user.groups = self.groups
26+
self.user.save()
27+
28+
def test_prefetch_related_updates(self):
29+
view = UserUpdate.as_view()
30+
pk = self.user.pk
31+
groups_pk = self.groups[0].pk
32+
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
33+
response = view(request, pk=pk)
34+
assert User.objects.get(pk=pk).groups.count() == 1
35+
expected = {
36+
'id': pk,
37+
'username': 'new',
38+
'groups': [1],
39+
'email': 'tom@example.com'
40+
}
41+
assert response.data == expected

0 commit comments

Comments
 (0)