Skip to content

Commit 1970e2d

Browse files
author
chao.wang
committed
Add JdbcRelyingPartyRegistrationRepository
Closes spring-projectsgh-16012
1 parent 56e757a commit 1970e2d

File tree

7 files changed

+914
-0
lines changed

7 files changed

+914
-0
lines changed

saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ dependencies {
106106
provided 'jakarta.servlet:jakarta.servlet-api'
107107

108108
optional 'com.fasterxml.jackson.core:jackson-databind'
109+
optional 'org.springframework:spring-jdbc'
109110

110111
testImplementation 'com.squareup.okhttp3:mockwebserver'
111112
testImplementation "org.assertj:assertj-core"
@@ -118,6 +119,7 @@ dependencies {
118119
testImplementation "org.springframework:spring-test"
119120

120121
testRuntimeOnly 'org.junit.platform:junit-platform-launcher'
122+
testRuntimeOnly 'org.hsqldb:hsqldb'
121123
}
122124

123125
jar {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,377 @@
1+
/*
2+
* Copyright 2002-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.saml2.provider.service.registration;
18+
19+
import java.nio.charset.StandardCharsets;
20+
import java.security.cert.X509Certificate;
21+
import java.security.interfaces.RSAPrivateKey;
22+
import java.sql.ResultSet;
23+
import java.sql.SQLException;
24+
import java.sql.Types;
25+
import java.util.ArrayList;
26+
import java.util.Collection;
27+
import java.util.Iterator;
28+
import java.util.List;
29+
import java.util.function.Consumer;
30+
31+
import com.fasterxml.jackson.core.type.TypeReference;
32+
import com.fasterxml.jackson.databind.ObjectMapper;
33+
import org.cryptacular.util.KeyPairUtil;
34+
import org.opensaml.security.x509.X509Support;
35+
36+
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
37+
import org.springframework.jdbc.core.JdbcOperations;
38+
import org.springframework.jdbc.core.PreparedStatementSetter;
39+
import org.springframework.jdbc.core.RowMapper;
40+
import org.springframework.jdbc.core.SqlParameterValue;
41+
import org.springframework.jdbc.support.lob.DefaultLobHandler;
42+
import org.springframework.jdbc.support.lob.LobHandler;
43+
import org.springframework.security.saml2.core.Saml2X509Credential;
44+
import org.springframework.util.Assert;
45+
import org.springframework.util.StringUtils;
46+
47+
public class JdbcRelyingPartyRegistrationRepository implements IterableRelyingPartyRegistrationRepository {
48+
49+
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
50+
51+
// @formatter:off
52+
static final String COLUMN_NAMES = "id, "
53+
+ "entity_id, "
54+
+ "name_id_format, "
55+
+ "acs_location, "
56+
+ "acs_binding, "
57+
+ "signing_credentials, "
58+
+ "decryption_credentials, "
59+
+ "singlelogout_url, "
60+
+ "singlelogout_response_url, "
61+
+ "singlelogout_binding, "
62+
+ "assertingparty_entity_id, "
63+
+ "assertingparty_metadata_uri, "
64+
+ "assertingparty_singlesignon_url, "
65+
+ "assertingparty_singlesignon_binding, "
66+
+ "assertingparty_singlesignon_sign_request, "
67+
+ "assertingparty_verification_credentials, "
68+
+ "assertingparty_singlelogout_url, "
69+
+ "assertingparty_singlelogout_response_url, "
70+
+ "assertingparty_singlelogout_binding"
71+
;
72+
// @formatter:on
73+
74+
private static final String TABLE_NAME = "saml2_relying_party_registration";
75+
76+
private static final String PK_FILTER = "id = ?";
77+
78+
private static final String ENTITY_ID_FILTER = "entity_id = ?";
79+
80+
// @formatter:off
81+
private static final String LOAD_BY_ID_SQL = "SELECT " + COLUMN_NAMES
82+
+ " FROM " + TABLE_NAME
83+
+ " WHERE " + PK_FILTER;
84+
85+
private static final String LOAD_BY_ENTITY_ID_SQL = "SELECT " + COLUMN_NAMES
86+
+ " FROM " + TABLE_NAME
87+
+ " WHERE " + ENTITY_ID_FILTER;
88+
89+
private static final String LOAD_ALL_SQL = "SELECT " + COLUMN_NAMES
90+
+ " FROM " + TABLE_NAME;
91+
// @formatter:on
92+
93+
protected final JdbcOperations jdbcOperations;
94+
95+
protected RowMapper<RelyingPartyRegistration> relyingPartyRegistrationRowMapper;
96+
97+
protected final LobHandler lobHandler;
98+
99+
/**
100+
* Constructs a {@code JdbcOAuth2AuthorizedClientService} using the provided
101+
* parameters.
102+
* @param jdbcOperations the JDBC operations
103+
*/
104+
public JdbcRelyingPartyRegistrationRepository(JdbcOperations jdbcOperations) {
105+
this(jdbcOperations, new DefaultLobHandler());
106+
}
107+
108+
/**
109+
* Constructs a {@code JdbcRelyingPartyRegistrationRepository} using the provided
110+
* parameters.
111+
* @param jdbcOperations the JDBC operations
112+
* @param lobHandler the handler for large binary fields and large text fields
113+
*/
114+
public JdbcRelyingPartyRegistrationRepository(JdbcOperations jdbcOperations, LobHandler lobHandler) {
115+
Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
116+
Assert.notNull(lobHandler, "lobHandler cannot be null");
117+
this.jdbcOperations = jdbcOperations;
118+
this.lobHandler = lobHandler;
119+
RelyingPartyRegistrationRowMapper rowMapper = new RelyingPartyRegistrationRowMapper();
120+
rowMapper.setLobHandler(lobHandler);
121+
this.relyingPartyRegistrationRowMapper = rowMapper;
122+
}
123+
124+
/**
125+
* Sets the {@link RowMapper} used for mapping the current row in
126+
* {@code java.sql.ResultSet} to {@link RelyingPartyRegistration}. The default is
127+
* {@link RelyingPartyRegistrationRowMapper}.
128+
* @param relyingPartyRegistrationRowMapper the {@link RowMapper} used for mapping the
129+
* current row in {@code java.sql.ResultSet} to {@link RelyingPartyRegistration}
130+
*/
131+
public final void setAuthorizedClientRowMapper(
132+
RowMapper<RelyingPartyRegistration> relyingPartyRegistrationRowMapper) {
133+
Assert.notNull(relyingPartyRegistrationRowMapper, "relyingPartyRegistrationRowMapper cannot be null");
134+
this.relyingPartyRegistrationRowMapper = relyingPartyRegistrationRowMapper;
135+
}
136+
137+
@Override
138+
public RelyingPartyRegistration findByRegistrationId(String registrationId) {
139+
Assert.hasText(registrationId, "registrationId cannot be empty");
140+
SqlParameterValue[] parameters = new SqlParameterValue[] {
141+
new SqlParameterValue(Types.VARCHAR, registrationId) };
142+
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
143+
List<RelyingPartyRegistration> result = this.jdbcOperations.query(LOAD_BY_ID_SQL, pss,
144+
this.relyingPartyRegistrationRowMapper);
145+
return !result.isEmpty() ? result.get(0) : null;
146+
}
147+
148+
@Override
149+
public RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) {
150+
Assert.hasText(entityId, "entityId cannot be empty");
151+
SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, entityId) };
152+
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
153+
List<RelyingPartyRegistration> result = this.jdbcOperations.query(LOAD_BY_ENTITY_ID_SQL, pss,
154+
this.relyingPartyRegistrationRowMapper);
155+
return !result.isEmpty() ? result.get(0) : null;
156+
}
157+
158+
@Override
159+
public Iterator<RelyingPartyRegistration> iterator() {
160+
List<RelyingPartyRegistration> result = this.jdbcOperations.query(LOAD_ALL_SQL,
161+
this.relyingPartyRegistrationRowMapper);
162+
return result.iterator();
163+
}
164+
165+
/**
166+
* The default {@link RowMapper} that maps the current row in
167+
* {@code java.sql.ResultSet} to {@link RelyingPartyRegistration}.
168+
*/
169+
public static class RelyingPartyRegistrationRowMapper implements RowMapper<RelyingPartyRegistration> {
170+
171+
protected LobHandler lobHandler = new DefaultLobHandler();
172+
173+
public final void setLobHandler(LobHandler lobHandler) {
174+
Assert.notNull(lobHandler, "lobHandler cannot be null");
175+
this.lobHandler = lobHandler;
176+
}
177+
178+
@Override
179+
public RelyingPartyRegistration mapRow(ResultSet rs, int rowNum) throws SQLException {
180+
String registrationId = rs.getString("id");
181+
String entityId = StringUtils.hasText(rs.getString("entity_id")) ? rs.getString("entity_id")
182+
: "{baseUrl}/saml2/service-provider-metadata/{registrationId}";
183+
String nameIdFormat = rs.getString("name_id_format");
184+
String acsLocation = StringUtils.hasText(rs.getString("acs_location")) ? rs.getString("acs_location")
185+
: "{baseUrl}/login/saml2/sso/{registrationId}";
186+
String acsBinding = StringUtils.hasText(rs.getString("acs_binding")) ? rs.getString("acs_binding")
187+
: Saml2MessageBinding.POST.getUrn();
188+
List<Credential> signingCredentials = parseCredentials(getLobValue(rs, "signing_credentials"));
189+
List<Credential> decryptionCredentials = parseCredentials(getLobValue(rs, "decryption_credentials"));
190+
String singleLogoutUrl = rs.getString("singlelogout_url");
191+
String singleLogoutResponseUrl = rs.getString("singlelogout_response_url");
192+
Saml2MessageBinding singleLogoutBinding = Saml2MessageBinding.from(rs.getString("singlelogout_binding"));
193+
String assertingPartyEntityId = rs.getString("assertingparty_entity_id");
194+
String assertingPartyMetadataUri = rs.getString("assertingparty_metadata_uri");
195+
String assertingPartySingleSignOnUrl = rs.getString("assertingparty_singlesignon_url");
196+
Saml2MessageBinding assertingPartySingleSignOnBinding = Saml2MessageBinding
197+
.from(rs.getString("assertingparty_singlesignon_binding"));
198+
Boolean assertingPartySingleSignOnSignRequest = rs.getBoolean("assertingparty_singlesignon_sign_request");
199+
List<Credential> assertingPartyVerificationCredentials = parseCredentials(
200+
getLobValue(rs, "assertingparty_verification_credentials"));
201+
String assertingPartySingleLogoutUrl = rs.getString("assertingparty_singlelogout_url");
202+
String assertingPartySingleLogoutResponseUrl = rs.getString("assertingparty_singlelogout_response_url");
203+
Saml2MessageBinding assertingPartySingleLogoutBinding = Saml2MessageBinding
204+
.from(rs.getString("assertingparty_singlelogout_binding"));
205+
206+
boolean usingMetadata = StringUtils.hasText(assertingPartyMetadataUri);
207+
RelyingPartyRegistration.Builder builder = (!usingMetadata)
208+
? RelyingPartyRegistration.withRegistrationId(registrationId)
209+
: createBuilderUsingMetadata(assertingPartyEntityId, assertingPartyMetadataUri)
210+
.registrationId(registrationId);
211+
builder.assertionConsumerServiceLocation(acsLocation);
212+
builder.assertionConsumerServiceBinding(Saml2MessageBinding.from(acsBinding));
213+
builder.assertingPartyMetadata(mapAssertingParty(assertingPartyEntityId, assertingPartySingleSignOnBinding,
214+
assertingPartySingleSignOnUrl, assertingPartySingleSignOnSignRequest,
215+
assertingPartySingleLogoutBinding, assertingPartySingleLogoutResponseUrl,
216+
assertingPartySingleLogoutUrl));
217+
builder.signingX509Credentials((credentials) -> signingCredentials.stream()
218+
.map(this::asSigningCredential)
219+
.forEach(credentials::add));
220+
builder.decryptionX509Credentials((credentials) -> decryptionCredentials.stream()
221+
.map(this::asDecryptionCredential)
222+
.forEach(credentials::add));
223+
builder.assertingPartyMetadata((details) -> details
224+
.verificationX509Credentials((credentials) -> assertingPartyVerificationCredentials.stream()
225+
.map(this::asVerificationCredential)
226+
.forEach(credentials::add)));
227+
builder.singleLogoutServiceLocation(singleLogoutUrl);
228+
builder.singleLogoutServiceResponseLocation(singleLogoutResponseUrl);
229+
builder.singleLogoutServiceBinding(singleLogoutBinding);
230+
builder.entityId(entityId);
231+
builder.nameIdFormat(nameIdFormat);
232+
RelyingPartyRegistration registration = builder.build();
233+
boolean signRequest = registration.getAssertingPartyMetadata().getWantAuthnRequestsSigned();
234+
if (signRequest) {
235+
Assert.state(!signingCredentials.isEmpty(),
236+
"Signing credentials must not be empty when authentication requests require signing.");
237+
}
238+
return registration;
239+
}
240+
241+
private Saml2X509Credential asSigningCredential(Credential credential) {
242+
RSAPrivateKey privateKey = readPrivateKey(credential.getPrivateKey());
243+
X509Certificate certificate = readCertificate(credential.getCertificate());
244+
return new Saml2X509Credential(privateKey, certificate,
245+
Saml2X509Credential.Saml2X509CredentialType.SIGNING);
246+
}
247+
248+
private Saml2X509Credential asDecryptionCredential(Credential credential) {
249+
RSAPrivateKey privateKey = readPrivateKey(credential.getPrivateKey());
250+
X509Certificate certificate = readCertificate(credential.getCertificate());
251+
return new Saml2X509Credential(privateKey, certificate,
252+
Saml2X509Credential.Saml2X509CredentialType.DECRYPTION);
253+
}
254+
255+
private Saml2X509Credential asVerificationCredential(Credential credential) {
256+
X509Certificate certificate = readCertificate(credential.getCertificate());
257+
return new Saml2X509Credential(certificate, Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION,
258+
Saml2X509Credential.Saml2X509CredentialType.VERIFICATION);
259+
}
260+
261+
private RSAPrivateKey readPrivateKey(String privateKey) {
262+
Assert.state(privateKey != null, "No private key specified");
263+
try {
264+
return (RSAPrivateKey) KeyPairUtil.decodePrivateKey(privateKey.getBytes(StandardCharsets.UTF_8));
265+
}
266+
catch (Exception ex) {
267+
throw new IllegalArgumentException(ex);
268+
}
269+
}
270+
271+
private X509Certificate readCertificate(String certificate) {
272+
Assert.state(certificate != null, "No certificate specified");
273+
try {
274+
return X509Support.decodeCertificate(certificate);
275+
}
276+
catch (Exception ex) {
277+
throw new IllegalArgumentException(ex);
278+
}
279+
}
280+
281+
private Consumer<AssertingPartyMetadata.Builder<?>> mapAssertingParty(String assertingPartyEntityId,
282+
Saml2MessageBinding assertingPartySingleSignOnBinding, String assertingPartySingleSignOnUrl,
283+
Boolean assertingPartySingleSignOnSignRequest, Saml2MessageBinding assertingPartySingleLogoutBinding,
284+
String assertingPartySingleLogoutResponseUrl, String assertingPartySingleLogoutUrl) {
285+
return (details) -> {
286+
applyingWhenNonNull(assertingPartyEntityId, details::entityId);
287+
applyingWhenNonNull(assertingPartySingleSignOnBinding, details::singleSignOnServiceBinding);
288+
applyingWhenNonNull(assertingPartySingleSignOnUrl, details::singleSignOnServiceLocation);
289+
applyingWhenNonNull(assertingPartySingleSignOnSignRequest, details::wantAuthnRequestsSigned);
290+
applyingWhenNonNull(assertingPartySingleLogoutUrl, details::singleLogoutServiceLocation);
291+
applyingWhenNonNull(assertingPartySingleLogoutResponseUrl,
292+
details::singleLogoutServiceResponseLocation);
293+
applyingWhenNonNull(assertingPartySingleLogoutBinding, details::singleLogoutServiceBinding);
294+
};
295+
}
296+
297+
private <T> void applyingWhenNonNull(T value, Consumer<T> consumer) {
298+
if (value != null) {
299+
consumer.accept(value);
300+
}
301+
}
302+
303+
private RelyingPartyRegistration.Builder createBuilderUsingMetadata(String assertingPartyEntityId,
304+
String assertingPartyMetadataUri) {
305+
Collection<RelyingPartyRegistration.Builder> candidates = RelyingPartyRegistrations
306+
.collectionFromMetadataLocation(assertingPartyMetadataUri);
307+
for (RelyingPartyRegistration.Builder candidate : candidates) {
308+
if (assertingPartyEntityId == null || assertingPartyEntityId.equals(getEntityId(candidate))) {
309+
return candidate;
310+
}
311+
}
312+
throw new IllegalStateException("No relying party with Entity ID '" + assertingPartyEntityId + "' found");
313+
}
314+
315+
private Object getEntityId(RelyingPartyRegistration.Builder candidate) {
316+
String[] result = new String[1];
317+
candidate.assertingPartyMetadata((builder) -> result[0] = builder.build().getEntityId());
318+
return result[0];
319+
}
320+
321+
private List<Credential> parseCredentials(String credentials) {
322+
if (!StringUtils.hasText(credentials)) {
323+
return new ArrayList<>();
324+
}
325+
try {
326+
return OBJECT_MAPPER.readValue(credentials, new TypeReference<>() {
327+
});
328+
}
329+
catch (Exception ex) {
330+
throw new IllegalArgumentException(ex.getMessage(), ex);
331+
}
332+
}
333+
334+
private String getLobValue(ResultSet rs, String columnName) throws SQLException {
335+
String columnValue = null;
336+
byte[] columnValueBytes = this.lobHandler.getBlobAsBytes(rs, columnName);
337+
if (columnValueBytes != null) {
338+
columnValue = new String(columnValueBytes, StandardCharsets.UTF_8);
339+
}
340+
return columnValue;
341+
}
342+
343+
}
344+
345+
public static class Credential {
346+
347+
private String privateKey;
348+
349+
private String certificate;
350+
351+
public Credential() {
352+
}
353+
354+
public Credential(String privateKey, String certificate) {
355+
this.privateKey = privateKey;
356+
this.certificate = certificate;
357+
}
358+
359+
public String getPrivateKey() {
360+
return privateKey;
361+
}
362+
363+
public void setPrivateKey(String privateKey) {
364+
this.privateKey = privateKey;
365+
}
366+
367+
public String getCertificate() {
368+
return certificate;
369+
}
370+
371+
public void setCertificate(String certificate) {
372+
this.certificate = certificate;
373+
}
374+
375+
}
376+
377+
}

0 commit comments

Comments
 (0)