Skip to content

Commit d00f963

Browse files
Add MODEL_SERIALIZER_FIELD_MAPPING settings
1 parent f113ab6 commit d00f963

File tree

4 files changed

+115
-43
lines changed

4 files changed

+115
-43
lines changed

docs/api-guide/settings.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,17 @@ Default: `ordering`
143143

144144
---
145145

146+
## Serializer settings
147+
148+
#### MODEL_SERIALIZER_FIELD_MAPPING
149+
150+
Extra field mapping used to extend or override mapping of django db fields to serializer fields which is used by
151+
ModelSerializer to set up fields for serializer.
152+
153+
Default: `{}`
154+
155+
---
156+
146157
## Versioning settings
147158

148159
#### DEFAULT_VERSION

rest_framework/serializers.py

Lines changed: 56 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -894,48 +894,8 @@ class ModelSerializer(Serializer):
894894
* A set of default validators are automatically populated.
895895
* Default `.create()` and `.update()` implementations are provided.
896896
897-
The process of automatically determining a set of serializer fields
898-
based on the model fields is reasonably complex, but you almost certainly
899-
don't need to dig into the implementation.
900-
901-
If the `ModelSerializer` class *doesn't* generate the set of fields that
902-
you need you should either declare the extra/differing fields explicitly on
903-
the serializer class, or simply use a `Serializer` class.
904897
"""
905-
serializer_field_mapping = {
906-
models.AutoField: IntegerField,
907-
models.BigIntegerField: IntegerField,
908-
models.BooleanField: BooleanField,
909-
models.CharField: CharField,
910-
models.CommaSeparatedIntegerField: CharField,
911-
models.DateField: DateField,
912-
models.DateTimeField: DateTimeField,
913-
models.DecimalField: DecimalField,
914-
models.DurationField: DurationField,
915-
models.EmailField: EmailField,
916-
models.Field: ModelField,
917-
models.FileField: FileField,
918-
models.FloatField: FloatField,
919-
models.ImageField: ImageField,
920-
models.IntegerField: IntegerField,
921-
models.NullBooleanField: BooleanField,
922-
models.PositiveIntegerField: IntegerField,
923-
models.PositiveSmallIntegerField: IntegerField,
924-
models.SlugField: SlugField,
925-
models.SmallIntegerField: IntegerField,
926-
models.TextField: CharField,
927-
models.TimeField: TimeField,
928-
models.URLField: URLField,
929-
models.UUIDField: UUIDField,
930-
models.GenericIPAddressField: IPAddressField,
931-
models.FilePathField: FilePathField,
932-
}
933-
if hasattr(models, 'JSONField'):
934-
serializer_field_mapping[models.JSONField] = JSONField
935-
if postgres_fields:
936-
serializer_field_mapping[postgres_fields.HStoreField] = HStoreField
937-
serializer_field_mapping[postgres_fields.ArrayField] = ListField
938-
serializer_field_mapping[postgres_fields.JSONField] = JSONField
898+
939899
serializer_related_field = PrimaryKeyRelatedField
940900
serializer_related_to_field = SlugRelatedField
941901
serializer_url_field = HyperlinkedIdentityField
@@ -950,6 +910,61 @@ class ModelSerializer(Serializer):
950910
# "HTTP 201 Created" responses.
951911
url_field_name = None
952912

913+
@property
914+
def serializer_field_mapping(self) -> dict[type[models.Field], type[Field]]:
915+
"""Get mapping of django model field to serializer field.
916+
917+
The process of automatically determining a set of serializer fields
918+
based on the model fields is reasonably complex, but you almost certainly
919+
don't need to dig into the implementation.
920+
921+
If the `ModelSerializer` class *doesn't* generate the set of fields that
922+
you need you should either extend serializer_field_mapping with
923+
the extra/differing fields explicitly, or simply use a `Serializer`
924+
class.
925+
926+
"""
927+
serializer_field_mapping = {
928+
models.AutoField: IntegerField,
929+
models.BigIntegerField: IntegerField,
930+
models.BooleanField: BooleanField,
931+
models.CharField: CharField,
932+
models.CommaSeparatedIntegerField: CharField,
933+
models.DateField: DateField,
934+
models.DateTimeField: DateTimeField,
935+
models.DecimalField: DecimalField,
936+
models.DurationField: DurationField,
937+
models.EmailField: EmailField,
938+
models.Field: ModelField,
939+
models.FileField: FileField,
940+
models.FloatField: FloatField,
941+
models.ImageField: ImageField,
942+
models.IntegerField: IntegerField,
943+
models.NullBooleanField: BooleanField,
944+
models.PositiveIntegerField: IntegerField,
945+
models.PositiveSmallIntegerField: IntegerField,
946+
models.SlugField: SlugField,
947+
models.SmallIntegerField: IntegerField,
948+
models.TextField: CharField,
949+
models.TimeField: TimeField,
950+
models.URLField: URLField,
951+
models.UUIDField: UUIDField,
952+
models.GenericIPAddressField: IPAddressField,
953+
models.FilePathField: FilePathField,
954+
}
955+
if hasattr(models, 'JSONField'):
956+
serializer_field_mapping[models.JSONField] = JSONField
957+
if postgres_fields:
958+
serializer_field_mapping[postgres_fields.HStoreField] = HStoreField
959+
serializer_field_mapping[postgres_fields.ArrayField] = ListField
960+
serializer_field_mapping[postgres_fields.JSONField] = JSONField
961+
for (
962+
model_field,
963+
serializer_field,
964+
) in api_settings.MODEL_SERIALIZER_FIELD_MAPPING.items():
965+
serializer_field_mapping[model_field] = serializer_field
966+
return serializer_field_mapping
967+
953968
# Default `create` and `update` behavior...
954969
def create(self, validated_data):
955970
"""

rest_framework/settings.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@
126126
'retrieve': 'read',
127127
'destroy': 'delete'
128128
},
129+
130+
# Serializers
131+
'MODEL_SERIALIZER_FIELD_MAPPING': {}
129132
}
130133

131134

@@ -147,7 +150,8 @@
147150
'UNAUTHENTICATED_USER',
148151
'UNAUTHENTICATED_TOKEN',
149152
'VIEW_NAME_FUNCTION',
150-
'VIEW_DESCRIPTION_FUNCTION'
153+
'VIEW_DESCRIPTION_FUNCTION',
154+
'MODEL_SERIALIZER_FIELD_MAPPING',
151155
]
152156

153157

@@ -168,6 +172,16 @@ def perform_import(val, setting_name):
168172
return import_from_string(val, setting_name)
169173
elif isinstance(val, (list, tuple)):
170174
return [import_from_string(item, setting_name) for item in val]
175+
elif isinstance(val, (dict)):
176+
return {
177+
import_from_string(
178+
key,
179+
setting_name,
180+
): import_from_string(
181+
value,
182+
setting_name,
183+
) for key, value in val.items()
184+
}
171185
return val
172186

173187

tests/test_model_serializer.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from django.db import models
2222
from django.db.models.signals import m2m_changed
2323
from django.dispatch import receiver
24-
from django.test import TestCase
24+
from django.test import TestCase, override_settings
2525

2626
from rest_framework import serializers
2727
from rest_framework.compat import postgres_fields
@@ -43,6 +43,12 @@ class CustomField(models.Field):
4343
pass
4444

4545

46+
class CustomCharFieldField(serializers.CharField):
47+
"""
48+
A custom serializer field simply for testing purposes.
49+
"""
50+
51+
4652
class OneFieldModel(models.Model):
4753
char_field = models.CharField(max_length=100)
4854

@@ -194,6 +200,32 @@ class Meta:
194200
custom_field = ModelField\(model_field=<tests.test_model_serializer.CustomField: custom_field>\)
195201
file_path_field = FilePathField\(path=%r\)
196202
""" % tempfile.gettempdir())
203+
print(expected)
204+
assert re.search(expected, repr(TestSerializer())) is not None
205+
206+
@override_settings(
207+
REST_FRAMEWORK={
208+
'MODEL_SERIALIZER_FIELD_MAPPING': {
209+
'django.db.models.CharField': 'tests.test_model_serializer.CustomCharFieldField',
210+
}
211+
},
212+
)
213+
def test_custom_fields(self):
214+
"""
215+
If MODEL_SERIALIZER_FIELD_MAPPING is set than model fields should map
216+
to their equivalent serializer fields.
217+
"""
218+
class TestSerializer(serializers.ModelSerializer):
219+
class Meta:
220+
model = RegularFieldsModel
221+
fields = (
222+
"char_field",
223+
)
224+
225+
expected = dedent(r"""
226+
TestSerializer\(\):
227+
char_field = CustomCharFieldField\(max_length=100\)
228+
""")
197229
assert re.search(expected, repr(TestSerializer())) is not None
198230

199231
def test_field_options(self):

0 commit comments

Comments
 (0)