10000 Generate OpenAPI schema field types from validators. (#6674) · coderanger/django-rest-framework@2d65f82 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2d65f82

Browse files
n2ygkcarltongibson
authored andcommitted
Generate OpenAPI schema field types from validators. (encode#6674)
1 parent a63860f commit 2d65f82

File tree

3 files changed

+204
-4
lines changed

3 files changed

+204
-4
lines changed

rest_framework/schemas/openapi.py

Lines changed: 105 additions & 4 deletions
< 57AE td data-grid-cell-id="diff-99d3943485b50b323271fe5399dab8b2c778f42cbc0749b0f702e4cab0a81b2c-270-325-0" data-selected="false" role="gridcell" style="background-color:var(--diffBlob-additionNum-bgColor, var(--diffBlob-addition-bgColor-num));text-align:center" tabindex="-1" valign="top" class="focusable-grid-cell diff-line-number position-relative left-side">
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import warnings
22

3+
from django.core.validators import (
4+
DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
5+
MinLengthValidator, MinValueValidator, RegexValidator, URLValidator
6+
)
37
from django.db import models
48
from django.utils.encoding import force_text
59

610
from rest_framework import exceptions, serializers
711
from rest_framework.compat import uritemplate
12+
from rest_framework.fields import empty
813

914
from .generators import BaseSchemaGenerator
1015
from .inspectors import ViewInspector
@@ -268,18 +273,76 @@ def _map_field(self, field):
268273
'format': 'date-time',
269274
}
270275

276+
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
277+
# see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
278+
# see also: https://swagger.io/docs/specification/data-models/data-types/#string
279+
if isinstance(field, serializers.EmailField):
280+
return {
281+
'type': 'string',
282+
'format': 'email'
283+
}
284+
285+
if isinstance(field, serializers.URLField):
286+
return {
287+
'type': 'string',
288+
'format': 'uri'
289+
}
290+
291+
if isinstance(field, serializers.UUIDField):
292+
return {
293+
'type': 'string',
294+
'format': 'uuid'
295+
}
296+
297+
if isinstance(field, serializers.IPAddressField):
298+
content = {
299+
'type': 'string',
300+
}
301+
if field.protocol != 'both':
302+
content['format'] = field.protocol
303+
return content
304+
305+
# DecimalField has multipleOf based on decimal_places
306+
if isinstance(field, serializers.DecimalField):
307+
content = {
308+
'type': 'number'
309+
}
310+
if field.decimal_places:
311+
content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
312+
if field.max_whole_digits:
313+
content['maximum'] = int(field.max_whole_digits * '9') + 1
314+
content['minimum'] = -content['maximum']
315+
self._map_min_max(field, content)
316+
return content
317+
318+
if isinstance(field, serializers.FloatField):
319+
content = {
320+
'type': 'number'
321+
}
322+
self._map_min_max(field, content)
323+
return content
324+
325+
if isinstance(field, serializers.IntegerField):
326+
content = {
327+
'type': 'integer'
328+
}
329+
self._map_min_max(field, content)
330+
return content
331+
271332
# Simplest cases, default to 'string' type:
272333
FIELD_CLASS_SCHEMA_TYPE = {
273334
serializers.BooleanField: 'boolean',
274-
serializers.DecimalField: 'number',
275-
serializers.FloatField: 'number',
276-
serializers.IntegerField: 'integer',
277-
278335
serializers.JSONField: 'object',
279336
serializers.DictField: 'object',
280337
}
281338
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
282339

340+
def _map_min_max(self, field, content):
341+
if field.max_value:
342+
content['maximum'] = field.max_value
343+
if field.min_value:
344+
content['minimum'] = field.min_value
345+
283346
def _map_serializer(self, serializer):
284347
# Assuming we have a valid serializer instance.
285348
# TODO:
@@ -303,13 +366,51 @@ def _map_serializer(self, serializer):
303366
schema['writeOnly'] = True
304367
if field.allow_null:
305368
schema['nullable'] = True
369+
if field.default and field.default != empty: # why don't they use None?!
370+
schema['default'] = field.default
371+
if field.help_text:
372+
schema['description'] = field.help_text
373+
self._map_field_validators(field.validators, schema)
306374

307375
properties[field.field_name] = schema
308376
return {
309377
'required': required,
310378
'properties': properties,
311379
}
312380

381+
def _map_field_validators(self, validators, schema):
382+
"""
383+
map field validators
384+
:param list:validators: list of field validators
385+
:param dict:schema: schema that the validators get added to
386+
"""
387+
for v in validators:
388+
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
389+
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
390+
if isinstance(v, EmailValidator):
391+
schema['format'] = 'email'
392+
if isinstance(v, URLValidator):
393+
schema['format'] = 'uri'
394+
if isinstance(v, RegexValidator):
395+
schema['pattern'] = v.regex.pattern
396+
elif isinstance(v, MaxLengthValidator):
397+
schema['maxLength'] = v.limit_value
398+
elif isinstance(v, MinLengthValidator):
399+
schema['minLength'] = v.limit_value
400+
elif isinstance(v, MaxValueValidator):
401+
schema['maximum'] = v.limit_value
402+
elif isinstance(v, MinValueValidator):
403+
schema['minimum'] = v.limit_value
404+
elif isinstance(v, DecimalValidator):
405+
if v.decimal_places:
406+
schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1')
407+
if v.max_digits:
408+
digits = v.max_digits
409+
if v.decimal_places is not None and v.decimal_places > 0:
410+
digits -= v.decimal_places
411+
schema['maximum'] = int(digits * '9') + 1
412+
schema['minimum'] = -schema['maximum']
413+
313414
def _get_request_body(self, path, method):
314415
view = self.view
315416

tests/schemas/test_openapi.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,53 @@ def test_serializer_datefield(self):
257257

258258
assert response_schema['date']['format'] == 'date'
259259
assert response_schema['datetime']['format'] == 'date-time'
260+
261+
def test_serializer_validators(self):
262+
patterns = [
263+
url(r'^example/?$', views.ExampleValdidatedAPIView.as_view()),
264+
]
265+
generator = SchemaGenerator(patterns=patterns)
266+
267+
request = create_request('/')
268+
schema = generator.get_schema(request=request)
269+
270+
response = schema['paths']['/example/']['get']['responses']
271+
response_schema = response['200']['content']['application/json']['schema']['properties']
272+
273+
assert response_schema['integer']['type'] == 'integer'
274+
assert response_schema['integer']['maximum'] == 99
275+
assert response_schema['integer']['minimum'] == -11
276+
277+
assert response_schema['string']['minLength'] == 2
278+
assert response_schema['string']['maxLength'] == 10
279+
280+
assert response_schema['regex']['pattern'] == r'[ABC]12{3}'
281+
assert response_schema['regex']['description'] == 'must have an A, B, or C followed by 1222'
282+
283+
assert response_schema['decimal1']['type'] == 'number'
284+
assert response_schema['decimal1']['multipleOf'] == .01
285+
assert response_schema['decimal1']['maximum'] == 10000
286+
assert response_schema['decimal1']['minimum'] == -10000
287+
288+
assert response_schema['decimal2']['type'] == 'number'
289+
assert response_schema['decimal2']['multipleOf'] == .0001
290+
291+
assert response_schema['email']['type'] == 'string'
292+
assert response_schema['email']['format'] == 'email'
293+
assert response_schema['email']['default'] == 'foo@bar.com'
294+
295+
assert response_schema['url']['type'] == 'string'
296+
assert response_schema['url']['nullable'] is True
297+
assert response_schema['url']['default'] == 'http://www.example.com'
298+
299+
assert response_schema['uuid']['type'] == 'string'
300+
assert response_schema['uuid']['format'] == 'uuid'
301+
302+
assert response_schema['ip4']['type'] == 'string'
303+
assert response_schema['ip4']['format'] == 'ipv4'
304+
305+
assert response_schema['ip6']['type'] == 'string'
306+
assert response_schema['ip6']['format'] == 'ipv6'
307+
308+
assert response_schema['ip']['type'] == 'string'
309+
assert 'format' not in response_schema['ip']

tests/schemas/views.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
import uuid
2+
3+
from django.core.validators import (
4+
DecimalValidator, MaxLengthValidator, MaxValueValidator,
5+
MinLengthValidator, MinValueValidator, RegexValidator
6+
)
7+
18
from rest_framework import generics, permissions, serializers
29
from rest_framework.decorators import action
310
from rest_framework.response import Response
@@ -56,3 +63,45 @@ def new(self, *args, **kwargs):
5663
@action(detail=False)
5764
def old(self, *args, **kwargs):
5865
pass
66+
67+
68+
# Validators and/or equivalent Field attributes.
69+
class ExampleValidatedSerializer(serializers.Serializer):
70+
integer = serializers.IntegerField(
71+
validators=(
72+
MaxValueValidator(limit_value=99),
73+
MinValueValidator(limit_value=-11),
74+
)
75+
)
76+
string = serializers.CharField(
77+
validators=(
78+
MaxLengthValidator(limit_value=10),
79+
MinLengthValidator(limit_value=2),
80+
)
81+
)
82+
regex = serializers.CharField(
83+
validators=(
84+
RegexValidator(regex=r'[ABC]12{3}'),
85+
),
86+
help_text='must have an A, B, or C followed by 1222'
87+
)
88+
decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2)
89+
decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0,
90+
validators=(DecimalValidator(max_digits=17, decimal_places=4),))
91+
email = serializers.EmailField(default='foo@bar.com')
92+
url = serializers.URLField(default='http://www.example.com', allow_null=True)
93+
uuid = serializers.UUIDField()
94+
ip4 = serializers.IPAddressField(protocol='ipv4')
95+
ip6 = serializers.IPAddressField(protocol='ipv6')
96+
ip = serializers.IPAddressField()
97+
98+
99+
class ExampleValdidatedAPIView(generics.GenericAPIView):
100+
serializer_class = ExampleValidatedSerializer
101+
102+
def get(self, *args, **kwargs):
103+
serializer = self.get_serializer(integer=33, string='hello', regex='foo', decimal1=3.55,
104+
decimal2=5.33, email='a@b.co',
105+
url='http://localhost', uuid=uuid.uuid4(), ip4='127.0.0.1', ip6='::1',
106+
ip='192.168.1.1')
107+
return Response(serializer.data)

0 commit comments

Comments
 (0)
0