|  | 
|  | 1 | +/* | 
|  | 2 | + * Copyright 2013-2020 the original author or authors. | 
|  | 3 | + * | 
|  | 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 5 | + * you may not use this file except in compliance with the License. | 
|  | 6 | + * You may obtain a copy of the License at | 
|  | 7 | + * | 
|  | 8 | + *      https://www.apache.org/licenses/LICENSE-2.0 | 
|  | 9 | + * | 
|  | 10 | + * Unless required by applicable law or agreed to in writing, software | 
|  | 11 | + * distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 13 | + * See the License for the specific language governing permissions and | 
|  | 14 | + * limitations under the License. | 
|  | 15 | + */ | 
|  | 16 | + | 
|  | 17 | +package org.springframework.cloud.gateway.filter.cors; | 
|  | 18 | + | 
|  | 19 | +import java.time.Duration; | 
|  | 20 | +import java.util.LinkedHashMap; | 
|  | 21 | +import java.util.List; | 
|  | 22 | +import java.util.Map; | 
|  | 23 | + | 
|  | 24 | +import org.awaitility.Awaitility; | 
|  | 25 | +import org.junit.jupiter.api.BeforeEach; | 
|  | 26 | +import org.junit.jupiter.api.Test; | 
|  | 27 | +import org.junit.jupiter.api.extension.ExtendWith; | 
|  | 28 | +import org.mockito.ArgumentCaptor; | 
|  | 29 | +import org.mockito.Captor; | 
|  | 30 | +import org.mockito.Mock; | 
|  | 31 | +import org.mockito.junit.jupiter.MockitoExtension; | 
|  | 32 | +import reactor.core.publisher.Flux; | 
|  | 33 | + | 
|  | 34 | +import org.springframework.cloud.gateway.config.GlobalCorsProperties; | 
|  | 35 | +import org.springframework.cloud.gateway.event.RefreshRoutesResultEvent; | 
|  | 36 | +import org.springframework.cloud.gateway.handler.RoutePredicateHandlerMapping; | 
|  | 37 | +import org.springframework.cloud.gateway.handler.predicate.PathRoutePredicateFactory; | 
|  | 38 | +import org.springframework.cloud.gateway.route.Route; | 
|  | 39 | +import org.springframework.cloud.gateway.route.RouteLocator; | 
|  | 40 | +import org.springframework.web.cors.CorsConfiguration; | 
|  | 41 | + | 
|  | 42 | +import static org.assertj.core.api.Assertions.assertThat; | 
|  | 43 | +import static org.mockito.Mockito.verify; | 
|  | 44 | +import static org.mockito.Mockito.when; | 
|  | 45 | + | 
|  | 46 | +/** | 
|  | 47 | + * Tests for {@link CorsGatewayFilterApplicationListener}. | 
|  | 48 | + * | 
|  | 49 | + * <p> | 
|  | 50 | + * This test verifies that the merged CORS configurations - composed of per-route metadata | 
|  | 51 | + * and at the global level - maintain insertion order, as defined by the use of | 
|  | 52 | + * {@link LinkedHashMap}. Preserving insertion order helps for predictable and | 
|  | 53 | + * deterministic CORS behavior when resolving multiple matching path patterns. | 
|  | 54 | + * </p> | 
|  | 55 | + * | 
|  | 56 | + * <p> | 
|  | 57 | + * The test builds actual {@link Route} instances with {@code Path} predicates and | 
|  | 58 | + * verifies that the resulting configuration map passed to | 
|  | 59 | + * {@link RoutePredicateHandlerMapping#setCorsConfigurations(Map)} respects the declared | 
|  | 60 | + * order of: | 
|  | 61 | + * <ul> | 
|  | 62 | + * <li>Route-specific CORS configurations (in the order the routes are discovered)</li> | 
|  | 63 | + * <li>Global CORS configurations (in insertion order)</li> | 
|  | 64 | + * </ul> | 
|  | 65 | + * </p> | 
|  | 66 | + * | 
|  | 67 | + * @author Yavor Chamov | 
|  | 68 | + */ | 
|  | 69 | +@ExtendWith(MockitoExtension.class) | 
|  | 70 | +class CorsGatewayFilterApplicationListenerTests { | 
|  | 71 | + | 
|  | 72 | +	private static final String GLOBAL_PATH_1 = "/global1"; | 
|  | 73 | + | 
|  | 74 | +	private static final String GLOBAL_PATH_2 = "/global2"; | 
|  | 75 | + | 
|  | 76 | +	private static final String ROUTE_PATH_1 = "/route1"; | 
|  | 77 | + | 
|  | 78 | +	private static final String ROUTE_PATH_2 = "/route2"; | 
|  | 79 | + | 
|  | 80 | +	private static final String ORIGIN_GLOBAL_1 = "https://global1.com"; | 
|  | 81 | + | 
|  | 82 | +	private static final String ORIGIN_GLOBAL_2 = "https://global2.com"; | 
|  | 83 | + | 
|  | 84 | +	private static final String ORIGIN_ROUTE_1 = "https://route1.com"; | 
|  | 85 | + | 
|  | 86 | +	private static final String ORIGIN_ROUTE_2 = "https://route2.com"; | 
|  | 87 | + | 
|  | 88 | +	private static final String ROUTE_ID_1 = "route1"; | 
|  | 89 | + | 
|  | 90 | +	private static final String ROUTE_ID_2 = "route2"; | 
|  | 91 | + | 
|  | 92 | +	private static final String ROUTE_URI = "https://spring.io"; | 
|  | 93 | + | 
|  | 94 | +	private static final String METADATA_KEY = "cors"; | 
|  | 95 | + | 
|  | 96 | +	private static final String ALLOWED_ORIGINS_KEY = "allowedOrigins"; | 
|  | 97 | + | 
|  | 98 | +	@Mock | 
|  | 99 | +	private RoutePredicateHandlerMapping handlerMapping; | 
|  | 100 | + | 
|  | 101 | +	@Mock | 
|  | 102 | +	private RouteLocator routeLocator; | 
|  | 103 | + | 
|  | 104 | +	@Captor | 
|  | 105 | +	private ArgumentCaptor<Map<String, CorsConfiguration>> corsConfigurations; | 
|  | 106 | + | 
|  | 107 | +	private GlobalCorsProperties globalCorsProperties; | 
|  | 108 | + | 
|  | 109 | +	private CorsGatewayFilterApplicationListener listener; | 
|  | 110 | + | 
|  | 111 | +	@BeforeEach | 
|  | 112 | +	void setUp() { | 
|  | 113 | +		globalCorsProperties = new GlobalCorsProperties(); | 
|  | 114 | +		listener = new CorsGatewayFilterApplicationListener(globalCorsProperties, handlerMapping, routeLocator); | 
|  | 115 | +	} | 
|  | 116 | + | 
|  | 117 | +	@Test | 
|  | 118 | +	void testOnApplicationEvent_preservesInsertionOrder_withRealRoutes() { | 
|  | 119 | + | 
|  | 120 | +		globalCorsProperties.getCorsConfigurations().put(GLOBAL_PATH_1, createCorsConfig(ORIGIN_GLOBAL_1)); | 
|  | 121 | +		globalCorsProperties.getCorsConfigurations().put(GLOBAL_PATH_2, createCorsConfig(ORIGIN_GLOBAL_2)); | 
|  | 122 | + | 
|  | 123 | +		Route route1 = buildRoute(ROUTE_ID_1, ROUTE_PATH_1, ORIGIN_ROUTE_1); | 
|  | 124 | +		Route route2 = buildRoute(ROUTE_ID_2, ROUTE_PATH_2, ORIGIN_ROUTE_2); | 
|  | 125 | + | 
|  | 126 | +		when(routeLocator.getRoutes()).thenReturn(Flux.just(route1, route2)); | 
|  | 127 | + | 
|  | 128 | +		listener.onApplicationEvent(new RefreshRoutesResultEvent(this)); | 
|  | 129 | + | 
|  | 130 | +		Awaitility.await().atMost(Duration.ofSeconds(2)).untilAsserted(() -> { | 
|  | 131 | + | 
|  | 132 | +			verify(handlerMapping).setCorsConfigurations(corsConfigurations.capture()); | 
|  | 133 | + | 
|  | 134 | +			Map<String, CorsConfiguration> mergedCorsConfigurations = corsConfigurations.getValue(); | 
|  | 135 | +			assertThat(mergedCorsConfigurations.keySet()).containsExactly(ROUTE_PATH_1, ROUTE_PATH_2, GLOBAL_PATH_1, | 
|  | 136 | +					GLOBAL_PATH_2); | 
|  | 137 | +			assertThat(mergedCorsConfigurations.get(GLOBAL_PATH_1).getAllowedOrigins()) | 
|  | 138 | +				.containsExactly(ORIGIN_GLOBAL_1); | 
|  | 139 | +			assertThat(mergedCorsConfigurations.get(GLOBAL_PATH_2).getAllowedOrigins()) | 
|  | 140 | +				.containsExactly(ORIGIN_GLOBAL_2); | 
|  | 141 | +			assertThat(mergedCorsConfigurations.get(ROUTE_PATH_1).getAllowedOrigins()).containsExactly(ORIGIN_ROUTE_1); | 
|  | 142 | +			assertThat(mergedCorsConfigurations.get(ROUTE_PATH_2).getAllowedOrigins()).containsExactly(ORIGIN_ROUTE_2); | 
|  | 143 | +		}); | 
|  | 144 | +	} | 
|  | 145 | + | 
|  | 146 | +	private CorsConfiguration createCorsConfig(String origin) { | 
|  | 147 | + | 
|  | 148 | +		CorsConfiguration config = new CorsConfiguration(); | 
|  | 149 | +		config.setAllowedOrigins(List.of(origin)); | 
|  | 150 | +		return config; | 
|  | 151 | +	} | 
|  | 152 | + | 
|  | 153 | +	private Route buildRoute(String id, String path, String allowedOrigin) { | 
|  | 154 | + | 
|  | 155 | +		return Route.async() | 
|  | 156 | +			.id(id) | 
|  | 157 | +			.uri(ROUTE_URI) | 
|  | 158 | +			.predicate(new PathRoutePredicateFactory().apply(config -> config.setPatterns(List.of(path)))) | 
|  | 159 | +			.metadata(METADATA_KEY, Map.of(ALLOWED_ORIGINS_KEY, List.of(allowedOrigin))) | 
|  | 160 | +			.build(); | 
|  | 161 | +	} | 
|  | 162 | + | 
|  | 163 | +} | 
0 commit comments