Skip to content

Instantly share code, notes, and snippets.

@nejckorasa
Created July 31, 2018 06:51
Show Gist options
  • Save nejckorasa/d68b8bafb6235ef33ad552d2e2c12fd6 to your computer and use it in GitHub Desktop.
Save nejckorasa/d68b8bafb6235ef33ad552d2e2c12fd6 to your computer and use it in GitHub Desktop.
Custom CORS filter for Spring framework to enable wildcard matching
####################################################
# CORS
####################################################
# Configure CORS headers
# If set to true, wildcard value for Access-Control-Allow-Origin header are avoided.
# List of origins defined in cors.allow-origin is matched with Origin header in each request.
# If match is found, Access-Control-Allow-Origin value duplicates the one present in Origin request header.
# default = true
#cors.avoid-wildcards=true
# default = *
#cors.allow-origin=*
# default = GET, POST, PUT, DELETE, OPTIONS
#cors.allow-method=GET, POST, PUT, DELETE, OPTIONS
# default = *
#cors.allow-headers=*
# default = true
#cors.allow-credentials=true
# default = 3600
#cors.max-age=3600
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.annotation.PostConstruct;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;
/**
* @author Nejc Korasa
*/
@Component
@Order(Ordered.HIGHEST_PRECEDENCE + 1)
public class CustomCorsFilter extends OncePerRequestFilter
{
private static final Logger LOG = LoggerFactory.getLogger(CustomCorsFilter.class);
private final List<String> originRegexes = new ArrayList<>();
@Value("${cors.avoid-wildcards:false}")
private Boolean avoidWildcards;
@Value("${cors.allow-origin:*}")
private String allowOrigin;
@Value("${cors.allow-methods:GET, POST, PUT, DELETE, OPTIONS}")
private String allowMethods;
@Value("${cors.allow-headers:*}")
private String allowHeaders;
@Value("${cors.allow-credentials:true}")
private String allowCredentials;
@Value("${cors.max-age:3600}")
private String maxAge;
@PostConstruct
public void init()
{
//noinspection DynamicRegexReplaceableByCompiledPattern
final List<String> regexes = Arrays
.stream(allowOrigin.replaceAll("\\s+", "").split(",")) // remove spaces, split by ','
.map(this::buildRegexFromWildcards)
.collect(Collectors.toList());
LOG.debug("Allow origin regexes: " + regexes);
originRegexes.addAll(regexes);
}
@Override
protected void doFilterInternal(
final HttpServletRequest request,
final HttpServletResponse response,
final FilterChain filterChain)
throws ServletException, IOException
{
response.addHeader("Access-Control-Allow-Origin", buildAllowOriginHeader(request.getHeader("Origin")));
if ("OPTIONS".equalsIgnoreCase(request.getMethod()))
{
response.addHeader("Access-Control-Allow-Methods", allowMethods);
response.addHeader("Access-Control-Allow-Headers", allowHeaders);
response.addHeader("Access-Control-Allow-Credentials", allowCredentials);
response.addHeader("Access-Control-Max-Age", maxAge);
response.setStatus(200);
}
else
{
filterChain.doFilter(request, response);
}
}
// example: Input: audio*2012*.wav
// Output: \Qaudio\E.*\Q2012\E.*\Q.wav\E
String buildRegexFromWildcards(final String origin)
{
final StringBuffer sb = new StringBuffer();
final Matcher matcher = Pattern.compile("[^*]+|(\\*)").matcher(origin);
while (matcher.find())
{
if (matcher.group(1) != null)
{
matcher.appendReplacement(sb, ".*");
}
else
{
matcher.appendReplacement(sb, "\\\\Q" + matcher.group(0) + "\\\\E");
}
}
matcher.appendTail(sb);
return sb.toString();
}
private String buildAllowOriginHeader(final String origin)
{
if (avoidWildcards)
{
if (originRegexes.size() == 1 && !"*".equals(allowOrigin))
{
return allowOrigin;
}
else
{
return matchesOrigin(origin) ? origin : allowOrigin;
}
}
return allowOrigin;
}
@SuppressWarnings("BooleanMethodNameMustStartWithQuestion")
private boolean matchesOrigin(final String origin)
{
return origin != null && ("*".equals(allowOrigin) || originRegexes.stream().anyMatch(origin::matches));
}
}
import org.junit.Test;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
/**
* @author Nejc Korasa
*/
public class CustomCorsFilterTest
{
@Test
public void wildcardToRegexSuffix()
{
final String wildcard = "my.domain.com*";
final CustomCorsFilter ccf = new CustomCorsFilter();
final String regex = ccf.buildRegexFromWildcards(wildcard);
assertTrue("my.domain.com:8080".matches(regex));
assertTrue("my.domain.com".matches(regex));
assertFalse("my.not.the.same.domain.com:8080".matches(regex));
}
@Test
public void wildcardToRegexPrefixAndSuffix()
{
final String wildcard = "*.my.domain.com*";
final CustomCorsFilter ccf = new CustomCorsFilter();
final String regex = ccf.buildRegexFromWildcards(wildcard);
assertTrue("sub.my.domain.com:8080".matches(regex));
assertTrue("sub.my.domain.com".matches(regex));
assertFalse("sub.my.wrong.domain.com:8080".matches(regex));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment