Skip to content

Instantly share code, notes, and snippets.

@jimhe
Last active July 12, 2024 07:42
Show Gist options
  • Save jimhe/6314378 to your computer and use it in GitHub Desktop.
Save jimhe/6314378 to your computer and use it in GitHub Desktop.
Servlet to handle SAML Auth request and response. on GET /saml, it will redirect to the ID Provider with the proper SAMLRequest parameter. on POST /saml, it will parse the POST parameter for a properly signed and successful response before allowing the user in.
package com.comprehend.servlet;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.opensaml.Configuration;
import org.opensaml.common.binding.BasicSAMLMessageContext;
import org.opensaml.common.xml.SAMLConstants;
import org.opensaml.saml2.binding.decoding.HTTPPostDecoder;
import org.opensaml.saml2.binding.encoding.HTTPRedirectDeflateEncoder;
import org.opensaml.saml2.core.*;
import org.opensaml.saml2.metadata.impl.SingleSignOnServiceImpl;
import org.opensaml.security.SAMLSignatureProfileValidator;
import org.opensaml.ws.message.MessageContext;
import org.opensaml.ws.message.decoder.MessageDecodingException;
import org.opensaml.ws.message.encoder.MessageEncodingException;
import org.opensaml.ws.transport.http.HttpServletRequestAdapter;
import org.opensaml.ws.transport.http.HttpServletResponseAdapter;
import org.opensaml.xml.ConfigurationException;
import org.opensaml.xml.XMLConfigurator;
import org.opensaml.xml.XMLObjectBuilderFactory;
import org.opensaml.xml.parse.BasicParserPool;
import org.opensaml.xml.parse.StaticBasicParserPool;
import org.opensaml.xml.parse.XMLParserException;
import org.opensaml.xml.security.x509.BasicX509Credential;
import org.opensaml.xml.signature.SignatureValidator;
import org.opensaml.xml.validation.ValidationException;
import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.xml.namespace.QName;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.PublicKey;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.X509EncodedKeySpec;
import java.util.UUID;
@WebServlet("/saml")
public class SAMLServlet extends HttpServlet {
private static String redirectURL = "http://localhost:8080/saml";
private static String idProviderURL = "https://secureauthtest.bsci.com/secureauth20";
private static int millisecDiffAllowed = 1000;
private static String certificateFilePath = "/tmp/bsci-saml.cer"; // add this as resource?
private static String[] xmlToolingConfigs = {
"/default-config.xml",
"/schema-config.xml",
"/signature-config.xml",
"/signature-validation-config.xml",
"/encryption-config.xml",
"/encryption-validation-config.xml",
"/soap11-config.xml",
"/wsfed11-protocol-config.xml",
"/saml1-assertion-config.xml",
"/saml1-protocol-config.xml",
"/saml1-core-validation-config.xml",
"/saml2-assertion-config.xml",
"/saml2-protocol-config.xml",
"/saml2-core-validation-config.xml",
"/saml1-metadata-config.xml",
"/saml2-metadata-config.xml",
"/saml2-metadata-validation-config.xml",
"/saml2-metadata-attr-config.xml",
"/saml2-metadata-idp-discovery-config.xml",
"/saml2-metadata-ui-config.xml",
"/saml2-protocol-thirdparty-config.xml",
"/saml2-metadata-query-config.xml",
"/saml2-assertion-delegation-restriction-config.xml",
"/saml2-ecp-config.xml",
"/xacml10-saml2-profile-config.xml",
"/xacml11-saml2-profile-config.xml",
"/xacml20-context-config.xml",
"/xacml20-policy-config.xml",
"/xacml2-saml2-profile-config.xml",
"/xacml3-saml2-profile-config.xml",
"/wsaddressing-config.xml",
"/wssecurity-config.xml",
};
@Override
public void init() {
try {
XMLConfigurator configurator = new XMLConfigurator();
Class clazz = Configuration.class;
for (String config : xmlToolingConfigs) {
configurator.load(clazz.getResourceAsStream(config));
}
StaticBasicParserPool pp = new StaticBasicParserPool();
pp.setMaxPoolSize(50);
pp.initialize();
Configuration.setParserPool(pp);
} catch (ConfigurationException e) {
e.printStackTrace();
} catch (XMLParserException e) {
e.printStackTrace();
}
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
XMLObjectBuilderFactory builderFactory = Configuration.getBuilderFactory();
AuthnRequest authnRequest = (AuthnRequest) builderFactory.getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME)
.buildObject(AuthnRequest.DEFAULT_ELEMENT_NAME);
String id = UUID.randomUUID().toString();
authnRequest.setID(id);
authnRequest.setAssertionConsumerServiceURL(redirectURL);
Issuer issuer = (Issuer) builderFactory.getBuilder(Issuer.DEFAULT_ELEMENT_NAME)
.buildObject(Issuer.DEFAULT_ELEMENT_NAME);
issuer.setValue("secureauthtest.bsci.com");
authnRequest.setIssuer(issuer);
NameIDPolicy nameIDPolicy = (NameIDPolicy) builderFactory.getBuilder(NameIDPolicy.DEFAULT_ELEMENT_NAME)
.buildObject(NameIDPolicy.DEFAULT_ELEMENT_NAME);
nameIDPolicy.setAllowCreate(true);
nameIDPolicy.setFormat("urn:oasis:names:tc:SAML:2.0:nameid-format:transient");
authnRequest.setNameIDPolicy(nameIDPolicy);
BasicSAMLMessageContext context = new BasicSAMLMessageContext<>();
context.setOutboundSAMLMessage(authnRequest);
HttpServletResponseAdapter responseAdapter = new HttpServletResponseAdapter(resp, true);
context.setOutboundMessageTransport(responseAdapter);
QName qName = new QName(SAMLConstants.SAML20MD_NS, "SingleSignOnService", SAMLConstants.SAML20MD_PREFIX);
SingleSignOnServiceImpl endpoint = (SingleSignOnServiceImpl) builderFactory.getBuilder(qName)
.buildObject(qName);
endpoint.setLocation(idProviderURL);
context.setPeerEntityEndpoint(endpoint);
HTTPRedirectDeflateEncoder encoder = new HTTPRedirectDeflateEncoder();
try {
// this performs the redirect
encoder.encode(context);
} catch (MessageEncodingException e) {
e.printStackTrace();
}
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
HTTPPostDecoder decoder = new HTTPPostDecoder(new BasicParserPool());
HttpServletRequestAdapter adapter = new HttpServletRequestAdapter(req);
MessageContext context = new BasicSAMLMessageContext();
context.setInboundMessageTransport(adapter);
try {
decoder.decode(context);
} catch (MessageDecodingException|org.opensaml.xml.security.SecurityException e) {
e.printStackTrace();
denyEntry();
return;
}
Response response = (Response) context.getInboundMessage();
Status status = response.getStatus();
// checks for success
if (!status.getStatusCode().getValue().equals(StatusCode.SUCCESS_URI)) {
denyEntry();
return;
}
// performs a cryptographic validation of the signature
try {
SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator();
profileValidator.validate(response.getSignature());
} catch (ValidationException e) {
e.printStackTrace();
denyEntry();
return;
}
// checks that signature was signed by the trusted CA
try {
File certificateFile = new File(certificateFilePath);
FileInputStream certInputStream = new FileInputStream(certificateFile);
CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
X509Certificate certificate = (X509Certificate) certificateFactory.generateCertificate(certInputStream);
X509EncodedKeySpec publicKeySpec = new X509EncodedKeySpec(certificate.getPublicKey().getEncoded());
KeyFactory keyFactory = KeyFactory.getInstance("RSA");
PublicKey publicKey = keyFactory.generatePublic(publicKeySpec);
BasicX509Credential publicCredential = new BasicX509Credential();
publicCredential.setPublicKey(publicKey);
SignatureValidator signatureValidator = new SignatureValidator(publicCredential);
signatureValidator.validate(response.getSignature());
} catch (ValidationException|InvalidKeySpecException|CertificateException|NoSuchAlgorithmException e) {
e.printStackTrace();
denyEntry();
return;
}
// performs a time check
DateTime now = new DateTime().withZone(DateTimeZone.UTC);
int milliSecDifference = Math.abs(
now.getMillisOfDay() -
response.getIssueInstant().withZone(DateTimeZone.UTC).getMillisOfDay());
if (milliSecDifference > millisecDiffAllowed) {
denyEntry();
return;
}
// allow in
allowEntry();
}
private void denyEntry() {
}
private void allowEntry() {
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment