Skip to content

Commit a001d9f

Browse files
Merge pull request #99 from sebdiem/master
add viewset support
2 parents 4f1d79e + 9b5f995 commit a001d9f

File tree

6 files changed

+71
-7
lines changed

6 files changed

+71
-7
lines changed

rest_framework_docs/api_docs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88

99
class ApiDocumentation(object):
1010

11-
def __init__(self):
11+
def __init__(self, drf_router=None):
1212
self.endpoints = []
13+
self.drf_router = drf_router
1314
try:
1415
root_urlconf = import_string(settings.ROOT_URLCONF)
1516
except ImportError:
@@ -26,7 +27,7 @@ def get_all_view_names(self, urlpatterns, parent_pattern=None):
2627
parent_pattern = None if pattern._regex == "^" else pattern
2728
self.get_all_view_names(urlpatterns=pattern.url_patterns, parent_pattern=parent_pattern)
2829
elif isinstance(pattern, RegexURLPattern) and self._is_drf_view(pattern) and not self._is_format_endpoint(pattern):
29-
api_endpoint = ApiEndpoint(pattern, parent_pattern)
30+
api_endpoint = ApiEndpoint(pattern, parent_pattern, self.drf_router)
3031
self.endpoints.append(api_endpoint)
3132

3233
def _is_drf_view(self, pattern):

rest_framework_docs/api_endpoint.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
class ApiEndpoint(object):
88

9-
def __init__(self, pattern, parent_pattern=None):
9+
def __init__(self, pattern, parent_pattern=None, drf_router=None):
10+
self.drf_router = drf_router
1011
self.pattern = pattern
1112
self.callback = pattern.callback
1213
# self.name = pattern.name
@@ -26,7 +27,39 @@ def __get_path__(self, parent_pattern):
2627
return simplify_regex(self.pattern.regex.pattern)
2728

2829
def __get_allowed_methods__(self):
29-
return [force_str(m).upper() for m in self.callback.cls.http_method_names if hasattr(self.callback.cls, m)]
30+
31+
viewset_methods = []
32+
if self.drf_router:
33+
for prefix, viewset, basename in self.drf_router.registry:
34+
if self.callback.cls != viewset:
35+
continue
36+
37+
lookup = self.drf_router.get_lookup_regex(viewset)
38+
routes = self.drf_router.get_routes(viewset)
39+
40+
for route in routes:
41+
42+
# Only actions which actually exist on the viewset will be bound
43+
mapping = self.drf_router.get_method_map(viewset, route.mapping)
44+
if not mapping:
45+
continue
46+
47+
# Build the url pattern
48+
regex = route.url.format(
49+
prefix=prefix,
50+
lookup=lookup,
51+
trailing_slash=self.drf_router.trailing_slash
52+
)
53+
if self.pattern.regex.pattern == regex:
54+
funcs, viewset_methods = zip(
55+
*[(mapping[m], m.upper()) for m in self.callback.cls.http_method_names if m in mapping]
56+
)
57+
viewset_methods = list(viewset_methods)
58+
if len(set(funcs)) == 1:
59+
self.docstring = inspect.getdoc(getattr(self.callback.cls, funcs[0]))
60+
61+
view_methods = [force_str(m).upper() for m in self.callback.cls.http_method_names if hasattr(self.callback.cls, m)]
62+
return viewset_methods + view_methods
3063

3164
def __get_docstring__(self):
3265
return inspect.getdoc(self.callback)

rest_framework_docs/views.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
class DRFDocsView(TemplateView):
88

99
template_name = "rest_framework_docs/home.html"
10+
drf_router = None
1011

1112
def get_context_data(self, **kwargs):
1213
settings = DRFSettings().settings
1314
if settings["HIDE_DOCS"]:
1415
raise Http404("Django Rest Framework Docs are hidden. Check your settings.")
1516

1617
context = super(DRFDocsView, self).get_context_data(**kwargs)
17-
docs = ApiDocumentation()
18+
docs = ApiDocumentation(drf_router=self.drf_router)
1819
endpoints = docs.get_endpoints()
1920

2021
query = self.request.GET.get("search", "")

tests/tests.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_index_view_with_endpoints(self):
2727
response = self.client.get(reverse('drfdocs'))
2828

2929
self.assertEqual(response.status_code, 200)
30-
self.assertEqual(len(response.context["endpoints"]), 11)
30+
self.assertEqual(len(response.context["endpoints"]), 14)
3131

3232
# Test the login view
3333
self.assertEqual(response.context["endpoints"][0].name_parent, "accounts")
@@ -67,3 +67,14 @@ def test_index_view_docs_hidden(self):
6767

6868
self.assertEqual(response.status_code, 404)
6969
self.assertEqual(response.reason_phrase.upper(), "NOT FOUND")
70+
71+
def test_model_viewset(self):
72+
response = self.client.get(reverse('drfdocs'))
73+
74+
self.assertEqual(response.status_code, 200)
75+
self.assertEqual(response.context["endpoints"][10].path, '/organisation-model-viewsets/')
76+
self.assertEqual(response.context["endpoints"][11].path, '/organisation-model-viewsets/<pk>/')
77+
self.assertEqual(response.context["endpoints"][10].allowed_methods, ['GET', 'POST', 'OPTIONS'])
78+
self.assertEqual(response.context["endpoints"][11].allowed_methods, ['GET', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])
79+
self.assertEqual(response.context["endpoints"][12].allowed_methods, ['POST', 'OPTIONS'])
80+
self.assertEqual(response.context["endpoints"][12].docstring, 'This is a test.')

tests/urls.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from django.conf.urls import include, url
44
from django.contrib import admin
5+
from rest_framework.routers import SimpleRouter
6+
from rest_framework_docs.views import DRFDocsView
57
from tests import views
68

79
accounts_urls = [
@@ -23,13 +25,17 @@
2325
url(r'^(?P<slug>[\w-]+)/errored/$', view=views.OrganisationErroredView.as_view(), name="errored")
2426
]
2527

28+
router = SimpleRouter()
29+
router.register('organisation-model-viewsets', views.TestModelViewSet, base_name='organisation')
30+
2631
urlpatterns = [
2732
url(r'^admin/', include(admin.site.urls)),
28-
url(r'^docs/', include('rest_framework_docs.urls')),
33+
url(r'^docs/', DRFDocsView.as_view(drf_router=router), name='drfdocs'),
2934

3035
# API
3136
url(r'^accounts/', view=include(accounts_urls, namespace='accounts')),
3237
url(r'^organisations/', view=include(organisations_urls, namespace='organisations')),
38+
url(r'^', include(router.urls)),
3339

3440
# Endpoints without parents/namespaces
3541
url(r'^another-login/$', views.LoginView.as_view(), name="login"),

tests/views.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from rest_framework import parsers, renderers, generics, status
66
from rest_framework.authtoken.models import Token
77
from rest_framework.authtoken.serializers import AuthTokenSerializer
8+
from rest_framework.decorators import detail_route
89
from rest_framework.permissions import AllowAny
910
from rest_framework.response import Response
1011
from rest_framework.views import APIView
12+
from rest_framework.viewsets import ModelViewSet
1113
from tests.models import User, Organisation, Membership
1214
from tests import serializers
1315

@@ -132,3 +134,13 @@ def post(self, request):
132134

133135
def get_serializer_class(self):
134136
return AuthTokenSerializer
137+
138+
139+
class TestModelViewSet(ModelViewSet):
140+
queryset = Organisation.objects.all()
141+
serializer_class = serializers.OrganisationMembersSerializer
142+
143+
@detail_route(methods=['post'])
144+
def test_route(self, request):
145+
"""This is a test."""
146+
return Response()

0 commit comments

Comments
 (0)