Skip to content

Enhance SslInfo to support multiple certificate stores. #45355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
package org.springframework.boot.actuate.ssl;

import java.util.List;
import java.util.Set;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -38,6 +39,7 @@
* Tests for {@link SslHealthIndicator}.
*
* @author Jonatan Ivanov
* @author Joshua Chen
*/
class SslHealthIndicatorTests {

@@ -55,7 +57,7 @@ void setUp() {
this.validity = mock(CertificateValidityInfo.class);
given(sslInfo.getBundles()).willReturn(List.of(bundle));
given(bundle.getCertificateChains()).willReturn(List.of(certificateChain));
given(certificateChain.getCertificates()).willReturn(List.of(certificateInfo));
given(certificateChain.getCertificates()).willReturn(Set.of(certificateInfo));
given(certificateInfo.getValidity()).willReturn(this.validity);
}

Original file line number Diff line number Diff line change
@@ -24,23 +24,28 @@
import java.security.cert.X509Certificate;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import javax.security.auth.x500.X500Principal;

import org.springframework.boot.info.SslInfo.CertificateValidityInfo.Status;
import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.ssl.SslBundles;
import org.springframework.boot.ssl.SslStoreBundle;
import org.springframework.util.ObjectUtils;

/**
* Information about the certificates that the application uses.
*
* @author Jonatan Ivanov
* @author Joshua Chen
* @since 3.4.0
*/
public class SslInfo {
@@ -72,7 +77,21 @@ public final class BundleInfo {

private BundleInfo(String name, SslBundle sslBundle) {
this.name = name;
this.certificateChains = extractCertificateChains(sslBundle.getStores().getKeyStore());
this.certificateChains = extractCertificateChains(sslBundle.getStores());
}

private List<CertificateChainInfo> extractCertificateChains(SslStoreBundle storeBundle) {
if (storeBundle == null) {
return Collections.emptyList();
}
List<CertificateChainInfo> certificateChains = new ArrayList<>();
if (storeBundle.getKeyStore() != null) {
certificateChains.addAll(extractCertificateChains(storeBundle.getKeyStore()));
}
if (storeBundle.getTrustStore() != null) {
certificateChains.addAll(extractCertificateChains(storeBundle.getTrustStore()));
}
return certificateChains;
}

private List<CertificateChainInfo> extractCertificateChains(KeyStore keyStore) {
@@ -107,29 +126,34 @@ public final class CertificateChainInfo {

private final String alias;

private final List<CertificateInfo> certificates;
private final Set<CertificateInfo> certificates;

CertificateChainInfo(KeyStore keyStore, String alias) {
this.alias = alias;
this.certificates = extractCertificates(keyStore, alias);
}

private List<CertificateInfo> extractCertificates(KeyStore keyStore, String alias) {
private Set<CertificateInfo> extractCertificates(KeyStore keyStore, String alias) {
try {
Certificate[] certificates = keyStore.getCertificateChain(alias);
return (!ObjectUtils.isEmpty(certificates))
? Arrays.stream(certificates).map(CertificateInfo::new).toList() : Collections.emptyList();
Set<CertificateInfo> certificateInfos = new java.util.HashSet<>(!ObjectUtils.isEmpty(certificates)
? Arrays.stream(certificates).map(CertificateInfo::new).collect(Collectors.toSet())
: Collections.emptySet());
if (keyStore.getCertificate(alias) != null) {
certificateInfos.add(new CertificateInfo(keyStore.getCertificate(alias)));
}
return certificateInfos;
}
catch (KeyStoreException ex) {
return Collections.emptyList();
return Collections.emptySet();
}
}

public String getAlias() {
return this.alias;
}

public List<CertificateInfo> getCertificates() {
public Set<CertificateInfo> getCertificates() {
return this.certificates;
}

@@ -208,6 +232,23 @@ private <R> R extract(Function<X509Certificate, R> extractor) {
return (this.certificate != null) ? extractor.apply(this.certificate) : null;
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other == null || getClass() != other.getClass()) {
return false;
}
return (this.certificate != null) ? this.certificate.equals(((CertificateInfo) other).certificate)
: super.equals(other);
}

@Override
public int hashCode() {
return (this.certificate != null) ? this.certificate.hashCode() : super.hashCode();
}

}

/**
Original file line number Diff line number Diff line change
@@ -23,10 +23,13 @@
import java.nio.file.Path;
import java.time.Duration;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;

import org.springframework.boot.info.SslInfo.BundleInfo;
import org.springframework.boot.info.SslInfo.CertificateChainInfo;
@@ -46,13 +49,15 @@
* Tests for {@link SslInfo}.
*
* @author Jonatan Ivanov
* @author Joshua Chen
*/
class SslInfoTests {

@Test
@ParameterizedTest
@EnumSource(StoreType.class)
@WithPackageResources("test.p12")
void validCertificatesShouldProvideSslInfo() {
SslInfo sslInfo = createSslInfo("classpath:test.p12");
void validCertificatesShouldProvideSslInfo(StoreType storeType) {
SslInfo sslInfo = createSslInfo(storeType, "classpath:test.p12");
assertThat(sslInfo.getBundles()).hasSize(1);
BundleInfo bundle = sslInfo.getBundles().get(0);
assertThat(bundle.getName()).isEqualTo("test-0");
@@ -62,10 +67,10 @@ void validCertificatesShouldProvideSslInfo() {
assertThat(bundle.getCertificateChains().get(1).getAlias()).isEqualTo("test-alias");
assertThat(bundle.getCertificateChains().get(1).getCertificates()).hasSize(1);
assertThat(bundle.getCertificateChains().get(2).getAlias()).isEqualTo("spring-boot-cert");
assertThat(bundle.getCertificateChains().get(2).getCertificates()).isEmpty();
assertThat(bundle.getCertificateChains().get(2).getCertificates()).hasSize(1);
assertThat(bundle.getCertificateChains().get(3).getAlias()).isEqualTo("test-alias-cert");
assertThat(bundle.getCertificateChains().get(3).getCertificates()).isEmpty();
CertificateInfo cert1 = bundle.getCertificateChains().get(0).getCertificates().get(0);
assertThat(bundle.getCertificateChains().get(3).getCertificates()).hasSize(1);
CertificateInfo cert1 = bundle.getCertificateChains().get(0).getCertificates().iterator().next();
assertThat(cert1.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
assertThat(cert1.getIssuer()).isEqualTo(cert1.getSubject());
assertThat(cert1.getSerialNumber()).isNotEmpty();
@@ -76,7 +81,7 @@ void validCertificatesShouldProvideSslInfo() {
assertThat(cert1.getValidity()).isNotNull();
assertThat(cert1.getValidity().getStatus()).isSameAs(Status.VALID);
assertThat(cert1.getValidity().getMessage()).isNull();
CertificateInfo cert2 = bundle.getCertificateChains().get(1).getCertificates().get(0);
CertificateInfo cert2 = bundle.getCertificateChains().get(1).getCertificates().iterator().next();
assertThat(cert2.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
assertThat(cert2.getIssuer()).isEqualTo(cert2.getSubject());
assertThat(cert2.getSerialNumber()).isNotEmpty();
@@ -89,19 +94,20 @@ void validCertificatesShouldProvideSslInfo() {
assertThat(cert2.getValidity().getMessage()).isNull();
}

@Test
@ParameterizedTest
@EnumSource(StoreType.class)
@WithPackageResources("test-not-yet-valid.p12")
void notYetValidCertificateShouldProvideSslInfo() {
SslInfo sslInfo = createSslInfo("classpath:test-not-yet-valid.p12");
void notYetValidCertificateShouldProvideSslInfo(StoreType storeType) {
SslInfo sslInfo = createSslInfo(storeType, "classpath:test-not-yet-valid.p12");
assertThat(sslInfo.getBundles()).hasSize(1);
BundleInfo bundle = sslInfo.getBundles().get(0);
assertThat(bundle.getName()).isEqualTo("test-0");
assertThat(bundle.getCertificateChains()).hasSize(1);
CertificateChainInfo certificateChain = bundle.getCertificateChains().get(0);
assertThat(certificateChain.getAlias()).isEqualTo("spring-boot");
List<CertificateInfo> certs = certificateChain.getCertificates();
Set<CertificateInfo> certs = certificateChain.getCertificates();
assertThat(certs).hasSize(1);
CertificateInfo cert = certs.get(0);
CertificateInfo cert = certs.iterator().next();
assertThat(cert.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
assertThat(cert.getIssuer()).isEqualTo(cert.getSubject());
assertThat(cert.getSerialNumber()).isNotEmpty();
@@ -124,9 +130,9 @@ void expiredCertificateShouldProvideSslInfo() {
assertThat(bundle.getCertificateChains()).hasSize(1);
CertificateChainInfo certificateChain = bundle.getCertificateChains().get(0);
assertThat(certificateChain.getAlias()).isEqualTo("spring-boot");
List<CertificateInfo> certs = certificateChain.getCertificates();
Set<CertificateInfo> certs = certificateChain.getCertificates();
assertThat(certs).hasSize(1);
CertificateInfo cert = certs.get(0);
CertificateInfo cert = certs.iterator().next();
assertThat(cert.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
assertThat(cert.getIssuer()).isEqualTo(cert.getSubject());
assertThat(cert.getSerialNumber()).isNotEmpty();
@@ -150,9 +156,9 @@ void soonToBeExpiredCertificateShouldProvideSslInfo(@TempDir Path tempDir)
assertThat(bundle.getCertificateChains()).hasSize(1);
CertificateChainInfo certificateChain = bundle.getCertificateChains().get(0);
assertThat(certificateChain.getAlias()).isEqualTo("spring-boot");
List<CertificateInfo> certs = certificateChain.getCertificates();
Set<CertificateInfo> certs = certificateChain.getCertificates();
assertThat(certs).hasSize(1);
CertificateInfo cert = certs.get(0);
CertificateInfo cert = certs.iterator().next();
assertThat(cert.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
assertThat(cert.getIssuer()).isEqualTo(cert.getSubject());
assertThat(cert.getSerialNumber()).isNotEmpty();
@@ -178,7 +184,7 @@ void multipleBundlesShouldProvideSslInfo(@TempDir Path tempDir) throws IOExcepti
.flatMap((bundle) -> bundle.getCertificateChains().stream())
.flatMap((certificateChain) -> certificateChain.getCertificates().stream())
.toList();
assertThat(certs).hasSize(5);
assertThat(certs).hasSize(7);
assertThat(certs).allSatisfy((cert) -> {
assertThat(cert.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
assertThat(cert.getIssuer()).isEqualTo(cert.getSubject());
@@ -227,10 +233,20 @@ void nullKeyStore() {
}

private SslInfo createSslInfo(String... locations) {
return createSslInfo(StoreType.KEYSTORE, locations);
}

private SslInfo createSslInfo(StoreType storeType, String... locations) {
DefaultSslBundleRegistry sslBundleRegistry = new DefaultSslBundleRegistry();
for (int i = 0; i < locations.length; i++) {
JksSslStoreDetails keyStoreDetails = JksSslStoreDetails.forLocation(locations[i]).withPassword("secret");
SslStoreBundle sslStoreBundle = new JksSslStoreBundle(keyStoreDetails, null);
JksSslStoreDetails storeDetails = JksSslStoreDetails.forLocation(locations[i]).withPassword("secret");
SslStoreBundle sslStoreBundle;
if (storeType == StoreType.TRUSTSTORE) {
sslStoreBundle = new JksSslStoreBundle(null, storeDetails);
}
else {
sslStoreBundle = new JksSslStoreBundle(storeDetails, null);
}
sslBundleRegistry.registerBundle("test-%d".formatted(i), SslBundle.of(sslStoreBundle));
}
return new SslInfo(sslBundleRegistry, Duration.ofDays(7));
@@ -270,4 +286,10 @@ private ProcessBuilder createProcessBuilder(Path keystore) {
return processBuilder;
}

private enum StoreType {

KEYSTORE, TRUSTSTORE

}

}