Created
July 27, 2022 07:53
-
-
Save dehidehidehi/53f9489ac1ffb923f69603f05db3c7bd to your computer and use it in GitHub Desktop.
Java - JWT Validator (RSA) Implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import io.jsonwebtoken.MalformedJwtException; | |
| import java.time.Instant; | |
| @AllArgsConstructor | |
| @Getter(AccessLevel.PROTECTED) | |
| @Slf4j | |
| public abstract class AbstractSimpleJWTValidator implements JWTValidator { | |
| private String jwtToken; | |
| private String localKeyId; | |
| private String publicKeyId; | |
| private Long expiration; | |
| /** Aud returned by the Jwt. **/ | |
| private String localAud; | |
| /** Aud which was sent by client application. **/ | |
| private String referenceAud; | |
| @Override public boolean isValidAud() { | |
| return referenceAud.equals(localAud); | |
| } | |
| @Override public boolean isExpiredToken() { | |
| Instant expirationInstant = Instant.ofEpochSecond(expiration); | |
| return expirationInstant.isBefore(Instant.now()); | |
| } | |
| @Override public void validateStructure() throws MalformedJwtException { | |
| if (!JWTValidator.isValidStructure(jwtToken)) | |
| throw new MalformedJwtException("JWT structure is invalid."); | |
| } | |
| @Override public void validateSignature() throws BadJWTException { | |
| if (!isValidSignature()) | |
| throw new BadJWTException("JWT signature is invalid."); | |
| } | |
| @Override public void validateAud() throws BadJWTException { | |
| if (!isValidAud()) | |
| throw new BadJWTException("Aud doesn't match."); | |
| } | |
| @Override public void validateTokenNotExpired() throws BadJWTException { | |
| if (isExpiredToken()) | |
| throw new BadJWTException("JWT is expired."); | |
| } | |
| @Override public void validateClaims() throws BadJWTException { | |
| throw new UnsupportedOperationException("Not implemented"); | |
| } | |
| @Override public void validate() throws JwtException { | |
| validateStructure(); | |
| validateTokenNotExpired(); | |
| validateAud(); | |
| validateSignature(); | |
| log.debug("JWT Token is valid."); | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| public class BadJWTException extends JwtException { | |
| public BadJWTException(String message) { | |
| super(message); | |
| } | |
| public BadJWTException(String message, Throwable cause) { | |
| super(message, cause); | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| public class JwtException extends Exception { | |
| public JwtException(String message) { | |
| super(message); | |
| } | |
| public JwtException(String message, Throwable cause) { | |
| super(message, cause); | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import io.jsonwebtoken.MalformedJwtException; | |
| public interface JWTValidator { | |
| /** Checks whether jwtToken's structure contains 2 period separators. **/ | |
| static boolean isValidStructure(String jwtToken) { | |
| boolean hasTwoPeriods = jwtToken.split("\\.").length == 3; | |
| boolean hasMinimumLength = jwtToken.length() >= 5; | |
| return hasMinimumLength && hasTwoPeriods; | |
| } | |
| boolean isValidSignature() throws BadJWTException; | |
| boolean isValidAud(); | |
| boolean isExpiredToken(); | |
| void validateStructure() throws MalformedJwtException; | |
| void validateSignature() throws BadJWTException; | |
| void validateAud() throws BadJWTException; | |
| void validateTokenNotExpired() throws BadJWTException; | |
| void validateClaims() throws BadJWTException; | |
| void validate() throws JwtException; | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import java.math.BigInteger; | |
| import java.security.*; | |
| import java.security.spec.RSAPublicKeySpec; | |
| import java.util.Base64; | |
| public class JWTValidatorRSAImpl extends AbstractSimpleJWTValidator { | |
| private final BigInteger decodedBigModulus; | |
| private final BigInteger decodedBigExponent; | |
| @Builder | |
| public JWTValidatorRSAImpl(String jwtToken, String localKeyId, String publicKeyId, Long expiration, | |
| String localAud, String referenceAud, BigInteger decodedBigModulus, BigInteger decodedBigExponent) { | |
| super(jwtToken, localKeyId, publicKeyId, expiration, localAud, referenceAud); | |
| this.decodedBigModulus = decodedBigModulus; | |
| this.decodedBigExponent = decodedBigExponent; | |
| } | |
| @Override public boolean isValidSignature() throws BadJWTException { | |
| // Declare | |
| PublicKey publicKey = computePublicKeyFromModulusAndExponent(); | |
| String signedData = getJwtToken().substring(0, getJwtToken().lastIndexOf(".")); | |
| String signatureB64u = getJwtToken().substring(getJwtToken().lastIndexOf(".") + 1); | |
| byte[] signature = Base64.getUrlDecoder().decode(signatureB64u); | |
| Signature sig = tryGetSignatureInstanceSha256WithRSA(); | |
| // Logic | |
| try { | |
| sig.initVerify(publicKey); | |
| sig.update(signedData.getBytes()); | |
| boolean isVerify = sig.verify(signature); | |
| return isVerify; | |
| } catch (SignatureException | InvalidKeyException e) { | |
| throw new BadJWTException(e.getMessage(), e); | |
| } | |
| } | |
| private PublicKey computePublicKeyFromModulusAndExponent() { | |
| RSAPublicKeySpec rsaPublicKeySpec = new RSAPublicKeySpec(decodedBigModulus, decodedBigExponent); | |
| return tryGetRSAPublicKey(rsaPublicKeySpec); | |
| } | |
| @SneakyThrows | |
| private PublicKey tryGetRSAPublicKey(RSAPublicKeySpec rsaPublicKeySpec) { | |
| return KeyFactory.getInstance("RSA").generatePublic(rsaPublicKeySpec); | |
| } | |
| @SneakyThrows | |
| private Signature tryGetSignatureInstanceSha256WithRSA() { | |
| return Signature.getInstance("SHA256withRSA"); | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import java.math.BigInteger; | |
| public class JWTValidatorRSAImplBuilder { | |
| private String jwtToken; | |
| private String localKeyId; | |
| private String publicKeyId; | |
| private Long expiration; | |
| private String localAud; | |
| private String referenceAud; | |
| private BigInteger decodedBigModulus; | |
| private BigInteger decodedBigExponent; | |
| public JWTValidatorRSAImplBuilder setJwtToken(String jwtToken) { | |
| this.jwtToken = jwtToken; | |
| return this; | |
| } | |
| public JWTValidatorRSAImplBuilder setLocalKeyId(String localKeyId) { | |
| this.localKeyId = localKeyId; | |
| return this; | |
| } | |
| public JWTValidatorRSAImplBuilder setPublicKeyId(String publicKeyId) { | |
| this.publicKeyId = publicKeyId; | |
| return this; | |
| } | |
| public JWTValidatorRSAImplBuilder setExpiration(Long expiration) { | |
| this.expiration = expiration; | |
| return this; | |
| } | |
| public JWTValidatorRSAImplBuilder setLocalAud(String localAud) { | |
| this.localAud = localAud; | |
| return this; | |
| } | |
| public JWTValidatorRSAImplBuilder setReferenceAud(String referenceAud) { | |
| this.referenceAud = referenceAud; | |
| return this; | |
| } | |
| public JWTValidatorRSAImplBuilder setDecodedBigModulus(BigInteger decodedBigModulus) { | |
| this.decodedBigModulus = decodedBigModulus; | |
| return this; | |
| } | |
| public JWTValidatorRSAImplBuilder setDecodedBigExponent(BigInteger decodedBigExponent) { | |
| this.decodedBigExponent = decodedBigExponent; | |
| return this; | |
| } | |
| public JWTValidatorRSAImpl createJWTValidatorRSAImpl() { | |
| return new JWTValidatorRSAImpl(jwtToken, localKeyId, publicKeyId, expiration, localAud, referenceAud, | |
| decodedBigModulus, decodedBigExponent); | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import java.math.BigInteger; | |
| import java.time.Instant; | |
| import java.time.temporal.ChronoUnit; | |
| import static org.assertj.core.api.Assertions.*; | |
| import static org.mockito.Mockito.*; | |
| class JWTValidatorRSAImplTest { | |
| private final PublicKey publicKey = new Object(); // replace | |
| private final JwtWrapperResponse response = new Object(); // replace | |
| private JWTValidatorRSAImpl jwtValidatorRSA; | |
| @AfterEach | |
| void tearDown() { | |
| jwtValidatorRSA = null; | |
| } | |
| @Test | |
| void jwtValidate_call_allValidationsWereAlsoCalled() throws Exception { | |
| jwtValidatorRSA = JWTValidatorRSAImpl.builder().build(); | |
| jwtValidatorRSA = spy(jwtValidatorRSA); | |
| doNothing().when(jwtValidatorRSA).validateStructure(); | |
| doNothing().when(jwtValidatorRSA).validateSignature(); | |
| doNothing().when(jwtValidatorRSA).validateTokenNotExpired(); | |
| doNothing().when(jwtValidatorRSA).validateAud(); | |
| jwtValidatorRSA.validate(); | |
| verify(jwtValidatorRSA).validateStructure(); | |
| verify(jwtValidatorRSA).validateSignature(); | |
| verify(jwtValidatorRSA).validateTokenNotExpired(); | |
| verify(jwtValidatorRSA).validateAud(); | |
| } | |
| @ParameterizedTest | |
| @ValueSource(strings = {"...", "123", "1.2abc"}) | |
| void validateStructure_invalidStructuredJwt_throwsMalformedJwtException(String badStructureJwt) { | |
| jwtValidatorRSA = JWTValidatorRSAImpl.builder().jwtToken(badStructureJwt).build(); | |
| assertThatThrownBy(jwtValidatorRSA::validateStructure).isInstanceOf(MalformedJwtException.class); | |
| } | |
| @Test | |
| void isValidSignature_validSignature_True() { | |
| jwtValidatorRSA = JWTValidatorRSAImpl.builder() | |
| .jwtToken(response.getIdTokenEncoded()) | |
| .decodedBigModulus(publicKey.getModulus()) | |
| .decodedBigExponent(publicKey.getExponent()) | |
| .localKeyId(response.getIdTokenEncoded()) | |
| .build(); | |
| assertThatCode(jwtValidatorRSA::validateSignature).doesNotThrowAnyException(); | |
| } | |
| @Test | |
| void isValidSignature_invalidSignatureModulus_False() { | |
| BigInteger badModulus = publicKey.getModulus().add(BigInteger.ONE); | |
| jwtValidatorRSA = JWTValidatorRSAImpl.builder() | |
| .jwtToken(response.getIdTokenEncoded()) | |
| .decodedBigModulus(badModulus) | |
| .decodedBigExponent(publicKey.getExponent()) | |
| .localKeyId(response.getIdTokenEncoded()) | |
| .build(); | |
| assertThatThrownBy(jwtValidatorRSA::validateSignature).isInstanceOf(BadJWTException.class); | |
| } | |
| @Test | |
| void isValidSignature_invalidSignatureExponent_False() { | |
| BigInteger badExponent = publicKey.getExponent().add(BigInteger.ONE); | |
| jwtValidatorRSA = JWTValidatorRSAImpl.builder() | |
| .jwtToken(response.getIdTokenEncoded()) | |
| .decodedBigModulus(publicKey.getModulus()) | |
| .decodedBigExponent(badExponent) | |
| .localKeyId(response.getIdTokenEncoded()) | |
| .build(); | |
| assertThatThrownBy(jwtValidatorRSA::validateSignature).isInstanceOf(BadJWTException.class); | |
| } | |
| @Test | |
| void validateSignature_validSignature_passes() { | |
| jwtValidatorRSA = JWTValidatorRSAImpl.builder() | |
| .jwtToken(response.getIdTokenEncoded()) | |
| .decodedBigModulus(publicKey.getModulus()) | |
| .decodedBigExponent(publicKey.getExponent()) | |
| .localKeyId(response.getIdTokenEncoded()) | |
| .build(); | |
| assertThatCode(jwtValidatorRSA::validateSignature).doesNotThrowAnyException(); | |
| } | |
| @Test | |
| void isTokenExpired_expirationDateInTheFuture_false() { | |
| Instant expirationDateInTheFuture = Instant.now().plus(5, ChronoUnit.SECONDS); | |
| long expirationTimeStamp = expirationDateInTheFuture.toEpochMilli() / 1000; | |
| jwtValidatorRSA = JWTValidatorRSAImpl.builder().expiration(expirationTimeStamp).build(); | |
| assertThat(jwtValidatorRSA.isExpiredToken()).isFalse(); | |
| } | |
| @Test | |
| void isTokenExpired_expirationDateInThePast_true() { | |
| Instant expirationDateInThePast = Instant.now().minus(1, ChronoUnit.SECONDS); | |
| long expirationTimeStamp = expirationDateInThePast.toEpochMilli() / 1000; | |
| jwtValidatorRSA = JWTValidatorRSAImpl.builder().expiration(expirationTimeStamp).build(); | |
| assertThat(jwtValidatorRSA.isExpiredToken()).isTrue(); | |
| } | |
| @Test | |
| void validateNotExpired_youngToken_passes() { | |
| Instant expirationDateInTheFuture = Instant.now().plus(5, ChronoUnit.SECONDS); | |
| long expirationTimeStamp = expirationDateInTheFuture.toEpochMilli() / 1000; | |
| jwtValidatorRSA = JWTValidatorRSAImpl.builder().expiration(expirationTimeStamp).build(); | |
| assertThatCode(jwtValidatorRSA::validateTokenNotExpired).doesNotThrowAnyException(); | |
| } | |
| @Test | |
| void validateNotExpired_expiredToken_throwsBadJwtException() { | |
| Instant expirationDateInThePast = Instant.now().minus(2, ChronoUnit.SECONDS); | |
| long expirationTimeStamp = expirationDateInThePast.toEpochMilli() / 1000; | |
| jwtValidatorRSA = JWTValidatorRSAImpl.builder().expiration(expirationTimeStamp).build(); | |
| assertThatThrownBy(jwtValidatorRSA::validateTokenNotExpired).isInstanceOf(BadJWTException.class); | |
| } | |
| @Test | |
| void isValidAud_matching_true() { | |
| jwtValidatorRSA = JWTValidatorRSAImpl.builder().localAud("abc").referenceAud("abc").build(); | |
| assertThat(jwtValidatorRSA.isValidAud()).isTrue(); | |
| } | |
| @ParameterizedTest | |
| @ValueSource(strings = {"ABC", "aBc", "def"}) | |
| void isValidAud_notMatching_false(String badAud) { | |
| jwtValidatorRSA = JWTValidatorRSAImpl.builder().localAud(badAud).referenceAud("abc").build(); | |
| assertThat(jwtValidatorRSA.isValidAud()).isFalse(); | |
| } | |
| @Test | |
| void validateAud_matching_passes() { | |
| jwtValidatorRSA = JWTValidatorRSAImpl.builder().localAud("abc").referenceAud("abc").build(); | |
| assertThatCode(jwtValidatorRSA::validateAud).doesNotThrowAnyException(); | |
| } | |
| @Test | |
| void validateAud_notMatching_throwsBadJwtException() { | |
| jwtValidatorRSA = JWTValidatorRSAImpl.builder().localAud("abc").referenceAud("def").build(); | |
| assertThatThrownBy(jwtValidatorRSA::validateAud).isInstanceOf(BadJWTException.class); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment