Wednesday, July 22, 2009

Receiving a SAML Response with a Java servlet

As an addendum to my previous post, if you need to receive a SAML Response in a Java servlet using OpenSAML you can use this code. This is, obviously, more likely than needing to create a Response object from an XML file. The 'request' variable is just an input to the processRequest method on the servlet.

import org.opensaml.ws.message.MessageContext;
import org.opensaml.ws.transport.http.HttpServletRequestAdapter;
import org.opensaml.common.binding.BasicSAMLMessageContext;
import org.opensaml.saml2.binding.decoding.HTTPPostDecoder;
import org.opensaml.saml2.core.Response;

//get the message context
MessageContext messageContext = new BasicSAMLMessageContext();
messageContext.setInboundMessageTransport(new HttpServletRequestAdapter(request));
HTTPPostDecoder samlMessageDecoder = new HTTPPostDecoder();
samlMessageDecoder.decode(messageContext);

//get the SAML Response
Response samlResponse = (Response)messageContext.getInboundMessage();

Monday, July 6, 2009

Processing SAML in Java using OpenSAML

I'm currently doing all of my SAML 2.0 work in a .NET environment, but I wanted to verify my SAML in Java too so I created this tester utilizing OpenSAML. It certainly isn't pretty at this point, and I'm certainly not saying it represents any best-practices, but I'm going to post it in its current state because I think it illustrates some useful things. It's basically just unmarshalling xml to an OpenSAML response object, verifying the signature on the response, and decrypting an assertion. Nothing terribly mind-blowing, but it might prove useful to someone.

import org.w3c.dom.Document;
import org.w3c.dom.Element;

import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;

import java.util.List;

import javax.xml.namespace.QName;
import javax.xml.validation.Schema;

import java.security.KeyFactory;
import java.security.KeyStore;
import java.security.PublicKey;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPrivateKey;
import java.security.spec.X509EncodedKeySpec;

import org.opensaml.DefaultBootstrap;
import org.opensaml.common.xml.SAMLSchemaBuilder;
import org.opensaml.saml2.core.Assertion;
import org.opensaml.saml2.core.Attribute;
import org.opensaml.saml2.core.AttributeStatement;
import org.opensaml.saml2.core.Response;
import org.opensaml.saml2.encryption.Decrypter;
import org.opensaml.xml.Configuration;
import org.opensaml.xml.XMLObject;
import org.opensaml.xml.encryption.DecryptionException;
import org.opensaml.xml.encryption.InlineEncryptedKeyResolver;
import org.opensaml.xml.io.Unmarshaller;
import org.opensaml.xml.parse.BasicParserPool;
import org.opensaml.xml.security.keyinfo.StaticKeyInfoCredentialResolver;
import org.opensaml.xml.security.x509.BasicX509Credential;
import org.opensaml.xml.signature.Signature;
import org.opensaml.xml.signature.SignatureValidator;
import org.opensaml.xml.validation.ValidationException;


public class Tester
{
public static void main(String[] args)
{
try
{
//initialize the opensaml library
DefaultBootstrap.bootstrap();

Schema schema = SAMLSchemaBuilder.getSAML11Schema();

//get parser pool manager
BasicParserPool parserPoolManager = new BasicParserPool();
parserPoolManager.setNamespaceAware(true);
parserPoolManager.setIgnoreElementContentWhitespace(true);
parserPoolManager.setSchema(schema);

//grab the xml file
File xmlFile = new File("C:\\Documents and Settings\\kgellis\\My Documents\\Eclipse\\Workspace\\MyOpenSamlTester\\Files\\Raw_AssertionEncrypted.xml");

//parse xml file
FileInputStream fileInputStream = new FileInputStream(xmlFile);
InputStream inputStream = fileInputStream;
Document document = parserPoolManager.parse(inputStream);
Element metadataRoot = document.getDocumentElement();

QName qName= new QName(metadataRoot.getNamespaceURI(), metadataRoot.getLocalName(), metadataRoot.getPrefix());

//get an unmarshaller
Unmarshaller unmarshaller = Configuration.getUnmarshallerFactory().getUnmarshaller(qName);

//unmarshall using the document root element
Response response = (Response)unmarshaller.unmarshall(metadataRoot);

//we have the xml unmarshalled to a response object
System.out.println("Response object created");
System.out.println("Issue Instant: " + response.getIssueInstant().toString());
System.out.println("Signature Reference ID: " + response.getSignatureReferenceID().toString());

//grab the certificate file
File certificateFile = new File("C:\\Documents and Settings\\kgellis\\My Documents\\Eclipse\\Workspace\\MyOpenSamlTester\\Files\\Con-wayPublicKey.cer");

//get the certificate from the file
InputStream inputStream2 = new FileInputStream(certificateFile);
CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
X509Certificate certificate = (X509Certificate)certificateFactory.generateCertificate(inputStream2);
inputStream2.close();

//pull out the public key part of the certificate into a KeySpec
X509EncodedKeySpec publicKeySpec = new X509EncodedKeySpec(certificate.getPublicKey().getEncoded());

//get KeyFactory object that creates key objects, specifying RSA
KeyFactory keyFactory = KeyFactory.getInstance("RSA");
System.out.println("Security Provider: " + keyFactory.getProvider().toString());

//generate public key to validate signatures
PublicKey publicKey = keyFactory.generatePublic(publicKeySpec);

//we have the public key
System.out.println("Public Key created");

//create credentials
BasicX509Credential publicCredential = new BasicX509Credential();

//add public key value
publicCredential.setPublicKey(publicKey);

//create SignatureValidator
SignatureValidator signatureValidator = new SignatureValidator(publicCredential);

//get the signature to validate from the response object
Signature signature = response.getSignature();

//try to validate
try
{
signatureValidator.validate(signature);
}
catch (ValidationException ve)
{
System.out.println("Signature is NOT valid.");
System.out.println(ve.getMessage());
return;
}

//no validation exception was thrown
System.out.println("Signature is valid.");

//start decryption of assertion

//grab the KeyStore file
File keyStoreFile = new File("C:\\Documents and Settings\\kgellis\\My Documents\\Eclipse\\Workspace\\MyOpenSamlTester\\Files\\WWCPrivateKey.jks");

KeyStore keyStore = KeyStore.getInstance("JKS");

//load up a KeyStore
keyStore.load(new FileInputStream(keyStoreFile), "!c3c0ld".toCharArray());

RSAPrivateKey privateKey = (RSAPrivateKey) keyStore.getKey("pvktmp:bd5ba0e0-9718-48ea-b6e6-32cd9c852d76", "!c3c0ld".toCharArray());

//we have the private key
System.out.println("Private Key created");
System.out.println("Private Key Algorithm: " + privateKey.getAlgorithm());

//create the credential
BasicX509Credential decryptionCredential = new BasicX509Credential();
decryptionCredential.setPrivateKey(privateKey);

StaticKeyInfoCredentialResolver skicr = new StaticKeyInfoCredentialResolver(decryptionCredential);

//create a decrypter
Decrypter decrypter = new Decrypter(null, skicr, new InlineEncryptedKeyResolver());

//decrypt the first (and only) assertion
Assertion decryptedAssertion;

try
{
decryptedAssertion = decrypter.decrypt(response.getEncryptedAssertions().get(0));
}
catch (DecryptionException de)
{
System.out.println("Assertion decryption failed.");
System.out.println(de.getMessage());
return;
}

System.out.println("Assertion decryption succeeded.");
System.out.println("Assertion ID: " + decryptedAssertion.getID());

//loop through the nodes to get what we want
List<AttributeStatement> attributeStatements = decryptedAssertion.getAttributeStatements();
for (int i = 0; i < attributeStatements.size(); i++)
{
List<Attribute> attributes = attributeStatements.get(i).getAttributes();
for (int x = 0; x < attributes.size(); x++)
{
String strAttributeName = attributes.get(x).getDOM().getAttribute("Name");

List<XMLObject> attributeValues = attributes.get(x).getAttributeValues();
for (int y = 0; y < attributeValues.size(); y++)
{
String strAttributeValue = attributeValues.get(y).getDOM().getTextContent();
System.out.println(strAttributeName + ": " + strAttributeValue);
}
}
}
}
catch (Exception ex)
{
ex.printStackTrace();
}
}
}