|
package com.example.keycloak.aws; |
|
|
|
import com.fasterxml.jackson.databind.JsonNode; |
|
import com.fasterxml.jackson.databind.ObjectMapper; |
|
import org.keycloak.models.KeycloakSession; |
|
import software.amazon.awssdk.regions.Region; |
|
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; |
|
import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; |
|
import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; |
|
|
|
import java.io.IOException; |
|
import java.time.Duration; |
|
import java.time.Instant; |
|
import java.util.Map; |
|
import java.util.Objects; |
|
import java.util.concurrent.ConcurrentHashMap; |
|
|
|
public class AwsSecretsManagerProvider implements SecretsProvider, AutoCloseable { |
|
private static final ObjectMapper MAPPER = new ObjectMapper(); |
|
|
|
private static final class Entry { |
|
final JsonNode json; |
|
final String versionId; |
|
final Instant expiresAt; |
|
Entry(JsonNode json, String versionId, Instant expiresAt) { |
|
this.json = json; |
|
this.versionId = versionId; |
|
this.expiresAt = expiresAt; |
|
} |
|
} |
|
|
|
private static volatile SecretsManagerClient client; |
|
|
|
private final KeycloakSession session; |
|
private final Region region; |
|
private final Duration ttl; |
|
private final Map<String, Entry> cache = new ConcurrentHashMap<>(); |
|
|
|
public AwsSecretsManagerProvider(KeycloakSession session, String region, Duration ttl) { |
|
this.session = session; |
|
this.region = Region.of(region); |
|
this.ttl = ttl; |
|
if (client == null) { |
|
synchronized (AwsSecretsManagerProvider.class) { |
|
if (client == null) { |
|
client = SecretsManagerClient.builder().region(this.region).build(); |
|
} |
|
} |
|
} |
|
} |
|
|
|
@Override |
|
public String getSecretField(String secretId, String fieldName) { |
|
Entry e = cache.get(secretId); |
|
if (e == null || Instant.now().isAfter(e.expiresAt)) { |
|
e = refresh(secretId, e); // fetch AWSCURRENT; detect rotation by VersionId change |
|
cache.put(secretId, e); |
|
} |
|
JsonNode node = e.json.get(fieldName); |
|
if (node == null) throw new IllegalArgumentException("Field " + fieldName + " not found in secret " + secretId); |
|
return node.asText(); |
|
} |
|
|
|
private Entry refresh(String secretId, Entry previous) { |
|
GetSecretValueResponse resp = client.getSecretValue(GetSecretValueRequest.builder() |
|
.secretId(secretId) |
|
.versionStage("AWSCURRENT") // pin to current to follow rotation |
|
.build()); |
|
|
|
String newVersion = resp.versionId(); |
|
JsonNode json; |
|
try { |
|
json = MAPPER.readTree(resp.secretString()); |
|
} catch (IOException ex) { |
|
throw new RuntimeException("Failed to parse secret JSON", ex); |
|
} |
|
|
|
// If VersionId changed, rotation detected; replace immediately. If same, just extend TTL. |
|
boolean rotated = previous == null || !Objects.equals(previous.versionId, newVersion); |
|
Instant expiry = Instant.now().plus(ttl); |
|
return new Entry(json, newVersion, expiry); |
|
} |
|
|
|
@Override |
|
public void close() { |
|
// no-op; client is shared |
|
} |
|
} |