diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 36ba9f93f..496457687 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -2,13 +2,13 @@ import functools import re -from collections.abc import Sequence +from collections.abc import Collection from starlette.datastructures import Headers, MutableHeaders from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send -ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT") +ALL_METHODS = {"DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"} SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"} @@ -16,60 +16,57 @@ class CORSMiddleware: def __init__( self, app: ASGIApp, - allow_origins: Sequence[str] = (), - allow_methods: Sequence[str] = ("GET",), - allow_headers: Sequence[str] = (), + allow_origins: Collection[str] = {}, + allow_methods: Collection[str] = {"GET"}, + allow_headers: Collection[str] = {}, allow_credentials: bool = False, allow_origin_regex: str | None = None, allow_private_network: bool = False, - expose_headers: Sequence[str] = (), + expose_headers: Collection[str] = {}, max_age: int = 600, ) -> None: if "*" in allow_methods: allow_methods = ALL_METHODS - compiled_allow_origin_regex = None - if allow_origin_regex is not None: - compiled_allow_origin_regex = re.compile(allow_origin_regex) - allow_all_origins = "*" in allow_origins allow_all_headers = "*" in allow_headers - preflight_explicit_allow_origin = not allow_all_origins or allow_credentials + explicit_allow_origin = allow_credentials or not allow_all_origins + allow_headers = SAFELISTED_HEADERS.union(allow_headers) simple_headers: dict[str, str] = {} - if allow_all_origins: + if not explicit_allow_origin: simple_headers["Access-Control-Allow-Origin"] = "*" if allow_credentials: simple_headers["Access-Control-Allow-Credentials"] = "true" if expose_headers: simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers) - preflight_headers: dict[str, str] = {} - if preflight_explicit_allow_origin: - # The origin value will be set in preflight_response() if it is allowed. - preflight_headers["Vary"] = "Origin" - else: + preflight_headers: dict[str, str] = { + "Access-Control-Allow-Methods": ", ".join(allow_methods), + "Access-Control-Max-Age": str(max_age), + } + if not explicit_allow_origin: preflight_headers["Access-Control-Allow-Origin"] = "*" - preflight_headers.update( - { - "Access-Control-Allow-Methods": ", ".join(allow_methods), - "Access-Control-Max-Age": str(max_age), - } - ) - allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers)) - if allow_headers and not allow_all_headers: - preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers) + if not allow_all_headers: + preflight_headers["Access-Control-Allow-Headers"] = ", ".join(sorted(allow_headers)) if allow_credentials: preflight_headers["Access-Control-Allow-Credentials"] = "true" + preflight_vary: list[str] = ["Access-Control-Request-Headers", "Access-Control-Request-Private-Network"] + if allow_methods != ALL_METHODS: + preflight_vary.append("Access-Control-Request-Method") + if explicit_allow_origin: + preflight_vary.append("Origin") + preflight_headers["Vary"] = ", ".join(preflight_vary) + self.app = app self.allow_origins = allow_origins self.allow_methods = allow_methods - self.allow_headers = [h.lower() for h in allow_headers] + self.allow_headers = {h.lower() for h in allow_headers} self.allow_all_origins = allow_all_origins self.allow_all_headers = allow_all_headers - self.preflight_explicit_allow_origin = preflight_explicit_allow_origin - self.allow_origin_regex = compiled_allow_origin_regex + self.explicit_allow_origin = explicit_allow_origin + self.allow_origin_regex = re.compile(allow_origin_regex) if allow_origin_regex is not None else None self.allow_private_network = allow_private_network self.simple_headers = simple_headers self.preflight_headers = preflight_headers @@ -79,15 +76,13 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) return - method = scope["method"] headers = Headers(scope=scope) - origin = headers.get("origin") - if origin is None: + if "origin" not in headers: await self.app(scope, receive, send) return - if method == "OPTIONS" and "access-control-request-method" in headers: + if scope["method"] == "OPTIONS" and "access-control-request-method" in headers: response = self.preflight_response(request_headers=headers) await response(scope, receive, send) return @@ -95,13 +90,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.simple_response(scope, receive, send, request_headers=headers) def is_allowed_origin(self, origin: str) -> bool: - if self.allow_all_origins: + if origin in self.allow_origins: return True - + if origin == "null": + return False if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin): return True - return origin in self.allow_origins + return self.allow_all_origins def preflight_response(self, request_headers: Headers) -> Response: requested_origin = request_headers["origin"] @@ -112,8 +108,8 @@ def preflight_response(self, request_headers: Headers) -> Response: headers = dict(self.preflight_headers) failures: list[str] = [] - if self.is_allowed_origin(origin=requested_origin): - if self.preflight_explicit_allow_origin: + if self.is_allowed_origin(requested_origin): + if self.explicit_allow_origin: # The "else" case is already accounted for in self.preflight_headers # and the value would be "*". headers["Access-Control-Allow-Origin"] = requested_origin @@ -123,13 +119,12 @@ def preflight_response(self, request_headers: Headers) -> Response: if requested_method not in self.allow_methods: failures.append("method") - # If we allow all headers, then we have to mirror back any requested - # headers in the response. + # When allow_headers is wildcard, mirror any requested headers. if self.allow_all_headers and requested_headers is not None: headers["Access-Control-Allow-Headers"] = requested_headers elif requested_headers is not None: - for header in [h.lower() for h in requested_headers.split(",")]: - if header.strip() not in self.allow_headers: + for header in requested_headers.split(","): + if header.strip().lower() not in self.allow_headers: failures.append("headers") break @@ -159,23 +154,15 @@ async def send(self, message: Message, send: Send, request_headers: Headers) -> message.setdefault("headers", []) headers = MutableHeaders(scope=message) - headers.update(self.simple_headers) origin = request_headers["Origin"] - has_cookie = "cookie" in request_headers - # If request includes any cookie headers, then we must respond - # with the specific origin instead of '*'. - if self.allow_all_origins and has_cookie: - self.allow_explicit_origin(headers, origin) + if self.explicit_allow_origin: + headers.add_vary_header("Origin") - # If we only allow specific origins, then we have to mirror back - # the Origin header in the response. - elif not self.allow_all_origins and self.is_allowed_origin(origin=origin): - self.allow_explicit_origin(headers, origin) + if self.is_allowed_origin(origin): + headers.update(self.simple_headers) - await send(message) + if self.explicit_allow_origin: + headers["Access-Control-Allow-Origin"] = origin - @staticmethod - def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None: - headers["Access-Control-Allow-Origin"] = origin - headers.add_vary_header("Origin") + await send(message) diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index cbee7d6e7..1bcf62ed5 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -1,12 +1,25 @@ +from httpx import Response + from starlette.applications import Starlette from starlette.middleware import Middleware -from starlette.middleware.cors import CORSMiddleware +from starlette.middleware.cors import ALL_METHODS, CORSMiddleware from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route from tests.types import TestClientFactory +def assert_vary(response: Response, expected: set[str] | None) -> None: + vary_header = response.headers.get("vary") + if expected is None: + assert vary_header is None + return + + assert vary_header is not None + actual = {value.strip() for value in vary_header.split(",") if value.strip()} + assert actual == expected + + def test_cors_allow_all( test_client_factory: TestClientFactory, ) -> None: @@ -37,29 +50,25 @@ def homepage(request: Request) -> PlainTextResponse: } response = client.options("/", headers=headers) assert response.status_code == 200 - assert response.text == "OK" assert response.headers["access-control-allow-origin"] == "https://example.org" assert response.headers["access-control-allow-headers"] == "X-Example" assert response.headers["access-control-allow-credentials"] == "true" - assert response.headers["vary"] == "Origin" + assert_vary( + response, + {"Access-Control-Request-Headers", "Access-Control-Request-Private-Network", "Origin"}, + ) + allowed_methods = {method.strip() for method in response.headers["access-control-allow-methods"].split(",")} + assert allowed_methods == ALL_METHODS # Test standard response headers = {"Origin": "https://example.org"} response = client.get("/", headers=headers) assert response.status_code == 200 assert response.text == "Homepage" - assert response.headers["access-control-allow-origin"] == "*" - assert response.headers["access-control-expose-headers"] == "X-Status" - assert response.headers["access-control-allow-credentials"] == "true" - - # Test standard credentialed response - headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} - response = client.get("/", headers=headers) - assert response.status_code == 200 - assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" assert response.headers["access-control-expose-headers"] == "X-Status" assert response.headers["access-control-allow-credentials"] == "true" + assert_vary(response, {"Origin"}) # Test non-CORS response response = client.get("/") @@ -101,7 +110,10 @@ def homepage(request: Request) -> PlainTextResponse: assert response.headers["access-control-allow-origin"] == "*" assert response.headers["access-control-allow-headers"] == "X-Example" assert "access-control-allow-credentials" not in response.headers - assert "vary" not in response.headers + assert_vary( + response, + {"Access-Control-Request-Headers", "Access-Control-Request-Private-Network"}, + ) # Test standard response headers = {"Origin": "https://example.org"} @@ -111,6 +123,7 @@ def homepage(request: Request) -> PlainTextResponse: assert response.headers["access-control-allow-origin"] == "*" assert response.headers["access-control-expose-headers"] == "X-Status" assert "access-control-allow-credentials" not in response.headers + assert_vary(response, None) # Test non-CORS response response = client.get("/") @@ -146,12 +159,20 @@ def homepage(request: Request) -> PlainTextResponse: } response = client.options("/", headers=headers) assert response.status_code == 200 - assert response.text == "OK" assert response.headers["access-control-allow-origin"] == "https://example.org" - assert response.headers["access-control-allow-headers"] == ( - "Accept, Accept-Language, Content-Language, Content-Type, X-Example" - ) + allowed_headers = {h.strip() for h in response.headers["access-control-allow-headers"].split(",")} + assert allowed_headers == {"Accept", "Accept-Language", "Content-Language", "Content-Type", "X-Example"} assert "access-control-allow-credentials" not in response.headers + assert_vary( + response, + { + "Access-Control-Request-Headers", + "Access-Control-Request-Private-Network", + "Access-Control-Request-Method", + "Origin", + }, + ) + assert response.headers["access-control-allow-methods"] == "GET" # Test standard response headers = {"Origin": "https://example.org"} @@ -160,6 +181,15 @@ def homepage(request: Request) -> PlainTextResponse: assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" assert "access-control-allow-credentials" not in response.headers + assert_vary(response, {"Origin"}) + + # Test disallowed standard response + headers = {"Origin": "https://another.org"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert "access-control-allow-origin" not in response.headers + assert_vary(response, {"Origin"}) # Test non-CORS response response = client.get("/") @@ -197,6 +227,15 @@ def homepage(request: Request) -> None: assert response.status_code == 400 assert response.text == "Disallowed CORS origin, method, headers" assert "access-control-allow-origin" not in response.headers + assert_vary( + response, + { + "Access-Control-Request-Headers", + "Access-Control-Request-Private-Network", + "Access-Control-Request-Method", + "Origin", + }, + ) # Bug specific test, https://github.com/Kludex/starlette/pull/1199 # Test preflight response text with multiple disallowed headers @@ -209,41 +248,6 @@ def homepage(request: Request) -> None: assert response.text == "Disallowed CORS headers" -def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed( - test_client_factory: TestClientFactory, -) -> None: - def homepage(request: Request) -> None: - return # pragma: no cover - - app = Starlette( - routes=[Route("/", endpoint=homepage)], - middleware=[ - Middleware( - CORSMiddleware, - allow_origins=["*"], - allow_methods=["POST"], - allow_credentials=True, - ) - ], - ) - - client = test_client_factory(app) - - # Test pre-flight response - headers = { - "Origin": "https://example.org", - "Access-Control-Request-Method": "POST", - } - response = client.options( - "/", - headers=headers, - ) - assert response.status_code == 200 - assert response.headers["access-control-allow-origin"] == "https://example.org" - assert response.headers["access-control-allow-credentials"] == "true" - assert response.headers["vary"] == "Origin" - - def test_cors_preflight_allow_all_methods( test_client_factory: TestClientFactory, ) -> None: @@ -262,10 +266,11 @@ def homepage(request: Request) -> None: "Access-Control-Request-Method": "POST", } + response = client.options("/", headers=headers) + assert response.status_code == 200 + allowed_methods = {m.strip() for m in response.headers["access-control-allow-methods"].split(",")} for method in ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"): - response = client.options("/", headers=headers) - assert response.status_code == 200 - assert method in response.headers["access-control-allow-methods"] + assert method in allowed_methods def test_cors_allow_all_methods( @@ -324,14 +329,7 @@ def homepage(request: Request) -> PlainTextResponse: assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" assert response.headers["access-control-allow-credentials"] == "true" - - # Test standard credentialed response - headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} - response = client.get("/", headers=headers) - assert response.status_code == 200 - assert response.text == "Homepage" - assert response.headers["access-control-allow-origin"] == "https://example.org" - assert response.headers["access-control-allow-credentials"] == "true" + assert_vary(response, {"Origin"}) # Test disallowed standard response # Note that enforcement is a browser concern. The disallowed-ness is reflected @@ -341,6 +339,7 @@ def homepage(request: Request) -> PlainTextResponse: assert response.status_code == 200 assert response.text == "Homepage" assert "access-control-allow-origin" not in response.headers + assert_vary(response, {"Origin"}) # Test pre-flight response headers = { @@ -350,12 +349,19 @@ def homepage(request: Request) -> PlainTextResponse: } response = client.options("/", headers=headers) assert response.status_code == 200 - assert response.text == "OK" assert response.headers["access-control-allow-origin"] == "https://another.com" - assert response.headers["access-control-allow-headers"] == ( - "Accept, Accept-Language, Content-Language, Content-Type, X-Example" - ) + allowed_headers = {h.strip() for h in response.headers["access-control-allow-headers"].split(",")} + assert allowed_headers == {"Accept", "Accept-Language", "Content-Language", "Content-Type", "X-Example"} assert response.headers["access-control-allow-credentials"] == "true" + assert_vary( + response, + { + "Access-Control-Request-Headers", + "Access-Control-Request-Private-Network", + "Access-Control-Request-Method", + "Origin", + }, + ) # Test disallowed pre-flight response headers = { @@ -395,6 +401,7 @@ def homepage(request: Request) -> PlainTextResponse: assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://subdomain.example.org" assert "access-control-allow-credentials" not in response.headers + assert_vary(response, {"Origin"}) # Test disallowed standard response headers = {"Origin": "https://subdomain.example.org.hacker.com"} @@ -402,100 +409,93 @@ def homepage(request: Request) -> PlainTextResponse: assert response.status_code == 200 assert response.text == "Homepage" assert "access-control-allow-origin" not in response.headers + assert_vary(response, {"Origin"}) -def test_cors_credentialed_requests_return_specific_origin( +def test_cors_vary_header_behavior( test_client_factory: TestClientFactory, ) -> None: def homepage(request: Request) -> PlainTextResponse: - return PlainTextResponse("Homepage", status_code=200) + return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}) - app = Starlette( + # Test 1: Specific origins add Vary: Origin + app_specific = Starlette( routes=[Route("/", endpoint=homepage)], - middleware=[Middleware(CORSMiddleware, allow_origins=["*"])], + middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"])], ) - client = test_client_factory(app) + client = test_client_factory(app_specific) - # Test credentialed request - headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} - response = client.get("/", headers=headers) + response = client.get("/", headers={"Origin": "https://example.org"}) assert response.status_code == 200 - assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" - assert "access-control-allow-credentials" not in response.headers - - -def test_cors_vary_header_defaults_to_origin( - test_client_factory: TestClientFactory, -) -> None: - def homepage(request: Request) -> PlainTextResponse: - return PlainTextResponse("Homepage", status_code=200) + assert_vary(response, {"Accept-Encoding", "Origin"}) - app = Starlette( + # Test 2: Wildcard without credentials does not add Vary: Origin + app_wildcard = Starlette( routes=[Route("/", endpoint=homepage)], - middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"])], + middleware=[Middleware(CORSMiddleware, allow_origins=["*"])], ) + client = test_client_factory(app_wildcard) - headers = {"Origin": "https://example.org"} - - client = test_client_factory(app) - - response = client.get("/", headers=headers) + response = client.get("/", headers={"Origin": "https://someplace.org"}) assert response.status_code == 200 - assert response.headers["vary"] == "Origin" - - -def test_cors_vary_header_is_not_set_for_non_credentialed_request( - test_client_factory: TestClientFactory, -) -> None: - def homepage(request: Request) -> PlainTextResponse: - return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}) + assert response.headers["access-control-allow-origin"] == "*" + assert_vary(response, {"Accept-Encoding"}) - app = Starlette( + # Test 3: Wildcard with credentials adds Vary: Origin + app_wildcard_creds = Starlette( routes=[Route("/", endpoint=homepage)], - middleware=[Middleware(CORSMiddleware, allow_origins=["*"])], + middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True)], ) - client = test_client_factory(app) + client = test_client_factory(app_wildcard_creds) - response = client.get("/", headers={"Origin": "https://someplace.org"}) + response = client.get("/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}) assert response.status_code == 200 - assert response.headers["vary"] == "Accept-Encoding" + assert response.headers["access-control-allow-origin"] == "https://someplace.org" + assert response.headers["access-control-allow-credentials"] == "true" + assert_vary(response, {"Accept-Encoding", "Origin"}) -def test_cors_vary_header_is_properly_set_for_credentialed_request( +def test_cors_preflight_vary_with_wildcard_origins_specific_methods( test_client_factory: TestClientFactory, ) -> None: - def homepage(request: Request) -> PlainTextResponse: - return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}) + def homepage(request: Request) -> None: + pass # pragma: no cover app = Starlette( routes=[Route("/", endpoint=homepage)], - middleware=[Middleware(CORSMiddleware, allow_origins=["*"])], + middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["GET", "POST"])], ) client = test_client_factory(app) - response = client.get("/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}) + # Preflight Vary should include Request-Method even with wildcard origins when methods are restricted + response = client.options("/", headers={"Origin": "https://example.org", "Access-Control-Request-Method": "POST"}) assert response.status_code == 200 - assert response.headers["vary"] == "Accept-Encoding, Origin" + assert_vary( + response, + {"Access-Control-Request-Headers", "Access-Control-Request-Private-Network", "Access-Control-Request-Method"}, + ) -def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard( +def test_cors_preflight_vary_with_specific_origins_wildcard_methods( test_client_factory: TestClientFactory, ) -> None: - def homepage(request: Request) -> PlainTextResponse: - return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}) + def homepage(request: Request) -> None: + pass # pragma: no cover app = Starlette( - routes=[ - Route("/", endpoint=homepage), - ], - middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"])], + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"], allow_methods=["*"])], ) client = test_client_factory(app) - response = client.get("/", headers={"Origin": "https://example.org"}) + # Preflight Vary should NOT include Request-Method when methods are unrestricted + response = client.options("/", headers={"Origin": "https://example.org", "Access-Control-Request-Method": "POST"}) assert response.status_code == 200 - assert response.headers["vary"] == "Accept-Encoding, Origin" + assert_vary( + response, + {"Access-Control-Request-Headers", "Access-Control-Request-Private-Network", "Origin"}, + ) def test_cors_allowed_origin_does_not_leak_between_credentialed_requests( @@ -514,22 +514,26 @@ def homepage(request: Request) -> PlainTextResponse: allow_origins=["*"], allow_headers=["*"], allow_methods=["*"], + allow_credentials=True, ) ], ) client = test_client_factory(app) - response = client.get("/", headers={"Origin": "https://someplace.org"}) - assert response.headers["access-control-allow-origin"] == "*" - assert "access-control-allow-credentials" not in response.headers + first_origin = "https://first.example" + second_origin = "https://second.example" - response = client.get("/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}) - assert response.headers["access-control-allow-origin"] == "https://someplace.org" - assert "access-control-allow-credentials" not in response.headers + response = client.get("/", headers={"Origin": first_origin}) + assert response.headers["access-control-allow-origin"] == first_origin + assert response.headers["access-control-allow-credentials"] == "true" - response = client.get("/", headers={"Origin": "https://someplace.org"}) - assert response.headers["access-control-allow-origin"] == "*" - assert "access-control-allow-credentials" not in response.headers + response = client.get("/", headers={"Cookie": "foo=bar", "Origin": second_origin}) + assert response.headers["access-control-allow-origin"] == second_origin + assert response.headers["access-control-allow-credentials"] == "true" + + response = client.get("/", headers={"Origin": first_origin}) + assert response.headers["access-control-allow-origin"] == first_origin + assert response.headers["access-control-allow-credentials"] == "true" def test_cors_private_network_access_allowed(test_client_factory: TestClientFactory) -> None: @@ -558,12 +562,20 @@ def homepage(request: Request) -> PlainTextResponse: assert response.status_code == 200 assert response.text == "OK" assert response.headers["access-control-allow-private-network"] == "true" + assert_vary( + response, + {"Access-Control-Request-Headers", "Access-Control-Request-Private-Network"}, + ) # Test preflight without Private Network Access request response = client.options("/", headers=headers_without_pna) assert response.status_code == 200 assert response.text == "OK" assert "access-control-allow-private-network" not in response.headers + assert_vary( + response, + {"Access-Control-Request-Headers", "Access-Control-Request-Private-Network"}, + ) # The access-control-allow-private-network header is not set for non-preflight requests response = client.get("/", headers=headers_with_pna) @@ -605,3 +617,220 @@ def homepage(request: Request) -> None: ... # pragma: no cover assert response.status_code == 400 assert response.text == "Disallowed CORS private-network" assert "access-control-allow-private-network" not in response.headers + assert_vary( + response, + {"Access-Control-Request-Headers", "Access-Control-Request-Private-Network"}, + ) + + +def test_cors_null_origin_rejection(test_client_factory: TestClientFactory) -> None: + def homepage(request: Request) -> PlainTextResponse: + return PlainTextResponse("Homepage", status_code=200) + + # Test 1: Null rejected with wildcard origins + app_wildcard = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CORSMiddleware, allow_origins=["*"])], + ) + client = test_client_factory(app_wildcard) + + response = client.options("/", headers={"Origin": "null", "Access-Control-Request-Method": "GET"}) + assert response.status_code == 400 + assert "origin" in response.text.lower() + + response = client.get("/", headers={"Origin": "null"}) + assert response.status_code == 200 + assert "access-control-allow-origin" not in response.headers + + # Test 2: Null rejected with regex that matches everything + app_regex = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CORSMiddleware, allow_origin_regex=r".*")], + ) + client = test_client_factory(app_regex) + + response = client.options("/", headers={"Origin": "null", "Access-Control-Request-Method": "GET"}) + assert response.status_code == 400 + assert "origin" in response.text.lower() + + # Test 3: Null rejected even when regex explicitly includes it + app_regex_explicit = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CORSMiddleware, allow_origin_regex=r"null|https://.*")], + ) + client = test_client_factory(app_regex_explicit) + + response = client.options("/", headers={"Origin": "null", "Access-Control-Request-Method": "GET"}) + assert response.status_code == 400 + assert "origin" in response.text.lower() + + # Verify HTTPS origins still work with the regex + response = client.get("/", headers={"Origin": "https://example.org"}) + assert response.status_code == 200 + assert response.headers["access-control-allow-origin"] == "https://example.org" + + +def test_cors_null_origin_explicitly_allowed(test_client_factory: TestClientFactory) -> None: + def homepage(request: Request) -> PlainTextResponse: + return PlainTextResponse("Homepage", status_code=200) + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CORSMiddleware, allow_origins=["null", "https://example.org"])], + ) + + client = test_client_factory(app) + + # Null origin should be allowed when explicitly whitelisted + response = client.options("/", headers={"Origin": "null", "Access-Control-Request-Method": "GET"}) + assert response.status_code == 200 + assert response.headers["access-control-allow-origin"] == "null" + + # Simple request should also allow null origin + response = client.get("/", headers={"Origin": "null"}) + assert response.status_code == 200 + assert response.headers["access-control-allow-origin"] == "null" + + +def test_cors_method_case_sensitive(test_client_factory: TestClientFactory) -> None: + def homepage(request: Request) -> None: + pass # pragma: no cover + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"], allow_methods=["POST"])], + ) + + client = test_client_factory(app) + + # Uppercase POST should be allowed + response = client.options("/", headers={"Origin": "https://example.org", "Access-Control-Request-Method": "POST"}) + assert response.status_code == 200 + + # Lowercase "post" should be rejected (methods are case-sensitive per HTTP spec) + response = client.options("/", headers={"Origin": "https://example.org", "Access-Control-Request-Method": "post"}) + assert response.status_code == 400 + assert "method" in response.text.lower() + + +def test_cors_empty_origins_list(test_client_factory: TestClientFactory) -> None: + def homepage(request: Request) -> PlainTextResponse: + return PlainTextResponse("Homepage", status_code=200) + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CORSMiddleware, allow_origins=[])], + ) + client = test_client_factory(app) + + response = client.options("/", headers={"Origin": "https://example.org", "Access-Control-Request-Method": "GET"}) + assert response.status_code == 400 + assert "access-control-allow-origin" not in response.headers + assert_vary( + response, + { + "Access-Control-Request-Headers", + "Access-Control-Request-Private-Network", + "Access-Control-Request-Method", + "Origin", + }, + ) + + response = client.get("/", headers={"Origin": "https://example.org"}) + assert response.status_code == 200 + assert "access-control-allow-origin" not in response.headers + assert_vary(response, {"Origin"}) + + +def test_cors_origins_list_and_regex_both_accepted(test_client_factory: TestClientFactory) -> None: + def homepage(request: Request) -> PlainTextResponse: + return PlainTextResponse("Homepage", status_code=200) + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[ + Middleware( + CORSMiddleware, + allow_origins=["https://example.org"], + allow_origin_regex=r"https://.*\.trusted\.com", + ) + ], + ) + client = test_client_factory(app) + + # Origin in explicit list should be accepted + response = client.get("/", headers={"Origin": "https://example.org"}) + assert response.status_code == 200 + assert response.headers["access-control-allow-origin"] == "https://example.org" + + # Origin matching regex should be accepted + response = client.get("/", headers={"Origin": "https://api.trusted.com"}) + assert response.status_code == 200 + assert response.headers["access-control-allow-origin"] == "https://api.trusted.com" + + # Origin matching neither should be rejected + response = client.get("/", headers={"Origin": "https://evil.com"}) + assert response.status_code == 200 + assert "access-control-allow-origin" not in response.headers + + +def test_cors_max_age_header(test_client_factory: TestClientFactory) -> None: + def homepage(request: Request) -> None: + pass # pragma: no cover + + app_default = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CORSMiddleware, allow_origins=["*"])], + ) + client = test_client_factory(app_default) + + response = client.options("/", headers={"Origin": "https://example.org", "Access-Control-Request-Method": "GET"}) + assert response.status_code == 200 + assert response.headers["access-control-max-age"] == "600" + + app_custom = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CORSMiddleware, allow_origins=["*"], max_age=7200)], + ) + client = test_client_factory(app_custom) + + response = client.options("/", headers={"Origin": "https://example.org", "Access-Control-Request-Method": "GET"}) + assert response.status_code == 200 + assert response.headers["access-control-max-age"] == "7200" + + +def test_cors_no_origin_header_no_cors_processing(test_client_factory: TestClientFactory) -> None: + def homepage(request: Request) -> PlainTextResponse: + return PlainTextResponse("Homepage", status_code=200) + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CORSMiddleware, allow_origins=["*"])], + ) + client = test_client_factory(app) + + response = client.get("/") + assert response.status_code == 200 + assert "access-control-allow-origin" not in response.headers + assert_vary(response, None) + + +def test_cors_header_name_case_insensitive(test_client_factory: TestClientFactory) -> None: + def homepage(request: Request) -> None: + pass # pragma: no cover + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_headers=["X-Custom-Header"])], + ) + client = test_client_factory(app) + + response = client.options( + "/", + headers={ + "Origin": "https://example.org", + "Access-Control-Request-Method": "GET", + "Access-Control-Request-Headers": "x-custom-header", + }, + ) + assert response.status_code == 200