Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 44 additions & 57 deletions starlette/middleware/cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,74 +2,71 @@

import functools
import re
from collections.abc import Sequence
from collections.abc import Collection

from starlette.datastructures import Headers, MutableHeaders
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send

ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT")
ALL_METHODS = {"DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"}
SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"}


class CORSMiddleware:
def __init__(
self,
app: ASGIApp,
allow_origins: Sequence[str] = (),
allow_methods: Sequence[str] = ("GET",),
allow_headers: Sequence[str] = (),
allow_origins: Collection[str] = {},
allow_methods: Collection[str] = {"GET"},
allow_headers: Collection[str] = {},
allow_credentials: bool = False,
allow_origin_regex: str | None = None,
allow_private_network: bool = False,
expose_headers: Sequence[str] = (),
expose_headers: Collection[str] = {},
max_age: int = 600,
) -> None:
if "*" in allow_methods:
allow_methods = ALL_METHODS

compiled_allow_origin_regex = None
if allow_origin_regex is not None:
compiled_allow_origin_regex = re.compile(allow_origin_regex)

allow_all_origins = "*" in allow_origins
allow_all_headers = "*" in allow_headers
preflight_explicit_allow_origin = not allow_all_origins or allow_credentials
explicit_allow_origin = allow_credentials or not allow_all_origins
allow_headers = SAFELISTED_HEADERS.union(allow_headers)

simple_headers: dict[str, str] = {}
if allow_all_origins:
if not explicit_allow_origin:
simple_headers["Access-Control-Allow-Origin"] = "*"
if allow_credentials:
simple_headers["Access-Control-Allow-Credentials"] = "true"
if expose_headers:
simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)

preflight_headers: dict[str, str] = {}
if preflight_explicit_allow_origin:
# The origin value will be set in preflight_response() if it is allowed.
preflight_headers["Vary"] = "Origin"
else:
preflight_headers: dict[str, str] = {
"Access-Control-Allow-Methods": ", ".join(allow_methods),
"Access-Control-Max-Age": str(max_age),
}
if not explicit_allow_origin:
preflight_headers["Access-Control-Allow-Origin"] = "*"
preflight_headers.update(
{
"Access-Control-Allow-Methods": ", ".join(allow_methods),
"Access-Control-Max-Age": str(max_age),
}
)
allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers))
if allow_headers and not allow_all_headers:
preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
if not allow_all_headers:
preflight_headers["Access-Control-Allow-Headers"] = ", ".join(sorted(allow_headers))
if allow_credentials:
preflight_headers["Access-Control-Allow-Credentials"] = "true"

preflight_vary: list[str] = ["Access-Control-Request-Headers", "Access-Control-Request-Private-Network"]
if allow_methods != ALL_METHODS:
preflight_vary.append("Access-Control-Request-Method")
if explicit_allow_origin:
preflight_vary.append("Origin")
preflight_headers["Vary"] = ", ".join(preflight_vary)

self.app = app
self.allow_origins = allow_origins
self.allow_methods = allow_methods
self.allow_headers = [h.lower() for h in allow_headers]
self.allow_headers = {h.lower() for h in allow_headers}
self.allow_all_origins = allow_all_origins
self.allow_all_headers = allow_all_headers
self.preflight_explicit_allow_origin = preflight_explicit_allow_origin
self.allow_origin_regex = compiled_allow_origin_regex
self.explicit_allow_origin = explicit_allow_origin
self.allow_origin_regex = re.compile(allow_origin_regex) if allow_origin_regex is not None else None
self.allow_private_network = allow_private_network
self.simple_headers = simple_headers
self.preflight_headers = preflight_headers
Expand All @@ -79,29 +76,28 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return

method = scope["method"]
headers = Headers(scope=scope)
origin = headers.get("origin")

if origin is None:
if "origin" not in headers:
await self.app(scope, receive, send)
return

if method == "OPTIONS" and "access-control-request-method" in headers:
if scope["method"] == "OPTIONS" and "access-control-request-method" in headers:
response = self.preflight_response(request_headers=headers)
await response(scope, receive, send)
return

await self.simple_response(scope, receive, send, request_headers=headers)

def is_allowed_origin(self, origin: str) -> bool:
if self.allow_all_origins:
if origin in self.allow_origins:
return True

if origin == "null":
return False
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
return True

return origin in self.allow_origins
return self.allow_all_origins

def preflight_response(self, request_headers: Headers) -> Response:
requested_origin = request_headers["origin"]
Expand All @@ -112,8 +108,8 @@ def preflight_response(self, request_headers: Headers) -> Response:
headers = dict(self.preflight_headers)
failures: list[str] = []

if self.is_allowed_origin(origin=requested_origin):
if self.preflight_explicit_allow_origin:
if self.is_allowed_origin(requested_origin):
if self.explicit_allow_origin:
# The "else" case is already accounted for in self.preflight_headers
# and the value would be "*".
headers["Access-Control-Allow-Origin"] = requested_origin
Expand All @@ -123,13 +119,12 @@ def preflight_response(self, request_headers: Headers) -> Response:
if requested_method not in self.allow_methods:
failures.append("method")

# If we allow all headers, then we have to mirror back any requested
# headers in the response.
# When allow_headers is wildcard, mirror any requested headers.
if self.allow_all_headers and requested_headers is not None:
headers["Access-Control-Allow-Headers"] = requested_headers
elif requested_headers is not None:
for header in [h.lower() for h in requested_headers.split(",")]:
if header.strip() not in self.allow_headers:
for header in requested_headers.split(","):
if header.strip().lower() not in self.allow_headers:
failures.append("headers")
break

Expand Down Expand Up @@ -159,23 +154,15 @@ async def send(self, message: Message, send: Send, request_headers: Headers) ->

message.setdefault("headers", [])
headers = MutableHeaders(scope=message)
headers.update(self.simple_headers)
origin = request_headers["Origin"]
has_cookie = "cookie" in request_headers

# If request includes any cookie headers, then we must respond
# with the specific origin instead of '*'.
if self.allow_all_origins and has_cookie:
self.allow_explicit_origin(headers, origin)
if self.explicit_allow_origin:
headers.add_vary_header("Origin")

# If we only allow specific origins, then we have to mirror back
# the Origin header in the response.
elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
self.allow_explicit_origin(headers, origin)
if self.is_allowed_origin(origin):
headers.update(self.simple_headers)

await send(message)
if self.explicit_allow_origin:
headers["Access-Control-Allow-Origin"] = origin

@staticmethod
def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None:
headers["Access-Control-Allow-Origin"] = origin
headers.add_vary_header("Origin")
await send(message)
Loading