Skip to content

Commit ab40236

Browse files
ari-spfmp911de
authored andcommitted
Prevent access to EntityManager when looking up PersistenceProvider.
Signed-off-by: Ariel Morelli Andres <[email protected]> Closes: #3425 Original pull request: #3885
1 parent 52a5e31 commit ab40236

File tree

12 files changed

+200
-16
lines changed

12 files changed

+200
-16
lines changed

spring-data-jpa/src/main/java/org/springframework/data/jpa/provider/PersistenceProvider.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
* @author Jens Schauder
5757
* @author Greg Turnquist
5858
* @author Yuriy Tsarkov
59+
* @author Ariel Morelli Andres (Atlassian US, Inc.)
5960
*/
6061
public enum PersistenceProvider implements QueryExtractor, ProxyIdAccessor, QueryComment {
6162

@@ -316,15 +317,15 @@ public static PersistenceProvider fromEntityManager(EntityManager em) {
316317
}
317318

318319
/**
319-
* Determines the {@link PersistenceProvider} from the given {@link EntityManager}. If no special one can be
320+
* Determines the {@link PersistenceProvider} from the given {@link EntityManagerFactory}. If no special one can be
320321
* determined {@link #GENERIC_JPA} will be returned.
321322
*
322323
* @param emf must not be {@literal null}.
323324
* @return will never be {@literal null}.
324325
*/
325326
public static PersistenceProvider fromEntityManagerFactory(EntityManagerFactory emf) {
326327

327-
Assert.notNull(emf, "EntityManager must not be null");
328+
Assert.notNull(emf, "EntityManagerFactory must not be null");
328329

329330
Class<?> entityManagerType = emf.getPersistenceUnitUtil().getClass();
330331
PersistenceProvider cachedProvider = CACHE.get(entityManagerType);

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/AbstractJpaQuery.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import jakarta.persistence.TypedQuery;
2525

2626
import java.lang.reflect.Constructor;
27-
import java.util.AbstractMap;
2827
import java.util.ArrayList;
2928
import java.util.List;
3029
import java.util.function.UnaryOperator;

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/support/Querydsl.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
import com.querydsl.jpa.JPQLTemplates;
4141
import com.querydsl.jpa.impl.AbstractJPAQuery;
4242
import com.querydsl.jpa.impl.JPAQuery;
43-
import org.jspecify.annotations.Nullable;
4443

4544
/**
4645
* Helper instance to ease access to Querydsl JPA query API.
@@ -87,7 +86,8 @@ public <T> AbstractJPAQuery<T, JPAQuery<T>> createQuery() {
8786
* Obtains the {@link JPQLTemplates} for the configured {@link EntityManager}. Can return {@literal null} to use the
8887
* default templates.
8988
*
90-
* @return the {@link JPQLTemplates} for the configured {@link EntityManager}, {@link JPQLTemplates#DEFAULT} by default.
89+
* @return the {@link JPQLTemplates} for the configured {@link EntityManager}, {@link JPQLTemplates#DEFAULT} by
90+
* default.
9191
* @since 3.5
9292
*/
9393
public JPQLTemplates getTemplates() {

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/CrudMethodMetadataUnitTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@
1717

1818
import static org.mockito.Mockito.*;
1919

20-
import java.util.Collections;
21-
import java.util.Map;
22-
2320
import jakarta.persistence.EntityManager;
2421
import jakarta.persistence.EntityManagerFactory;
2522
import jakarta.persistence.LockModeType;
@@ -28,6 +25,9 @@
2825
import jakarta.persistence.criteria.CriteriaQuery;
2926
import jakarta.persistence.metamodel.Metamodel;
3027

28+
import java.util.Collections;
29+
import java.util.Map;
30+
3131
import org.junit.jupiter.api.BeforeEach;
3232
import org.junit.jupiter.api.Test;
3333
import org.junit.jupiter.api.extension.ExtendWith;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright 2011-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.jpa.repository;
17+
18+
import java.util.Optional;
19+
20+
import org.hibernate.context.spi.CurrentTenantIdentifierResolver;
21+
import org.jspecify.annotations.Nullable;
22+
23+
/**
24+
* {@code CurrentTenantIdentifierResolver} instance for testing
25+
*
26+
* @author Ariel Morelli Andres (Atlassian US, Inc.)
27+
*/
28+
public class HibernateCurrentTenantIdentifierResolver implements CurrentTenantIdentifierResolver<String> {
29+
private static final ThreadLocal<@Nullable String> CURRENT_TENANT_IDENTIFIER = new ThreadLocal<>();
30+
31+
public static void setTenantIdentifier(String tenantIdentifier) {
32+
CURRENT_TENANT_IDENTIFIER.set(tenantIdentifier);
33+
}
34+
35+
public static void removeTenantIdentifier() {
36+
CURRENT_TENANT_IDENTIFIER.remove();
37+
}
38+
39+
@Override
40+
public String resolveCurrentTenantIdentifier() {
41+
return Optional.ofNullable(CURRENT_TENANT_IDENTIFIER.get())
42+
.orElseThrow(() -> new IllegalArgumentException("Could not resolve current tenant identifier"));
43+
}
44+
45+
@Override
46+
public boolean validateExistingCurrentSessions() {
47+
return true;
48+
}
49+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Copyright 2011-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.jpa.repository;
17+
18+
import static org.assertj.core.api.Assertions.*;
19+
import static org.assertj.core.api.Assumptions.*;
20+
21+
import java.util.List;
22+
23+
import org.junit.jupiter.api.AfterEach;
24+
import org.junit.jupiter.api.Test;
25+
import org.junit.jupiter.api.extension.ExtendWith;
26+
import org.springframework.beans.factory.annotation.Autowired;
27+
import org.springframework.context.annotation.ComponentScan;
28+
import org.springframework.context.annotation.Configuration;
29+
import org.springframework.context.annotation.FilterType;
30+
import org.springframework.context.annotation.ImportResource;
31+
import org.springframework.data.jpa.domain.sample.Role;
32+
import org.springframework.data.jpa.provider.PersistenceProvider;
33+
import org.springframework.data.jpa.repository.config.EnableJpaRepositories;
34+
import org.springframework.data.jpa.repository.sample.RoleRepository;
35+
import org.springframework.test.context.ContextConfiguration;
36+
import org.springframework.test.context.junit.jupiter.SpringExtension;
37+
import org.springframework.transaction.annotation.Transactional;
38+
39+
import jakarta.persistence.EntityManager;
40+
41+
/**
42+
* Tests for repositories that use multi-tenancy. This tests verifies that repositories can be created an injected
43+
* despite not having a tenant available at creation time
44+
*
45+
* @author Ariel Morelli Andres (Atlassian US, Inc.)
46+
*/
47+
@ExtendWith(SpringExtension.class)
48+
@ContextConfiguration()
49+
class HibernateMultitenancyTests {
50+
51+
@Autowired RoleRepository roleRepository;
52+
@Autowired EntityManager em;
53+
54+
@AfterEach
55+
void tearDown() {
56+
HibernateCurrentTenantIdentifierResolver.removeTenantIdentifier();
57+
}
58+
59+
@Test
60+
void testPersistenceProviderFromFactoryWithoutTenant() {
61+
PersistenceProvider provider = PersistenceProvider.fromEntityManagerFactory(em.getEntityManagerFactory());
62+
assumeThat(provider).isEqualTo(PersistenceProvider.HIBERNATE);
63+
}
64+
65+
@Test
66+
void testRepositoryWithTenant() {
67+
HibernateCurrentTenantIdentifierResolver.setTenantIdentifier("tenant-id");
68+
assertThatNoException().isThrownBy(() -> roleRepository.findAll());
69+
}
70+
71+
@Test
72+
void testRepositoryWithoutTenantFails() {
73+
assertThatThrownBy(() -> roleRepository.findAll()).isInstanceOf(RuntimeException.class);
74+
}
75+
76+
@Transactional
77+
List<Role> insertAndQuery() {
78+
roleRepository.save(new Role("DRUMMER"));
79+
roleRepository.flush();
80+
return roleRepository.findAll();
81+
}
82+
83+
@ImportResource({ "classpath:multitenancy-test.xml" })
84+
@Configuration
85+
@EnableJpaRepositories(basePackageClasses = HibernateRepositoryTests.class, considerNestedRepositories = true,
86+
includeFilters = @ComponentScan.Filter(classes = { RoleRepository.class }, type = FilterType.ASSIGNABLE_TYPE))
87+
static class TestConfig {}
88+
}

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/AbstractStringBasedJpaQueryUnitTests.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import static org.mockito.Mockito.*;
1919

2020
import jakarta.persistence.EntityManager;
21+
import jakarta.persistence.EntityManagerFactory;
2122
import jakarta.persistence.metamodel.Metamodel;
2223

2324
import java.lang.reflect.Method;
@@ -52,6 +53,7 @@
5253
*
5354
* @author Christoph Strobl
5455
* @author Mark Paluch
56+
* @author Ariel Morelli Andres
5557
*/
5658
class AbstractStringBasedJpaQueryUnitTests {
5759

@@ -137,10 +139,12 @@ static class InvocationCapturingStringQueryStub extends AbstractStringBasedJpaQu
137139
public EntityManager get() {
138140

139141
EntityManager em = Mockito.mock(EntityManager.class);
142+
EntityManagerFactory emf = Mockito.mock(EntityManagerFactory.class);
140143

141144
Metamodel meta = mock(Metamodel.class);
142145
when(em.getMetamodel()).thenReturn(meta);
143146
when(em.getDelegate()).thenReturn(new Object()); // some generic jpa
147+
when(em.getEntityManagerFactory()).thenReturn(emf);
144148

145149
return em;
146150
}

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/NamedQueryUnitTests.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import org.springframework.data.domain.Page;
3737
import org.springframework.data.domain.Pageable;
3838
import org.springframework.data.jpa.provider.QueryExtractor;
39-
import org.springframework.data.jpa.repository.QueryRewriter;
4039
import org.springframework.data.projection.ProjectionFactory;
4140
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
4241
import org.springframework.data.repository.core.RepositoryMetadata;

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/NativeJpaQueryUnitTests.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,9 @@ void shouldApplySorting() {
7171
queryExtractor);
7272

7373
NativeJpaQuery query = new NativeJpaQuery(queryMethod, em, queryMethod.getRequiredDeclaredQuery(),
74-
queryMethod.getDeclaredCountQuery(),
75-
new JpaQueryConfiguration(QueryRewriterProvider.simple(), QueryEnhancerSelector.DEFAULT_SELECTOR,
76-
ValueExpressionDelegate.create(), EscapeCharacter.DEFAULT));
77-
QueryProvider sql = query.getSortedQuery(Sort.by("foo", "bar"),
78-
queryMethod.getResultProcessor().getReturnedType());
74+
queryMethod.getDeclaredCountQuery(), new JpaQueryConfiguration(QueryRewriterProvider.simple(),
75+
QueryEnhancerSelector.DEFAULT_SELECTOR, ValueExpressionDelegate.create(), EscapeCharacter.DEFAULT));
76+
QueryProvider sql = query.getSortedQuery(Sort.by("foo", "bar"), queryMethod.getResultProcessor().getReturnedType());
7977

8078
assertThat(sql.getQueryString()).isEqualTo("SELECT e FROM Employee e order by e.foo asc, e.bar asc");
8179
}

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/support/JpaRepositoryFragmentsContributorUnitTests.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import static org.mockito.Mockito.*;
2020

2121
import jakarta.persistence.EntityManager;
22+
import jakarta.persistence.EntityManagerFactory;
2223

2324
import java.util.Iterator;
2425

@@ -40,6 +41,7 @@
4041
* Unit tests for {@link JpaRepositoryFragmentsContributor}.
4142
*
4243
* @author Mark Paluch
44+
* @author Ariel Morelli Andres
4345
*/
4446
class JpaRepositoryFragmentsContributorUnitTests {
4547

@@ -53,7 +55,9 @@ void composedContributorShouldCreateFragments() {
5355
when(entityPathResolver.createPath(any())).thenReturn((EntityPath) QCustomer.customer);
5456

5557
EntityManager entityManager = mock(EntityManager.class);
58+
EntityManagerFactory emf = mock(EntityManagerFactory.class);
5659
when(entityManager.getDelegate()).thenReturn(entityManager);
60+
when(entityManager.getEntityManagerFactory()).thenReturn(emf);
5761

5862
RepositoryComposition.RepositoryFragments fragments = contributor.contribute(
5963
AbstractRepositoryMetadata.getMetadata(QuerydslUserRepository.class),

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/support/SimpleJpaRepositoryUnitTests.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import java.lang.reflect.Method;
3232
import java.lang.reflect.Modifier;
3333
import java.util.Arrays;
34-
import java.util.Optional;
3534
import java.util.stream.Stream;
3635

3736
import org.junit.jupiter.api.BeforeEach;
@@ -44,7 +43,6 @@
4443
import org.mockito.junit.jupiter.MockitoExtension;
4544
import org.mockito.junit.jupiter.MockitoSettings;
4645
import org.mockito.quality.Strictness;
47-
4846
import org.springframework.data.domain.PageRequest;
4947
import org.springframework.data.jpa.domain.Specification;
5048
import org.springframework.data.jpa.domain.sample.User;
@@ -61,6 +59,7 @@
6159
* @author Jens Schauder
6260
* @author Greg Turnquist
6361
* @author Yanming Zhou
62+
* @author Ariel Morelli Andres (Atlassian US, Inc.)
6463
*/
6564
@ExtendWith(MockitoExtension.class)
6665
@MockitoSettings(strictness = Strictness.LENIENT)
@@ -85,6 +84,9 @@ class SimpleJpaRepositoryUnitTests {
8584
void setUp() {
8685

8786
when(em.getDelegate()).thenReturn(em);
87+
when(em.getEntityManagerFactory()).thenReturn(entityManagerFactory);
88+
89+
when(entityManagerFactory.getPersistenceUnitUtil()).thenReturn(persistenceUnitUtil);
8890

8991
when(information.getJavaType()).thenReturn(User.class);
9092
when(em.getCriteriaBuilder()).thenReturn(builder);
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<beans xmlns="http://www.springframework.org/schema/beans"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xmlns:jdbc="http://www.springframework.org/schema/jdbc"
5+
xsi:schemaLocation="http://www.springframework.org/schema/jdbc https://www.springframework.org/schema/jdbc/spring-jdbc.xsd
6+
http://www.springframework.org/schema/beans https://www.springframework.org/schema/beans/spring-beans.xsd">
7+
8+
<import resource="hibernate.xml" />
9+
10+
<bean id="entityManagerFactory"
11+
class="org.springframework.orm.jpa.LocalContainerEntityManagerFactoryBean">
12+
<property name="dataSource" ref="dataSource" />
13+
<property name="persistenceUnitName" value="spring-data-jpa" />
14+
<property name="jpaVendorAdapter" ref="vendorAdaptor" />
15+
<property name="jpaProperties">
16+
<props>
17+
<prop key="hibernate.tenant_identifier_resolver">
18+
org.springframework.data.jpa.repository.HibernateCurrentTenantIdentifierResolver
19+
</prop>
20+
</props>
21+
</property>
22+
</bean>
23+
24+
<bean id="abstractVendorAdaptor" abstract="true">
25+
<property name="generateDdl" value="true" />
26+
<property name="database" value="HSQL" />
27+
</bean>
28+
29+
<bean id="transactionManager" class="org.springframework.orm.jpa.JpaTransactionManager">
30+
<property name="entityManagerFactory" ref="entityManagerFactory" />
31+
</bean>
32+
33+
<bean name="sampleEvaluationContextExtension" class="org.springframework.data.jpa.repository.sample.SampleEvaluationContextExtension"/>
34+
35+
<jdbc:embedded-database id="dataSource" type="HSQL" generate-name="true">
36+
<jdbc:script execution="INIT" separator="/;" location="classpath:scripts/hsqldb-init.sql"/>
37+
<jdbc:script execution="INIT" separator="/;" location="classpath:scripts/schema-stored-procedures.sql"/>
38+
</jdbc:embedded-database>
39+
40+
</beans>

0 commit comments

Comments
 (0)