Skip to content

Instantly share code, notes, and snippets.

@dehidehidehi
Created July 27, 2022 07:53
Show Gist options
  • Select an option

  • Save dehidehidehi/53f9489ac1ffb923f69603f05db3c7bd to your computer and use it in GitHub Desktop.

Select an option

Save dehidehidehi/53f9489ac1ffb923f69603f05db3c7bd to your computer and use it in GitHub Desktop.
Java - JWT Validator (RSA) Implementation
import io.jsonwebtoken.MalformedJwtException;
import java.time.Instant;
@AllArgsConstructor
@Getter(AccessLevel.PROTECTED)
@Slf4j
public abstract class AbstractSimpleJWTValidator implements JWTValidator {
private String jwtToken;
private String localKeyId;
private String publicKeyId;
private Long expiration;
/** Aud returned by the Jwt. **/
private String localAud;
/** Aud which was sent by client application. **/
private String referenceAud;
@Override public boolean isValidAud() {
return referenceAud.equals(localAud);
}
@Override public boolean isExpiredToken() {
Instant expirationInstant = Instant.ofEpochSecond(expiration);
return expirationInstant.isBefore(Instant.now());
}
@Override public void validateStructure() throws MalformedJwtException {
if (!JWTValidator.isValidStructure(jwtToken))
throw new MalformedJwtException("JWT structure is invalid.");
}
@Override public void validateSignature() throws BadJWTException {
if (!isValidSignature())
throw new BadJWTException("JWT signature is invalid.");
}
@Override public void validateAud() throws BadJWTException {
if (!isValidAud())
throw new BadJWTException("Aud doesn't match.");
}
@Override public void validateTokenNotExpired() throws BadJWTException {
if (isExpiredToken())
throw new BadJWTException("JWT is expired.");
}
@Override public void validateClaims() throws BadJWTException {
throw new UnsupportedOperationException("Not implemented");
}
@Override public void validate() throws JwtException {
validateStructure();
validateTokenNotExpired();
validateAud();
validateSignature();
log.debug("JWT Token is valid.");
}
}
public class BadJWTException extends JwtException {
public BadJWTException(String message) {
super(message);
}
public BadJWTException(String message, Throwable cause) {
super(message, cause);
}
}
public class JwtException extends Exception {
public JwtException(String message) {
super(message);
}
public JwtException(String message, Throwable cause) {
super(message, cause);
}
}
import io.jsonwebtoken.MalformedJwtException;
public interface JWTValidator {
/** Checks whether jwtToken's structure contains 2 period separators. **/
static boolean isValidStructure(String jwtToken) {
boolean hasTwoPeriods = jwtToken.split("\\.").length == 3;
boolean hasMinimumLength = jwtToken.length() >= 5;
return hasMinimumLength && hasTwoPeriods;
}
boolean isValidSignature() throws BadJWTException;
boolean isValidAud();
boolean isExpiredToken();
void validateStructure() throws MalformedJwtException;
void validateSignature() throws BadJWTException;
void validateAud() throws BadJWTException;
void validateTokenNotExpired() throws BadJWTException;
void validateClaims() throws BadJWTException;
void validate() throws JwtException;
}
import java.math.BigInteger;
import java.security.*;
import java.security.spec.RSAPublicKeySpec;
import java.util.Base64;
public class JWTValidatorRSAImpl extends AbstractSimpleJWTValidator {
private final BigInteger decodedBigModulus;
private final BigInteger decodedBigExponent;
@Builder
public JWTValidatorRSAImpl(String jwtToken, String localKeyId, String publicKeyId, Long expiration,
String localAud, String referenceAud, BigInteger decodedBigModulus, BigInteger decodedBigExponent) {
super(jwtToken, localKeyId, publicKeyId, expiration, localAud, referenceAud);
this.decodedBigModulus = decodedBigModulus;
this.decodedBigExponent = decodedBigExponent;
}
@Override public boolean isValidSignature() throws BadJWTException {
// Declare
PublicKey publicKey = computePublicKeyFromModulusAndExponent();
String signedData = getJwtToken().substring(0, getJwtToken().lastIndexOf("."));
String signatureB64u = getJwtToken().substring(getJwtToken().lastIndexOf(".") + 1);
byte[] signature = Base64.getUrlDecoder().decode(signatureB64u);
Signature sig = tryGetSignatureInstanceSha256WithRSA();
// Logic
try {
sig.initVerify(publicKey);
sig.update(signedData.getBytes());
boolean isVerify = sig.verify(signature);
return isVerify;
} catch (SignatureException | InvalidKeyException e) {
throw new BadJWTException(e.getMessage(), e);
}
}
private PublicKey computePublicKeyFromModulusAndExponent() {
RSAPublicKeySpec rsaPublicKeySpec = new RSAPublicKeySpec(decodedBigModulus, decodedBigExponent);
return tryGetRSAPublicKey(rsaPublicKeySpec);
}
@SneakyThrows
private PublicKey tryGetRSAPublicKey(RSAPublicKeySpec rsaPublicKeySpec) {
return KeyFactory.getInstance("RSA").generatePublic(rsaPublicKeySpec);
}
@SneakyThrows
private Signature tryGetSignatureInstanceSha256WithRSA() {
return Signature.getInstance("SHA256withRSA");
}
}
import java.math.BigInteger;
public class JWTValidatorRSAImplBuilder {
private String jwtToken;
private String localKeyId;
private String publicKeyId;
private Long expiration;
private String localAud;
private String referenceAud;
private BigInteger decodedBigModulus;
private BigInteger decodedBigExponent;
public JWTValidatorRSAImplBuilder setJwtToken(String jwtToken) {
this.jwtToken = jwtToken;
return this;
}
public JWTValidatorRSAImplBuilder setLocalKeyId(String localKeyId) {
this.localKeyId = localKeyId;
return this;
}
public JWTValidatorRSAImplBuilder setPublicKeyId(String publicKeyId) {
this.publicKeyId = publicKeyId;
return this;
}
public JWTValidatorRSAImplBuilder setExpiration(Long expiration) {
this.expiration = expiration;
return this;
}
public JWTValidatorRSAImplBuilder setLocalAud(String localAud) {
this.localAud = localAud;
return this;
}
public JWTValidatorRSAImplBuilder setReferenceAud(String referenceAud) {
this.referenceAud = referenceAud;
return this;
}
public JWTValidatorRSAImplBuilder setDecodedBigModulus(BigInteger decodedBigModulus) {
this.decodedBigModulus = decodedBigModulus;
return this;
}
public JWTValidatorRSAImplBuilder setDecodedBigExponent(BigInteger decodedBigExponent) {
this.decodedBigExponent = decodedBigExponent;
return this;
}
public JWTValidatorRSAImpl createJWTValidatorRSAImpl() {
return new JWTValidatorRSAImpl(jwtToken, localKeyId, publicKeyId, expiration, localAud, referenceAud,
decodedBigModulus, decodedBigExponent);
}
}
import java.math.BigInteger;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import static org.assertj.core.api.Assertions.*;
import static org.mockito.Mockito.*;
class JWTValidatorRSAImplTest {
private final PublicKey publicKey = new Object(); // replace
private final JwtWrapperResponse response = new Object(); // replace
private JWTValidatorRSAImpl jwtValidatorRSA;
@AfterEach
void tearDown() {
jwtValidatorRSA = null;
}
@Test
void jwtValidate_call_allValidationsWereAlsoCalled() throws Exception {
jwtValidatorRSA = JWTValidatorRSAImpl.builder().build();
jwtValidatorRSA = spy(jwtValidatorRSA);
doNothing().when(jwtValidatorRSA).validateStructure();
doNothing().when(jwtValidatorRSA).validateSignature();
doNothing().when(jwtValidatorRSA).validateTokenNotExpired();
doNothing().when(jwtValidatorRSA).validateAud();
jwtValidatorRSA.validate();
verify(jwtValidatorRSA).validateStructure();
verify(jwtValidatorRSA).validateSignature();
verify(jwtValidatorRSA).validateTokenNotExpired();
verify(jwtValidatorRSA).validateAud();
}
@ParameterizedTest
@ValueSource(strings = {"...", "123", "1.2abc"})
void validateStructure_invalidStructuredJwt_throwsMalformedJwtException(String badStructureJwt) {
jwtValidatorRSA = JWTValidatorRSAImpl.builder().jwtToken(badStructureJwt).build();
assertThatThrownBy(jwtValidatorRSA::validateStructure).isInstanceOf(MalformedJwtException.class);
}
@Test
void isValidSignature_validSignature_True() {
jwtValidatorRSA = JWTValidatorRSAImpl.builder()
.jwtToken(response.getIdTokenEncoded())
.decodedBigModulus(publicKey.getModulus())
.decodedBigExponent(publicKey.getExponent())
.localKeyId(response.getIdTokenEncoded())
.build();
assertThatCode(jwtValidatorRSA::validateSignature).doesNotThrowAnyException();
}
@Test
void isValidSignature_invalidSignatureModulus_False() {
BigInteger badModulus = publicKey.getModulus().add(BigInteger.ONE);
jwtValidatorRSA = JWTValidatorRSAImpl.builder()
.jwtToken(response.getIdTokenEncoded())
.decodedBigModulus(badModulus)
.decodedBigExponent(publicKey.getExponent())
.localKeyId(response.getIdTokenEncoded())
.build();
assertThatThrownBy(jwtValidatorRSA::validateSignature).isInstanceOf(BadJWTException.class);
}
@Test
void isValidSignature_invalidSignatureExponent_False() {
BigInteger badExponent = publicKey.getExponent().add(BigInteger.ONE);
jwtValidatorRSA = JWTValidatorRSAImpl.builder()
.jwtToken(response.getIdTokenEncoded())
.decodedBigModulus(publicKey.getModulus())
.decodedBigExponent(badExponent)
.localKeyId(response.getIdTokenEncoded())
.build();
assertThatThrownBy(jwtValidatorRSA::validateSignature).isInstanceOf(BadJWTException.class);
}
@Test
void validateSignature_validSignature_passes() {
jwtValidatorRSA = JWTValidatorRSAImpl.builder()
.jwtToken(response.getIdTokenEncoded())
.decodedBigModulus(publicKey.getModulus())
.decodedBigExponent(publicKey.getExponent())
.localKeyId(response.getIdTokenEncoded())
.build();
assertThatCode(jwtValidatorRSA::validateSignature).doesNotThrowAnyException();
}
@Test
void isTokenExpired_expirationDateInTheFuture_false() {
Instant expirationDateInTheFuture = Instant.now().plus(5, ChronoUnit.SECONDS);
long expirationTimeStamp = expirationDateInTheFuture.toEpochMilli() / 1000;
jwtValidatorRSA = JWTValidatorRSAImpl.builder().expiration(expirationTimeStamp).build();
assertThat(jwtValidatorRSA.isExpiredToken()).isFalse();
}
@Test
void isTokenExpired_expirationDateInThePast_true() {
Instant expirationDateInThePast = Instant.now().minus(1, ChronoUnit.SECONDS);
long expirationTimeStamp = expirationDateInThePast.toEpochMilli() / 1000;
jwtValidatorRSA = JWTValidatorRSAImpl.builder().expiration(expirationTimeStamp).build();
assertThat(jwtValidatorRSA.isExpiredToken()).isTrue();
}
@Test
void validateNotExpired_youngToken_passes() {
Instant expirationDateInTheFuture = Instant.now().plus(5, ChronoUnit.SECONDS);
long expirationTimeStamp = expirationDateInTheFuture.toEpochMilli() / 1000;
jwtValidatorRSA = JWTValidatorRSAImpl.builder().expiration(expirationTimeStamp).build();
assertThatCode(jwtValidatorRSA::validateTokenNotExpired).doesNotThrowAnyException();
}
@Test
void validateNotExpired_expiredToken_throwsBadJwtException() {
Instant expirationDateInThePast = Instant.now().minus(2, ChronoUnit.SECONDS);
long expirationTimeStamp = expirationDateInThePast.toEpochMilli() / 1000;
jwtValidatorRSA = JWTValidatorRSAImpl.builder().expiration(expirationTimeStamp).build();
assertThatThrownBy(jwtValidatorRSA::validateTokenNotExpired).isInstanceOf(BadJWTException.class);
}
@Test
void isValidAud_matching_true() {
jwtValidatorRSA = JWTValidatorRSAImpl.builder().localAud("abc").referenceAud("abc").build();
assertThat(jwtValidatorRSA.isValidAud()).isTrue();
}
@ParameterizedTest
@ValueSource(strings = {"ABC", "aBc", "def"})
void isValidAud_notMatching_false(String badAud) {
jwtValidatorRSA = JWTValidatorRSAImpl.builder().localAud(badAud).referenceAud("abc").build();
assertThat(jwtValidatorRSA.isValidAud()).isFalse();
}
@Test
void validateAud_matching_passes() {
jwtValidatorRSA = JWTValidatorRSAImpl.builder().localAud("abc").referenceAud("abc").build();
assertThatCode(jwtValidatorRSA::validateAud).doesNotThrowAnyException();
}
@Test
void validateAud_notMatching_throwsBadJwtException() {
jwtValidatorRSA = JWTValidatorRSAImpl.builder().localAud("abc").referenceAud("def").build();
assertThatThrownBy(jwtValidatorRSA::validateAud).isInstanceOf(BadJWTException.class);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment