Last active
July 12, 2024 07:42
-
-
Save jimhe/6314378 to your computer and use it in GitHub Desktop.
Servlet to handle SAML Auth request and response. on GET /saml, it will redirect to the ID Provider with the proper SAMLRequest parameter.
on POST /saml, it will parse the POST parameter for a properly signed and successful response before allowing the user in.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.comprehend.servlet; | |
import org.joda.time.DateTime; | |
import org.joda.time.DateTimeZone; | |
import org.opensaml.Configuration; | |
import org.opensaml.common.binding.BasicSAMLMessageContext; | |
import org.opensaml.common.xml.SAMLConstants; | |
import org.opensaml.saml2.binding.decoding.HTTPPostDecoder; | |
import org.opensaml.saml2.binding.encoding.HTTPRedirectDeflateEncoder; | |
import org.opensaml.saml2.core.*; | |
import org.opensaml.saml2.metadata.impl.SingleSignOnServiceImpl; | |
import org.opensaml.security.SAMLSignatureProfileValidator; | |
import org.opensaml.ws.message.MessageContext; | |
import org.opensaml.ws.message.decoder.MessageDecodingException; | |
import org.opensaml.ws.message.encoder.MessageEncodingException; | |
import org.opensaml.ws.transport.http.HttpServletRequestAdapter; | |
import org.opensaml.ws.transport.http.HttpServletResponseAdapter; | |
import org.opensaml.xml.ConfigurationException; | |
import org.opensaml.xml.XMLConfigurator; | |
import org.opensaml.xml.XMLObjectBuilderFactory; | |
import org.opensaml.xml.parse.BasicParserPool; | |
import org.opensaml.xml.parse.StaticBasicParserPool; | |
import org.opensaml.xml.parse.XMLParserException; | |
import org.opensaml.xml.security.x509.BasicX509Credential; | |
import org.opensaml.xml.signature.SignatureValidator; | |
import org.opensaml.xml.validation.ValidationException; | |
import javax.servlet.ServletException; | |
import javax.servlet.annotation.WebServlet; | |
import javax.servlet.http.HttpServlet; | |
import javax.servlet.http.HttpServletRequest; | |
import javax.servlet.http.HttpServletResponse; | |
import javax.xml.namespace.QName; | |
import java.io.File; | |
import java.io.FileInputStream; | |
import java.io.IOException; | |
import java.security.KeyFactory; | |
import java.security.NoSuchAlgorithmException; | |
import java.security.PublicKey; | |
import java.security.cert.CertificateException; | |
import java.security.cert.CertificateFactory; | |
import java.security.cert.X509Certificate; | |
import java.security.spec.InvalidKeySpecException; | |
import java.security.spec.X509EncodedKeySpec; | |
import java.util.UUID; | |
@WebServlet("/saml") | |
public class SAMLServlet extends HttpServlet { | |
private static String redirectURL = "http://localhost:8080/saml"; | |
private static String idProviderURL = "https://secureauthtest.bsci.com/secureauth20"; | |
private static int millisecDiffAllowed = 1000; | |
private static String certificateFilePath = "/tmp/bsci-saml.cer"; // add this as resource? | |
private static String[] xmlToolingConfigs = { | |
"/default-config.xml", | |
"/schema-config.xml", | |
"/signature-config.xml", | |
"/signature-validation-config.xml", | |
"/encryption-config.xml", | |
"/encryption-validation-config.xml", | |
"/soap11-config.xml", | |
"/wsfed11-protocol-config.xml", | |
"/saml1-assertion-config.xml", | |
"/saml1-protocol-config.xml", | |
"/saml1-core-validation-config.xml", | |
"/saml2-assertion-config.xml", | |
"/saml2-protocol-config.xml", | |
"/saml2-core-validation-config.xml", | |
"/saml1-metadata-config.xml", | |
"/saml2-metadata-config.xml", | |
"/saml2-metadata-validation-config.xml", | |
"/saml2-metadata-attr-config.xml", | |
"/saml2-metadata-idp-discovery-config.xml", | |
"/saml2-metadata-ui-config.xml", | |
"/saml2-protocol-thirdparty-config.xml", | |
"/saml2-metadata-query-config.xml", | |
"/saml2-assertion-delegation-restriction-config.xml", | |
"/saml2-ecp-config.xml", | |
"/xacml10-saml2-profile-config.xml", | |
"/xacml11-saml2-profile-config.xml", | |
"/xacml20-context-config.xml", | |
"/xacml20-policy-config.xml", | |
"/xacml2-saml2-profile-config.xml", | |
"/xacml3-saml2-profile-config.xml", | |
"/wsaddressing-config.xml", | |
"/wssecurity-config.xml", | |
}; | |
@Override | |
public void init() { | |
try { | |
XMLConfigurator configurator = new XMLConfigurator(); | |
Class clazz = Configuration.class; | |
for (String config : xmlToolingConfigs) { | |
configurator.load(clazz.getResourceAsStream(config)); | |
} | |
StaticBasicParserPool pp = new StaticBasicParserPool(); | |
pp.setMaxPoolSize(50); | |
pp.initialize(); | |
Configuration.setParserPool(pp); | |
} catch (ConfigurationException e) { | |
e.printStackTrace(); | |
} catch (XMLParserException e) { | |
e.printStackTrace(); | |
} | |
} | |
@Override | |
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { | |
XMLObjectBuilderFactory builderFactory = Configuration.getBuilderFactory(); | |
AuthnRequest authnRequest = (AuthnRequest) builderFactory.getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME) | |
.buildObject(AuthnRequest.DEFAULT_ELEMENT_NAME); | |
String id = UUID.randomUUID().toString(); | |
authnRequest.setID(id); | |
authnRequest.setAssertionConsumerServiceURL(redirectURL); | |
Issuer issuer = (Issuer) builderFactory.getBuilder(Issuer.DEFAULT_ELEMENT_NAME) | |
.buildObject(Issuer.DEFAULT_ELEMENT_NAME); | |
issuer.setValue("secureauthtest.bsci.com"); | |
authnRequest.setIssuer(issuer); | |
NameIDPolicy nameIDPolicy = (NameIDPolicy) builderFactory.getBuilder(NameIDPolicy.DEFAULT_ELEMENT_NAME) | |
.buildObject(NameIDPolicy.DEFAULT_ELEMENT_NAME); | |
nameIDPolicy.setAllowCreate(true); | |
nameIDPolicy.setFormat("urn:oasis:names:tc:SAML:2.0:nameid-format:transient"); | |
authnRequest.setNameIDPolicy(nameIDPolicy); | |
BasicSAMLMessageContext context = new BasicSAMLMessageContext<>(); | |
context.setOutboundSAMLMessage(authnRequest); | |
HttpServletResponseAdapter responseAdapter = new HttpServletResponseAdapter(resp, true); | |
context.setOutboundMessageTransport(responseAdapter); | |
QName qName = new QName(SAMLConstants.SAML20MD_NS, "SingleSignOnService", SAMLConstants.SAML20MD_PREFIX); | |
SingleSignOnServiceImpl endpoint = (SingleSignOnServiceImpl) builderFactory.getBuilder(qName) | |
.buildObject(qName); | |
endpoint.setLocation(idProviderURL); | |
context.setPeerEntityEndpoint(endpoint); | |
HTTPRedirectDeflateEncoder encoder = new HTTPRedirectDeflateEncoder(); | |
try { | |
// this performs the redirect | |
encoder.encode(context); | |
} catch (MessageEncodingException e) { | |
e.printStackTrace(); | |
} | |
} | |
@Override | |
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { | |
HTTPPostDecoder decoder = new HTTPPostDecoder(new BasicParserPool()); | |
HttpServletRequestAdapter adapter = new HttpServletRequestAdapter(req); | |
MessageContext context = new BasicSAMLMessageContext(); | |
context.setInboundMessageTransport(adapter); | |
try { | |
decoder.decode(context); | |
} catch (MessageDecodingException|org.opensaml.xml.security.SecurityException e) { | |
e.printStackTrace(); | |
denyEntry(); | |
return; | |
} | |
Response response = (Response) context.getInboundMessage(); | |
Status status = response.getStatus(); | |
// checks for success | |
if (!status.getStatusCode().getValue().equals(StatusCode.SUCCESS_URI)) { | |
denyEntry(); | |
return; | |
} | |
// performs a cryptographic validation of the signature | |
try { | |
SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator(); | |
profileValidator.validate(response.getSignature()); | |
} catch (ValidationException e) { | |
e.printStackTrace(); | |
denyEntry(); | |
return; | |
} | |
// checks that signature was signed by the trusted CA | |
try { | |
File certificateFile = new File(certificateFilePath); | |
FileInputStream certInputStream = new FileInputStream(certificateFile); | |
CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509"); | |
X509Certificate certificate = (X509Certificate) certificateFactory.generateCertificate(certInputStream); | |
X509EncodedKeySpec publicKeySpec = new X509EncodedKeySpec(certificate.getPublicKey().getEncoded()); | |
KeyFactory keyFactory = KeyFactory.getInstance("RSA"); | |
PublicKey publicKey = keyFactory.generatePublic(publicKeySpec); | |
BasicX509Credential publicCredential = new BasicX509Credential(); | |
publicCredential.setPublicKey(publicKey); | |
SignatureValidator signatureValidator = new SignatureValidator(publicCredential); | |
signatureValidator.validate(response.getSignature()); | |
} catch (ValidationException|InvalidKeySpecException|CertificateException|NoSuchAlgorithmException e) { | |
e.printStackTrace(); | |
denyEntry(); | |
return; | |
} | |
// performs a time check | |
DateTime now = new DateTime().withZone(DateTimeZone.UTC); | |
int milliSecDifference = Math.abs( | |
now.getMillisOfDay() - | |
response.getIssueInstant().withZone(DateTimeZone.UTC).getMillisOfDay()); | |
if (milliSecDifference > millisecDiffAllowed) { | |
denyEntry(); | |
return; | |
} | |
// allow in | |
allowEntry(); | |
} | |
private void denyEntry() { | |
} | |
private void allowEntry() { | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment