1717package org .springframework .test .context .bean .override ;
1818
1919import java .lang .annotation .Annotation ;
20+ import java .lang .reflect .AnnotatedElement ;
2021import java .lang .reflect .Field ;
2122import java .lang .reflect .Modifier ;
23+ import java .util .ArrayList ;
2224import java .util .Arrays ;
2325import java .util .Collections ;
2426import java .util .HashSet ;
25- import java .util .LinkedList ;
2627import java .util .List ;
2728import java .util .Objects ;
2829import java .util .Set ;
2930import java .util .concurrent .atomic .AtomicBoolean ;
31+ import java .util .function .BiConsumer ;
32+ import java .util .function .Predicate ;
3033
3134import org .springframework .beans .BeanUtils ;
3235import org .springframework .beans .factory .config .BeanDefinition ;
5760 *
5861 * <p>Concrete implementations of {@code BeanOverrideHandler} can store additional
5962 * metadata to use during override {@linkplain #createOverrideInstance instance
60- * creation} — for example, based on further processing of the annotation
61- * or the annotated field .
63+ * creation} — for example, based on further processing of the annotation,
64+ * the annotated field, or the annotated class .
6265 *
6366 * <p><strong>NOTE</strong>: Only <em>singleton</em> beans can be overridden.
6467 * Any attempt to override a non-singleton bean will result in an exception.
7073 */
7174public abstract class BeanOverrideHandler {
7275
76+ @ Nullable
7377 private final Field field ;
7478
7579 private final Set <Annotation > fieldAnnotations ;
@@ -82,7 +86,7 @@ public abstract class BeanOverrideHandler {
8286 private final BeanOverrideStrategy strategy ;
8387
8488
85- protected BeanOverrideHandler (Field field , ResolvableType beanType , @ Nullable String beanName ,
89+ protected BeanOverrideHandler (@ Nullable Field field , ResolvableType beanType , @ Nullable String beanName ,
8690 BeanOverrideStrategy strategy ) {
8791
8892 this .field = field ;
@@ -96,57 +100,115 @@ protected BeanOverrideHandler(Field field, ResolvableType beanType, @Nullable St
96100 * Process the given {@code testClass} and build the corresponding
97101 * {@code BeanOverrideHandler} list derived from {@link BeanOverride @BeanOverride}
98102 * fields in the test class and its type hierarchy.
99- * <p>This method does not search the enclosing class hierarchy.
103+ * <p>This method does not search the enclosing class hierarchy and does not
104+ * search for {@code @BeanOverride} declarations on classes or interfaces.
100105 * @param testClass the test class to process
101106 * @return a list of bean override handlers
102- * @see org.springframework.test.context.TestContextAnnotationUtils#searchEnclosingClass (Class)
107+ * @see #forTestClass (Class, Predicate )
103108 */
104109 public static List <BeanOverrideHandler > forTestClass (Class <?> testClass ) {
105- List <BeanOverrideHandler > handlers = new LinkedList <>();
106- findHandlers (testClass , testClass , handlers );
110+ return forTestClass (testClass , false , clazz -> false );
111+ }
112+
113+ /**
114+ * Process the given {@code testClass} and build the corresponding
115+ * {@code BeanOverrideHandler} list derived from {@link BeanOverride @BeanOverride}
116+ * fields in the test class and in its type hierarchy as well as from
117+ * {@code @BeanOverride} declarations on classes and interfaces.
118+ * <p>This method additionally searches for {@code @BeanOverride} declarations
119+ * in the enclosing class hierarchy if the supplied predicate evaluates to
120+ * {@code true}.
121+ * @param testClass the test class to process
122+ * @param searchEnclosingClass a predicate which evaluates to {@code true}
123+ * if a search should be performed on the enclosing class — for example,
124+ * {@code TestContextAnnotationUtils::searchEnclosingClass}
125+ * @return a list of bean override handlers
126+ * @since 6.2.2
127+ * @see org.springframework.test.context.TestContextAnnotationUtils#searchEnclosingClass(Class)
128+ */
129+ public static List <BeanOverrideHandler > forTestClass (Class <?> testClass , Predicate <Class <?>> searchEnclosingClass ) {
130+ return forTestClass (testClass , true , searchEnclosingClass );
131+ }
132+
133+ private static List <BeanOverrideHandler > forTestClass (Class <?> testClass , boolean searchOnTypes ,
134+ Predicate <Class <?>> searchEnclosingClass ) {
135+
136+ List <BeanOverrideHandler > handlers = new ArrayList <>();
137+ findHandlers (testClass , testClass , handlers , searchOnTypes , searchEnclosingClass );
107138 return handlers ;
108139 }
109140
110141 /**
111- * Find handlers using tail recursion to ensure that "locally declared"
112- * bean overrides take precedence over inherited bean overrides.
142+ * Find handlers using tail recursion to ensure that "locally declared" bean overrides
143+ * take precedence over inherited bean overrides.
144+ * <p>Note: the search algorithm is effectively the inverse of the algorithm used in
145+ * {@link org.springframework.test.context.TestContextAnnotationUtils#findAnnotationDescriptor(Class, Class)},
146+ * but with tail recursion the semantics should be the same.
113147 * @since 6.2.2
114148 */
115- private static void findHandlers (Class <?> clazz , Class <?> testClass , List <BeanOverrideHandler > handlers ) {
116- if (clazz == null || clazz == Object .class ) {
117- return ;
149+ private static void findHandlers (Class <?> clazz , Class <?> testClass , List <BeanOverrideHandler > handlers ,
150+ boolean searchOnTypes , Predicate <Class <?>> searchEnclosingClass ) {
151+
152+ // 1) Search enclosing class hierarchy.
153+ if (searchEnclosingClass .test (clazz )) {
154+ findHandlers (clazz .getEnclosingClass (), testClass , handlers , searchOnTypes , searchEnclosingClass );
118155 }
119156
120- // 1) Search type hierarchy.
121- findHandlers (clazz .getSuperclass (), testClass , handlers );
157+ // 2) Search class hierarchy.
158+ Class <?> superclass = clazz .getSuperclass ();
159+ if (superclass != null && superclass != Object .class ) {
160+ findHandlers (superclass , testClass , handlers , searchOnTypes , searchEnclosingClass );
161+ }
122162
123- // 2) Process fields in current class.
163+ // 3) Search interfaces.
164+ for (Class <?> ifc : clazz .getInterfaces ()) {
165+ findHandlers (ifc , testClass , handlers , searchOnTypes , searchEnclosingClass );
166+ }
167+
168+ // 4) Process current class.
169+ if (searchOnTypes ) {
170+ processClass (clazz , testClass , handlers );
171+ }
172+
173+ // 5) Process fields in current class.
124174 ReflectionUtils .doWithLocalFields (clazz , field -> processField (field , testClass , handlers ));
125175 }
126176
177+ private static void processClass (Class <?> clazz , Class <?> testClass , List <BeanOverrideHandler > handlers ) {
178+ processElement (clazz , testClass , (processor , composedAnnotation ) ->
179+ processor .createHandlers (composedAnnotation , testClass ).forEach (handlers ::add ));
180+ }
181+
127182 private static void processField (Field field , Class <?> testClass , List <BeanOverrideHandler > handlers ) {
128183 AtomicBoolean overrideAnnotationFound = new AtomicBoolean ();
129- MergedAnnotations . from (field , DIRECT ). stream ( BeanOverride . class ). forEach ( mergedAnnotation -> {
184+ processElement (field , testClass , ( processor , composedAnnotation ) -> {
130185 Assert .state (!Modifier .isStatic (field .getModifiers ()),
131186 () -> "@BeanOverride field must not be static: " + field );
187+ Assert .state (overrideAnnotationFound .compareAndSet (false , true ),
188+ () -> "Multiple @BeanOverride annotations found on field: " + field );
189+ handlers .add (processor .createHandler (composedAnnotation , testClass , field ));
190+ });
191+ }
192+
193+ private static void processElement (AnnotatedElement element , Class <?> testClass ,
194+ BiConsumer <BeanOverrideProcessor , Annotation > consumer ) {
195+
196+ MergedAnnotations .from (element , DIRECT ).stream (BeanOverride .class ).forEach (mergedAnnotation -> {
132197 MergedAnnotation <?> metaSource = mergedAnnotation .getMetaSource ();
133198 Assert .state (metaSource != null , "@BeanOverride annotation must be meta-present" );
134199
135200 BeanOverride beanOverride = mergedAnnotation .synthesize ();
136201 BeanOverrideProcessor processor = BeanUtils .instantiateClass (beanOverride .value ());
137202 Annotation composedAnnotation = metaSource .synthesize ();
138-
139- Assert .state (overrideAnnotationFound .compareAndSet (false , true ),
140- () -> "Multiple @BeanOverride annotations found on field: " + field );
141- BeanOverrideHandler handler = processor .createHandler (composedAnnotation , testClass , field );
142- handlers .add (handler );
203+ consumer .accept (processor , composedAnnotation );
143204 });
144205 }
145206
146207
147208 /**
148209 * Get the annotated {@link Field}.
149210 */
211+ @ Nullable
150212 public final Field getField () {
151213 return this .field ;
152214 }
@@ -243,20 +305,23 @@ public boolean equals(Object other) {
243305 !Objects .equals (this .strategy , that .strategy )) {
244306 return false ;
245307 }
308+
309+ // by-name lookup
246310 if (this .beanName != null ) {
247311 return true ;
248312 }
249313
250314 // by-type lookup
251- return (Objects .equals (this .field .getName (), that .field .getName ()) &&
315+ return (this .field != null && that .field != null &&
316+ Objects .equals (this .field .getName (), that .field .getName ()) &&
252317 this .fieldAnnotations .equals (that .fieldAnnotations ));
253318 }
254319
255320 @ Override
256321 public int hashCode () {
257322 int hash = Objects .hash (getClass (), this .beanType .getType (), this .beanName , this .strategy );
258- return (this .beanName != null ? hash : hash +
259- Objects .hash (this .field . getName (), this .fieldAnnotations ));
323+ return (this .beanName != null ? hash :
324+ hash + Objects .hash (( this .field != null ? this . field . getName () : null ), this .fieldAnnotations ));
260325 }
261326
262327 @ Override
@@ -269,7 +334,10 @@ public String toString() {
269334 .toString ();
270335 }
271336
272- private static Set <Annotation > annotationSet (Field field ) {
337+ private static Set <Annotation > annotationSet (@ Nullable Field field ) {
338+ if (field == null ) {
339+ return Collections .emptySet ();
340+ }
273341 Annotation [] annotations = field .getAnnotations ();
274342 return (annotations .length != 0 ? new HashSet <>(Arrays .asList (annotations )) : Collections .emptySet ());
275343 }
0 commit comments