Skip to content

Commit 06ef82e

Browse files
committed
Consistent type-based bean lookup for internal resolution paths
Includes additional tests for List/ObjectProvider dependencies. See gh-35101
1 parent 2e9e45e commit 06ef82e

File tree

3 files changed

+108
-23
lines changed

3 files changed

+108
-23
lines changed

spring-beans/src/main/java/org/springframework/beans/factory/support/DefaultListableBeanFactory.java

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ public void ifUnique(Consumer<T> dependencyConsumer) throws BeansException {
496496
@Override
497497
public Stream<T> stream() {
498498
return Arrays.stream(beanNamesForStream(requiredType, true, allowEagerInit))
499-
.map(name -> (T) getBean(name))
499+
.map(name -> (T) resolveBean(name, requiredType))
500500
.filter(bean -> !(bean instanceof NullBean));
501501
}
502502
@SuppressWarnings("unchecked")
@@ -508,7 +508,7 @@ public Stream<T> orderedStream() {
508508
}
509509
Map<String, T> matchingBeans = CollectionUtils.newLinkedHashMap(beanNames.length);
510510
for (String beanName : beanNames) {
511-
Object beanInstance = getBean(beanName);
511+
Object beanInstance = resolveBean(beanName, requiredType);
512512
if (!(beanInstance instanceof NullBean)) {
513513
matchingBeans.put(beanName, (T) beanInstance);
514514
}
@@ -521,7 +521,7 @@ public Stream<T> orderedStream() {
521521
public Stream<T> stream(Predicate<Class<?>> customFilter, boolean includeNonSingletons) {
522522
return Arrays.stream(beanNamesForStream(requiredType, includeNonSingletons, allowEagerInit))
523523
.filter(name -> customFilter.test(getType(name)))
524-
.map(name -> (T) getBean(name))
524+
.map(name -> (T) resolveBean(name, requiredType))
525525
.filter(bean -> !(bean instanceof NullBean));
526526
}
527527
@SuppressWarnings("unchecked")
@@ -534,7 +534,7 @@ public Stream<T> orderedStream(Predicate<Class<?>> customFilter, boolean include
534534
Map<String, T> matchingBeans = CollectionUtils.newLinkedHashMap(beanNames.length);
535535
for (String beanName : beanNames) {
536536
if (customFilter.test(getType(beanName))) {
537-
Object beanInstance = getBean(beanName);
537+
Object beanInstance = resolveBean(beanName, requiredType);
538538
if (!(beanInstance instanceof NullBean)) {
539539
matchingBeans.put(beanName, (T) beanInstance);
540540
}
@@ -1207,6 +1207,17 @@ private void instantiateSingleton(String beanName) {
12071207
}
12081208
}
12091209

1210+
private Object resolveBean(String beanName, ResolvableType requiredType) {
1211+
try {
1212+
// Need to provide required type for SmartFactoryBean
1213+
return getBean(beanName, requiredType.toClass());
1214+
}
1215+
catch (BeanNotOfRequiredTypeException ex) {
1216+
// Probably a null bean...
1217+
return getBean(beanName);
1218+
}
1219+
}
1220+
12101221
private static String getThreadNamePrefix() {
12111222
String name = Thread.currentThread().getName();
12121223
int numberSeparator = name.lastIndexOf('-');
@@ -1542,7 +1553,7 @@ else if (candidateNames.length > 1) {
15421553
Map<String, Object> candidates = CollectionUtils.newLinkedHashMap(candidateNames.length);
15431554
for (String beanName : candidateNames) {
15441555
if (containsSingleton(beanName) && args == null) {
1545-
Object beanInstance = getBean(beanName);
1556+
Object beanInstance = resolveBean(beanName, requiredType);
15461557
candidates.put(beanName, (beanInstance instanceof NullBean ? null : beanInstance));
15471558
}
15481559
else {
@@ -1659,7 +1670,7 @@ else if (descriptor.supportsLazyResolution()) {
16591670
if (autowiredBeanNames != null) {
16601671
autowiredBeanNames.add(dependencyName);
16611672
}
1662-
Object dependencyBean = getBean(dependencyName);
1673+
Object dependencyBean = resolveBean(dependencyName, descriptor.getResolvableType());
16631674
return resolveInstance(dependencyBean, descriptor, type, dependencyName);
16641675
}
16651676
}
@@ -2582,24 +2593,26 @@ private Stream<Object> resolveStream(boolean ordered) {
25822593

25832594
@Override
25842595
public Stream<Object> stream(Predicate<Class<?>> customFilter, boolean includeNonSingletons) {
2585-
return Arrays.stream(beanNamesForStream(this.descriptor.getResolvableType(), includeNonSingletons, true))
2596+
ResolvableType type = this.descriptor.getResolvableType();
2597+
return Arrays.stream(beanNamesForStream(type, includeNonSingletons, true))
25862598
.filter(name -> AutowireUtils.isAutowireCandidate(DefaultListableBeanFactory.this, name))
25872599
.filter(name -> customFilter.test(getType(name)))
2588-
.map(name -> getBean(name))
2600+
.map(name -> resolveBean(name, type))
25892601
.filter(bean -> !(bean instanceof NullBean));
25902602
}
25912603

25922604
@Override
25932605
public Stream<Object> orderedStream(Predicate<Class<?>> customFilter, boolean includeNonSingletons) {
2594-
String[] beanNames = beanNamesForStream(this.descriptor.getResolvableType(), includeNonSingletons, true);
2606+
ResolvableType type = this.descriptor.getResolvableType();
2607+
String[] beanNames = beanNamesForStream(type, includeNonSingletons, true);
25952608
if (beanNames.length == 0) {
25962609
return Stream.empty();
25972610
}
25982611
Map<String, Object> matchingBeans = CollectionUtils.newLinkedHashMap(beanNames.length);
25992612
for (String beanName : beanNames) {
26002613
if (AutowireUtils.isAutowireCandidate(DefaultListableBeanFactory.this, beanName) &&
26012614
customFilter.test(getType(beanName))) {
2602-
Object beanInstance = getBean(beanName);
2615+
Object beanInstance = resolveBean(beanName, type);
26032616
if (!(beanInstance instanceof NullBean)) {
26042617
matchingBeans.put(beanName, beanInstance);
26052618
}

spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ public <T> ObjectProvider<T> getBeanProvider(ResolvableType requiredType, boolea
307307
public T getObject() throws BeansException {
308308
String[] beanNames = getBeanNamesForType(requiredType);
309309
if (beanNames.length == 1) {
310-
return (T) getBean(beanNames[0], requiredType);
310+
return (T) getBean(beanNames[0], requiredType.toClass());
311311
}
312312
else if (beanNames.length > 1) {
313313
throw new NoUniqueBeanDefinitionException(requiredType, beanNames);
@@ -333,7 +333,7 @@ else if (beanNames.length > 1) {
333333
public @Nullable T getIfAvailable() throws BeansException {
334334
String[] beanNames = getBeanNamesForType(requiredType);
335335
if (beanNames.length == 1) {
336-
return (T) getBean(beanNames[0]);
336+
return (T) getBean(beanNames[0], requiredType.toClass());
337337
}
338338
else if (beanNames.length > 1) {
339339
throw new NoUniqueBeanDefinitionException(requiredType, beanNames);
@@ -346,15 +346,16 @@ else if (beanNames.length > 1) {
346346
public @Nullable T getIfUnique() throws BeansException {
347347
String[] beanNames = getBeanNamesForType(requiredType);
348348
if (beanNames.length == 1) {
349-
return (T) getBean(beanNames[0]);
349+
return (T) getBean(beanNames[0], requiredType.toClass());
350350
}
351351
else {
352352
return null;
353353
}
354354
}
355355
@Override
356356
public Stream<T> stream() {
357-
return Arrays.stream(getBeanNamesForType(requiredType)).map(name -> (T) getBean(name));
357+
return Arrays.stream(getBeanNamesForType(requiredType))
358+
.map(name -> (T) getBean(name, requiredType.toClass()));
358359
}
359360
};
360361
}

spring-beans/src/test/java/org/springframework/beans/factory/BeanFactoryUtilsTests.java

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,40 @@ void supportsMultipleTypesWithDefaultFactory() {
464464
lbf.registerSingleton("fb2", fb2);
465465
lbf.registerSingleton("sfb1", sfb1);
466466
lbf.registerSingleton("sfb2", sfb2);
467-
468-
testSupportsMultipleTypesWithStaticFactory(lbf);
467+
lbf.registerBeanDefinition("recipient",
468+
new RootBeanDefinition(Recipient.class, RootBeanDefinition.AUTOWIRE_CONSTRUCTOR, false));
469+
470+
Recipient recipient = lbf.getBean("recipient", Recipient.class);
471+
assertThat(recipient.sfb1).isSameAs(lbf.getBean("sfb1", TestBean.class));
472+
assertThat(recipient.sfb2).isSameAs(lbf.getBean("sfb2", TestBean.class));
473+
474+
List<ITestBean> testBeanList = recipient.testBeanList;
475+
assertThat(testBeanList).hasSize(5);
476+
assertThat(testBeanList.get(0)).isSameAs(bean);
477+
assertThat(testBeanList.get(1)).isSameAs(fb1.getObject());
478+
assertThat(testBeanList.get(2)).isInstanceOf(TestBean.class);
479+
assertThat(testBeanList.get(3)).isSameAs(lbf.getBean("sfb1", TestBean.class));
480+
assertThat(testBeanList.get(4)).isSameAs(lbf.getBean("sfb2", TestBean.class));
481+
482+
List<CharSequence> stringList = recipient.stringList;
483+
assertThat(stringList).hasSize(2);
484+
assertThat(stringList.get(0)).isSameAs(lbf.getBean("sfb1", String.class));
485+
assertThat(stringList.get(1)).isSameAs(lbf.getBean("sfb2", String.class));
486+
487+
testBeanList = recipient.testBeanProvider.stream().toList();
488+
assertThat(testBeanList).hasSize(5);
489+
assertThat(testBeanList.get(0)).isSameAs(bean);
490+
assertThat(testBeanList.get(1)).isSameAs(fb1.getObject());
491+
assertThat(testBeanList.get(2)).isInstanceOf(TestBean.class);
492+
assertThat(testBeanList.get(3)).isSameAs(lbf.getBean("sfb1", TestBean.class));
493+
assertThat(testBeanList.get(4)).isSameAs(lbf.getBean("sfb2", TestBean.class));
494+
495+
stringList = recipient.stringProvider.stream().toList();
496+
assertThat(stringList).hasSize(2);
497+
assertThat(stringList.get(0)).isSameAs(lbf.getBean("sfb1", String.class));
498+
assertThat(stringList.get(1)).isSameAs(lbf.getBean("sfb2", String.class));
499+
500+
testSupportsMultipleTypes(lbf);
469501
}
470502

471503
@Test
@@ -483,22 +515,35 @@ void supportsMultipleTypesWithStaticFactory() {
483515
lbf.addBean("sfb1", sfb1);
484516
lbf.addBean("sfb2", sfb2);
485517

486-
testSupportsMultipleTypesWithStaticFactory(lbf);
518+
testSupportsMultipleTypes(lbf);
487519
}
488520

489-
void testSupportsMultipleTypesWithStaticFactory(ListableBeanFactory lbf) {
521+
void testSupportsMultipleTypes(ListableBeanFactory lbf) {
522+
List<ITestBean> testBeanList = lbf.getBeanProvider(ITestBean.class).stream().toList();
523+
assertThat(testBeanList).hasSize(5);
524+
assertThat(testBeanList.get(0)).isSameAs(lbf.getBean("bean", TestBean.class));
525+
assertThat(testBeanList.get(1)).isSameAs(lbf.getBean("fb1", TestBean.class));
526+
assertThat(testBeanList.get(2)).isInstanceOf(TestBean.class);
527+
assertThat(testBeanList.get(3)).isSameAs(lbf.getBean("sfb1", TestBean.class));
528+
assertThat(testBeanList.get(4)).isSameAs(lbf.getBean("sfb2", TestBean.class));
529+
530+
List<CharSequence> stringList = lbf.getBeanProvider(CharSequence.class).stream().toList();
531+
assertThat(stringList).hasSize(2);
532+
assertThat(stringList.get(0)).isSameAs(lbf.getBean("sfb1", String.class));
533+
assertThat(stringList.get(1)).isSameAs(lbf.getBean("sfb2", String.class));
534+
490535
Map<String, ?> beans = BeanFactoryUtils.beansOfTypeIncludingAncestors(lbf, ITestBean.class);
491536
assertThat(beans).hasSize(5);
492537
assertThat(beans.get("bean")).isSameAs(lbf.getBean("bean"));
493-
assertThat(beans.get("fb1")).isSameAs(lbf.getBean("&fb1", DummyFactory.class).getObject());
538+
assertThat(beans.get("fb1")).isSameAs(lbf.getBean("fb1",TestBean.class));
494539
assertThat(beans.get("fb2")).isInstanceOf(TestBean.class);
495-
assertThat(beans.get("sfb1")).isInstanceOf(TestBean.class);
496-
assertThat(beans.get("sfb2")).isInstanceOf(TestBean.class);
540+
assertThat(beans.get("sfb1")).isSameAs(lbf.getBean("sfb1", TestBean.class));
541+
assertThat(beans.get("sfb2")).isSameAs(lbf.getBean("sfb2", TestBean.class));
497542

498543
beans = BeanFactoryUtils.beansOfTypeIncludingAncestors(lbf, CharSequence.class);
499544
assertThat(beans).hasSize(2);
500-
assertThat(beans.get("sfb1")).isInstanceOf(String.class);
501-
assertThat(beans.get("sfb2")).isInstanceOf(String.class);
545+
assertThat(beans.get("sfb1")).isSameAs(lbf.getBean("sfb1", String.class));
546+
assertThat(beans.get("sfb2")).isSameAs(lbf.getBean("sfb1", String.class));
502547

503548
assertThat(lbf.getBean("sfb1", ITestBean.class)).isInstanceOf(TestBean.class);
504549
assertThat(lbf.getBean("sfb2", ITestBean.class)).isInstanceOf(TestBean.class);
@@ -604,4 +649,30 @@ public boolean supportsType(Class<?> type) {
604649
}
605650
}
606651

652+
653+
static class Recipient {
654+
655+
public Recipient(ITestBean sfb1, ITestBean sfb2, List<ITestBean> testBeanList, List<CharSequence> stringList,
656+
ObjectProvider<ITestBean> testBeanProvider, ObjectProvider<CharSequence> stringProvider) {
657+
this.sfb1 = sfb1;
658+
this.sfb2 = sfb2;
659+
this.testBeanList = testBeanList;
660+
this.stringList = stringList;
661+
this.testBeanProvider = testBeanProvider;
662+
this.stringProvider = stringProvider;
663+
}
664+
665+
ITestBean sfb1;
666+
667+
ITestBean sfb2;
668+
669+
List<ITestBean> testBeanList;
670+
671+
List<CharSequence> stringList;
672+
673+
ObjectProvider<ITestBean> testBeanProvider;
674+
675+
ObjectProvider<CharSequence> stringProvider;
676+
}
677+
607678
}

0 commit comments

Comments
 (0)