|
| 1 | +/* |
| 2 | + * Copyright 2002-2025 the original author or authors. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +package org.springframework.security.saml2.provider.service.registration; |
| 18 | + |
| 19 | +import java.nio.charset.StandardCharsets; |
| 20 | +import java.security.cert.X509Certificate; |
| 21 | +import java.security.interfaces.RSAPrivateKey; |
| 22 | +import java.sql.ResultSet; |
| 23 | +import java.sql.SQLException; |
| 24 | +import java.sql.Types; |
| 25 | +import java.util.ArrayList; |
| 26 | +import java.util.Collection; |
| 27 | +import java.util.Iterator; |
| 28 | +import java.util.List; |
| 29 | +import java.util.function.Consumer; |
| 30 | + |
| 31 | +import com.fasterxml.jackson.core.type.TypeReference; |
| 32 | +import com.fasterxml.jackson.databind.ObjectMapper; |
| 33 | +import org.cryptacular.util.KeyPairUtil; |
| 34 | +import org.opensaml.security.x509.X509Support; |
| 35 | + |
| 36 | +import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; |
| 37 | +import org.springframework.jdbc.core.JdbcOperations; |
| 38 | +import org.springframework.jdbc.core.PreparedStatementSetter; |
| 39 | +import org.springframework.jdbc.core.RowMapper; |
| 40 | +import org.springframework.jdbc.core.SqlParameterValue; |
| 41 | +import org.springframework.jdbc.support.lob.DefaultLobHandler; |
| 42 | +import org.springframework.jdbc.support.lob.LobHandler; |
| 43 | +import org.springframework.security.saml2.core.Saml2X509Credential; |
| 44 | +import org.springframework.util.Assert; |
| 45 | +import org.springframework.util.StringUtils; |
| 46 | + |
| 47 | +public class JdbcRelyingPartyRegistrationRepository implements IterableRelyingPartyRegistrationRepository { |
| 48 | + |
| 49 | + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); |
| 50 | + |
| 51 | + // @formatter:off |
| 52 | + static final String COLUMN_NAMES = "id, " |
| 53 | + + "entity_id, " |
| 54 | + + "name_id_format, " |
| 55 | + + "acs_location, " |
| 56 | + + "acs_binding, " |
| 57 | + + "signing_credentials, " |
| 58 | + + "decryption_credentials, " |
| 59 | + + "singlelogout_url, " |
| 60 | + + "singlelogout_response_url, " |
| 61 | + + "singlelogout_binding, " |
| 62 | + + "assertingparty_entity_id, " |
| 63 | + + "assertingparty_metadata_uri, " |
| 64 | + + "assertingparty_singlesignon_url, " |
| 65 | + + "assertingparty_singlesignon_binding, " |
| 66 | + + "assertingparty_singlesignon_sign_request, " |
| 67 | + + "assertingparty_verification_credentials, " |
| 68 | + + "assertingparty_singlelogout_url, " |
| 69 | + + "assertingparty_singlelogout_response_url, " |
| 70 | + + "assertingparty_singlelogout_binding" |
| 71 | + ; |
| 72 | + // @formatter:on |
| 73 | + |
| 74 | + private static final String TABLE_NAME = "saml2_relying_party_registration"; |
| 75 | + |
| 76 | + private static final String PK_FILTER = "id = ?"; |
| 77 | + |
| 78 | + private static final String ENTITY_ID_FILTER = "entity_id = ?"; |
| 79 | + |
| 80 | + // @formatter:off |
| 81 | + private static final String LOAD_BY_ID_SQL = "SELECT " + COLUMN_NAMES |
| 82 | + + " FROM " + TABLE_NAME |
| 83 | + + " WHERE " + PK_FILTER; |
| 84 | + |
| 85 | + private static final String LOAD_BY_ENTITY_ID_SQL = "SELECT " + COLUMN_NAMES |
| 86 | + + " FROM " + TABLE_NAME |
| 87 | + + " WHERE " + ENTITY_ID_FILTER; |
| 88 | + |
| 89 | + private static final String LOAD_ALL_SQL = "SELECT " + COLUMN_NAMES |
| 90 | + + " FROM " + TABLE_NAME; |
| 91 | + // @formatter:on |
| 92 | + |
| 93 | + protected final JdbcOperations jdbcOperations; |
| 94 | + |
| 95 | + protected RowMapper<RelyingPartyRegistration> relyingPartyRegistrationRowMapper; |
| 96 | + |
| 97 | + protected final LobHandler lobHandler; |
| 98 | + |
| 99 | + /** |
| 100 | + * Constructs a {@code JdbcOAuth2AuthorizedClientService} using the provided |
| 101 | + * parameters. |
| 102 | + * @param jdbcOperations the JDBC operations |
| 103 | + */ |
| 104 | + public JdbcRelyingPartyRegistrationRepository(JdbcOperations jdbcOperations) { |
| 105 | + this(jdbcOperations, new DefaultLobHandler()); |
| 106 | + } |
| 107 | + |
| 108 | + /** |
| 109 | + * Constructs a {@code JdbcRelyingPartyRegistrationRepository} using the provided |
| 110 | + * parameters. |
| 111 | + * @param jdbcOperations the JDBC operations |
| 112 | + * @param lobHandler the handler for large binary fields and large text fields |
| 113 | + */ |
| 114 | + public JdbcRelyingPartyRegistrationRepository(JdbcOperations jdbcOperations, LobHandler lobHandler) { |
| 115 | + Assert.notNull(jdbcOperations, "jdbcOperations cannot be null"); |
| 116 | + Assert.notNull(lobHandler, "lobHandler cannot be null"); |
| 117 | + this.jdbcOperations = jdbcOperations; |
| 118 | + this.lobHandler = lobHandler; |
| 119 | + RelyingPartyRegistrationRowMapper rowMapper = new RelyingPartyRegistrationRowMapper(); |
| 120 | + rowMapper.setLobHandler(lobHandler); |
| 121 | + this.relyingPartyRegistrationRowMapper = rowMapper; |
| 122 | + } |
| 123 | + |
| 124 | + /** |
| 125 | + * Sets the {@link RowMapper} used for mapping the current row in |
| 126 | + * {@code java.sql.ResultSet} to {@link RelyingPartyRegistration}. The default is |
| 127 | + * {@link RelyingPartyRegistrationRowMapper}. |
| 128 | + * @param relyingPartyRegistrationRowMapper the {@link RowMapper} used for mapping the |
| 129 | + * current row in {@code java.sql.ResultSet} to {@link RelyingPartyRegistration} |
| 130 | + */ |
| 131 | + public final void setAuthorizedClientRowMapper( |
| 132 | + RowMapper<RelyingPartyRegistration> relyingPartyRegistrationRowMapper) { |
| 133 | + Assert.notNull(relyingPartyRegistrationRowMapper, "relyingPartyRegistrationRowMapper cannot be null"); |
| 134 | + this.relyingPartyRegistrationRowMapper = relyingPartyRegistrationRowMapper; |
| 135 | + } |
| 136 | + |
| 137 | + @Override |
| 138 | + public RelyingPartyRegistration findByRegistrationId(String registrationId) { |
| 139 | + Assert.hasText(registrationId, "registrationId cannot be empty"); |
| 140 | + SqlParameterValue[] parameters = new SqlParameterValue[] { |
| 141 | + new SqlParameterValue(Types.VARCHAR, registrationId) }; |
| 142 | + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); |
| 143 | + List<RelyingPartyRegistration> result = this.jdbcOperations.query(LOAD_BY_ID_SQL, pss, |
| 144 | + this.relyingPartyRegistrationRowMapper); |
| 145 | + return !result.isEmpty() ? result.get(0) : null; |
| 146 | + } |
| 147 | + |
| 148 | + @Override |
| 149 | + public RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) { |
| 150 | + Assert.hasText(entityId, "entityId cannot be empty"); |
| 151 | + SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, entityId) }; |
| 152 | + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); |
| 153 | + List<RelyingPartyRegistration> result = this.jdbcOperations.query(LOAD_BY_ENTITY_ID_SQL, pss, |
| 154 | + this.relyingPartyRegistrationRowMapper); |
| 155 | + return !result.isEmpty() ? result.get(0) : null; |
| 156 | + } |
| 157 | + |
| 158 | + @Override |
| 159 | + public Iterator<RelyingPartyRegistration> iterator() { |
| 160 | + List<RelyingPartyRegistration> result = this.jdbcOperations.query(LOAD_ALL_SQL, |
| 161 | + this.relyingPartyRegistrationRowMapper); |
| 162 | + return result.iterator(); |
| 163 | + } |
| 164 | + |
| 165 | + /** |
| 166 | + * The default {@link RowMapper} that maps the current row in |
| 167 | + * {@code java.sql.ResultSet} to {@link RelyingPartyRegistration}. |
| 168 | + */ |
| 169 | + public static class RelyingPartyRegistrationRowMapper implements RowMapper<RelyingPartyRegistration> { |
| 170 | + |
| 171 | + protected LobHandler lobHandler = new DefaultLobHandler(); |
| 172 | + |
| 173 | + public final void setLobHandler(LobHandler lobHandler) { |
| 174 | + Assert.notNull(lobHandler, "lobHandler cannot be null"); |
| 175 | + this.lobHandler = lobHandler; |
| 176 | + } |
| 177 | + |
| 178 | + @Override |
| 179 | + public RelyingPartyRegistration mapRow(ResultSet rs, int rowNum) throws SQLException { |
| 180 | + String registrationId = rs.getString("id"); |
| 181 | + String entityId = StringUtils.hasText(rs.getString("entity_id")) ? rs.getString("entity_id") |
| 182 | + : "{baseUrl}/saml2/service-provider-metadata/{registrationId}"; |
| 183 | + String nameIdFormat = rs.getString("name_id_format"); |
| 184 | + String acsLocation = StringUtils.hasText(rs.getString("acs_location")) ? rs.getString("acs_location") |
| 185 | + : "{baseUrl}/login/saml2/sso/{registrationId}"; |
| 186 | + String acsBinding = StringUtils.hasText(rs.getString("acs_binding")) ? rs.getString("acs_binding") |
| 187 | + : Saml2MessageBinding.POST.getUrn(); |
| 188 | + List<Credential> signingCredentials = parseCredentials(getLobValue(rs, "signing_credentials")); |
| 189 | + List<Credential> decryptionCredentials = parseCredentials(getLobValue(rs, "decryption_credentials")); |
| 190 | + String singleLogoutUrl = rs.getString("singlelogout_url"); |
| 191 | + String singleLogoutResponseUrl = rs.getString("singlelogout_response_url"); |
| 192 | + Saml2MessageBinding singleLogoutBinding = Saml2MessageBinding.from(rs.getString("singlelogout_binding")); |
| 193 | + String assertingPartyEntityId = rs.getString("assertingparty_entity_id"); |
| 194 | + String assertingPartyMetadataUri = rs.getString("assertingparty_metadata_uri"); |
| 195 | + String assertingPartySingleSignOnUrl = rs.getString("assertingparty_singlesignon_url"); |
| 196 | + Saml2MessageBinding assertingPartySingleSignOnBinding = Saml2MessageBinding |
| 197 | + .from(rs.getString("assertingparty_singlesignon_binding")); |
| 198 | + Boolean assertingPartySingleSignOnSignRequest = rs.getBoolean("assertingparty_singlesignon_sign_request"); |
| 199 | + List<Credential> assertingPartyVerificationCredentials = parseCredentials( |
| 200 | + getLobValue(rs, "assertingparty_verification_credentials")); |
| 201 | + String assertingPartySingleLogoutUrl = rs.getString("assertingparty_singlelogout_url"); |
| 202 | + String assertingPartySingleLogoutResponseUrl = rs.getString("assertingparty_singlelogout_response_url"); |
| 203 | + Saml2MessageBinding assertingPartySingleLogoutBinding = Saml2MessageBinding |
| 204 | + .from(rs.getString("assertingparty_singlelogout_binding")); |
| 205 | + |
| 206 | + boolean usingMetadata = StringUtils.hasText(assertingPartyMetadataUri); |
| 207 | + RelyingPartyRegistration.Builder builder = (!usingMetadata) |
| 208 | + ? RelyingPartyRegistration.withRegistrationId(registrationId) |
| 209 | + : createBuilderUsingMetadata(assertingPartyEntityId, assertingPartyMetadataUri) |
| 210 | + .registrationId(registrationId); |
| 211 | + builder.assertionConsumerServiceLocation(acsLocation); |
| 212 | + builder.assertionConsumerServiceBinding(Saml2MessageBinding.from(acsBinding)); |
| 213 | + builder.assertingPartyMetadata(mapAssertingParty(assertingPartyEntityId, assertingPartySingleSignOnBinding, |
| 214 | + assertingPartySingleSignOnUrl, assertingPartySingleSignOnSignRequest, |
| 215 | + assertingPartySingleLogoutBinding, assertingPartySingleLogoutResponseUrl, |
| 216 | + assertingPartySingleLogoutUrl)); |
| 217 | + builder.signingX509Credentials((credentials) -> signingCredentials.stream() |
| 218 | + .map(this::asSigningCredential) |
| 219 | + .forEach(credentials::add)); |
| 220 | + builder.decryptionX509Credentials((credentials) -> decryptionCredentials.stream() |
| 221 | + .map(this::asDecryptionCredential) |
| 222 | + .forEach(credentials::add)); |
| 223 | + builder.assertingPartyMetadata((details) -> details |
| 224 | + .verificationX509Credentials((credentials) -> assertingPartyVerificationCredentials.stream() |
| 225 | + .map(this::asVerificationCredential) |
| 226 | + .forEach(credentials::add))); |
| 227 | + builder.singleLogoutServiceLocation(singleLogoutUrl); |
| 228 | + builder.singleLogoutServiceResponseLocation(singleLogoutResponseUrl); |
| 229 | + builder.singleLogoutServiceBinding(singleLogoutBinding); |
| 230 | + builder.entityId(entityId); |
| 231 | + builder.nameIdFormat(nameIdFormat); |
| 232 | + RelyingPartyRegistration registration = builder.build(); |
| 233 | + boolean signRequest = registration.getAssertingPartyMetadata().getWantAuthnRequestsSigned(); |
| 234 | + if (signRequest) { |
| 235 | + Assert.state(!signingCredentials.isEmpty(), |
| 236 | + "Signing credentials must not be empty when authentication requests require signing."); |
| 237 | + } |
| 238 | + return registration; |
| 239 | + } |
| 240 | + |
| 241 | + private Saml2X509Credential asSigningCredential(Credential credential) { |
| 242 | + RSAPrivateKey privateKey = readPrivateKey(credential.getPrivateKey()); |
| 243 | + X509Certificate certificate = readCertificate(credential.getCertificate()); |
| 244 | + return new Saml2X509Credential(privateKey, certificate, |
| 245 | + Saml2X509Credential.Saml2X509CredentialType.SIGNING); |
| 246 | + } |
| 247 | + |
| 248 | + private Saml2X509Credential asDecryptionCredential(Credential credential) { |
| 249 | + RSAPrivateKey privateKey = readPrivateKey(credential.getPrivateKey()); |
| 250 | + X509Certificate certificate = readCertificate(credential.getCertificate()); |
| 251 | + return new Saml2X509Credential(privateKey, certificate, |
| 252 | + Saml2X509Credential.Saml2X509CredentialType.DECRYPTION); |
| 253 | + } |
| 254 | + |
| 255 | + private Saml2X509Credential asVerificationCredential(Credential credential) { |
| 256 | + X509Certificate certificate = readCertificate(credential.getCertificate()); |
| 257 | + return new Saml2X509Credential(certificate, Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION, |
| 258 | + Saml2X509Credential.Saml2X509CredentialType.VERIFICATION); |
| 259 | + } |
| 260 | + |
| 261 | + private RSAPrivateKey readPrivateKey(String privateKey) { |
| 262 | + Assert.state(privateKey != null, "No private key specified"); |
| 263 | + try { |
| 264 | + return (RSAPrivateKey) KeyPairUtil.decodePrivateKey(privateKey.getBytes(StandardCharsets.UTF_8)); |
| 265 | + } |
| 266 | + catch (Exception ex) { |
| 267 | + throw new IllegalArgumentException(ex); |
| 268 | + } |
| 269 | + } |
| 270 | + |
| 271 | + private X509Certificate readCertificate(String certificate) { |
| 272 | + Assert.state(certificate != null, "No certificate specified"); |
| 273 | + try { |
| 274 | + return X509Support.decodeCertificate(certificate); |
| 275 | + } |
| 276 | + catch (Exception ex) { |
| 277 | + throw new IllegalArgumentException(ex); |
| 278 | + } |
| 279 | + } |
| 280 | + |
| 281 | + private Consumer<AssertingPartyMetadata.Builder<?>> mapAssertingParty(String assertingPartyEntityId, |
| 282 | + Saml2MessageBinding assertingPartySingleSignOnBinding, String assertingPartySingleSignOnUrl, |
| 283 | + Boolean assertingPartySingleSignOnSignRequest, Saml2MessageBinding assertingPartySingleLogoutBinding, |
| 284 | + String assertingPartySingleLogoutResponseUrl, String assertingPartySingleLogoutUrl) { |
| 285 | + return (details) -> { |
| 286 | + applyingWhenNonNull(assertingPartyEntityId, details::entityId); |
| 287 | + applyingWhenNonNull(assertingPartySingleSignOnBinding, details::singleSignOnServiceBinding); |
| 288 | + applyingWhenNonNull(assertingPartySingleSignOnUrl, details::singleSignOnServiceLocation); |
| 289 | + applyingWhenNonNull(assertingPartySingleSignOnSignRequest, details::wantAuthnRequestsSigned); |
| 290 | + applyingWhenNonNull(assertingPartySingleLogoutUrl, details::singleLogoutServiceLocation); |
| 291 | + applyingWhenNonNull(assertingPartySingleLogoutResponseUrl, |
| 292 | + details::singleLogoutServiceResponseLocation); |
| 293 | + applyingWhenNonNull(assertingPartySingleLogoutBinding, details::singleLogoutServiceBinding); |
| 294 | + }; |
| 295 | + } |
| 296 | + |
| 297 | + private <T> void applyingWhenNonNull(T value, Consumer<T> consumer) { |
| 298 | + if (value != null) { |
| 299 | + consumer.accept(value); |
| 300 | + } |
| 301 | + } |
| 302 | + |
| 303 | + private RelyingPartyRegistration.Builder createBuilderUsingMetadata(String assertingPartyEntityId, |
| 304 | + String assertingPartyMetadataUri) { |
| 305 | + Collection<RelyingPartyRegistration.Builder> candidates = RelyingPartyRegistrations |
| 306 | + .collectionFromMetadataLocation(assertingPartyMetadataUri); |
| 307 | + for (RelyingPartyRegistration.Builder candidate : candidates) { |
| 308 | + if (assertingPartyEntityId == null || assertingPartyEntityId.equals(getEntityId(candidate))) { |
| 309 | + return candidate; |
| 310 | + } |
| 311 | + } |
| 312 | + throw new IllegalStateException("No relying party with Entity ID '" + assertingPartyEntityId + "' found"); |
| 313 | + } |
| 314 | + |
| 315 | + private Object getEntityId(RelyingPartyRegistration.Builder candidate) { |
| 316 | + String[] result = new String[1]; |
| 317 | + candidate.assertingPartyMetadata((builder) -> result[0] = builder.build().getEntityId()); |
| 318 | + return result[0]; |
| 319 | + } |
| 320 | + |
| 321 | + private List<Credential> parseCredentials(String credentials) { |
| 322 | + if (!StringUtils.hasText(credentials)) { |
| 323 | + return new ArrayList<>(); |
| 324 | + } |
| 325 | + try { |
| 326 | + return OBJECT_MAPPER.readValue(credentials, new TypeReference<>() { |
| 327 | + }); |
| 328 | + } |
| 329 | + catch (Exception ex) { |
| 330 | + throw new IllegalArgumentException(ex.getMessage(), ex); |
| 331 | + } |
| 332 | + } |
| 333 | + |
| 334 | + private String getLobValue(ResultSet rs, String columnName) throws SQLException { |
| 335 | + String columnValue = null; |
| 336 | + byte[] columnValueBytes = this.lobHandler.getBlobAsBytes(rs, columnName); |
| 337 | + if (columnValueBytes != null) { |
| 338 | + columnValue = new String(columnValueBytes, StandardCharsets.UTF_8); |
| 339 | + } |
| 340 | + return columnValue; |
| 341 | + } |
| 342 | + |
| 343 | + } |
| 344 | + |
| 345 | + public static class Credential { |
| 346 | + |
| 347 | + private String privateKey; |
| 348 | + |
| 349 | + private String certificate; |
| 350 | + |
| 351 | + public Credential() { |
| 352 | + } |
| 353 | + |
| 354 | + public Credential(String privateKey, String certificate) { |
| 355 | + this.privateKey = privateKey; |
| 356 | + this.certificate = certificate; |
| 357 | + } |
| 358 | + |
| 359 | + public String getPrivateKey() { |
| 360 | + return privateKey; |
| 361 | + } |
| 362 | + |
| 363 | + public void setPrivateKey(String privateKey) { |
| 364 | + this.privateKey = privateKey; |
| 365 | + } |
| 366 | + |
| 367 | + public String getCertificate() { |
| 368 | + return certificate; |
| 369 | + } |
| 370 | + |
| 371 | + public void setCertificate(String certificate) { |
| 372 | + this.certificate = certificate; |
| 373 | + } |
| 374 | + |
| 375 | + } |
| 376 | + |
| 377 | +} |
0 commit comments