Skip to content

Commit

Permalink
WW-5350 Implement OGNL Allowlist capability
Browse files Browse the repository at this point in the history
  • Loading branch information
kusalk committed Nov 5, 2023
1 parent 9cbe10f commit bef9769
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 17 deletions.
78 changes: 66 additions & 12 deletions core/src/main/java/com/opensymphony/xwork2/ognl/OgnlUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toSet;
import static org.apache.commons.lang3.StringUtils.strip;
import static org.apache.struts2.StrutsConstants.STRUTS_ALLOWLIST_CLASSES;
import static org.apache.struts2.StrutsConstants.STRUTS_ALLOWLIST_ENABLE;
import static org.apache.struts2.StrutsConstants.STRUTS_ALLOWLIST_PACKAGE_NAMES;
import static org.apache.struts2.ognl.OgnlGuard.EXPR_BLOCKED;


Expand Down Expand Up @@ -89,15 +92,19 @@ public class OgnlUtil {
private Set<String> excludedPackageNames = emptySet();
private Set<String> excludedPackageExemptClasses = emptySet();

private boolean enforceAllowlistEnabled = false;
private Set<String> allowlistClasses = emptySet();
private Set<String> allowlistPackageNames = emptySet();

private Set<String> devModeExcludedClasses = emptySet();
private Set<Pattern> devModeExcludedPackageNamePatterns = emptySet();
private Set<String> devModeExcludedPackageNames = emptySet();
private Set<String> devModeExcludedPackageExemptClasses = emptySet();

private Container container;
private boolean allowStaticFieldAccess = true;
private boolean disallowProxyMemberAccess;
private boolean disallowDefaultPackageAccess;
private boolean disallowProxyMemberAccess = false;
private boolean disallowDefaultPackageAccess = false;

/**
* Construct a new OgnlUtil instance for use with the framework
Expand Down Expand Up @@ -178,6 +185,12 @@ protected void setDevModeExcludedClasses(String commaDelimitedClasses) {
devModeExcludedClasses = toNewClassesSet(devModeExcludedClasses, commaDelimitedClasses);
}

private static Set<String> toClassesSet(String newDelimitedClasses) throws ConfigurationException {
Set<String> classNames = commaDelimitedStringToSet(newDelimitedClasses);
validateClasses(classNames, OgnlUtil.class.getClassLoader());
return unmodifiableSet(classNames);
}

private static Set<String> toNewClassesSet(Set<String> oldClasses, String newDelimitedClasses) throws ConfigurationException {
Set<String> classNames = commaDelimitedStringToSet(newDelimitedClasses);
validateClasses(classNames, OgnlUtil.class.getClassLoader());
Expand Down Expand Up @@ -229,25 +242,36 @@ protected void setDevModeExcludedPackageNames(String commaDelimitedPackageNames)
devModeExcludedPackageNames = toNewPackageNamesSet(devModeExcludedPackageNames, commaDelimitedPackageNames);
}

private static Set<String> toNewPackageNamesSet(Set<String> oldPackageNames, String newDelimitedPackageNames) throws ConfigurationException {
private static Set<String> toPackageNamesSet(String newDelimitedPackageNames) throws ConfigurationException {
Set<String> packageNames = commaDelimitedStringToSet(newDelimitedPackageNames)
.stream().map(s -> strip(s, ".")).collect(toSet());
if (packageNames.stream().anyMatch(s -> Pattern.compile("\\s").matcher(s).find())) {
throw new ConfigurationException("Excluded package names could not be parsed due to erroneous whitespace characters: " + newDelimitedPackageNames);
}
validatePackageNames(packageNames);
return unmodifiableSet(packageNames);
}

private static Set<String> toNewPackageNamesSet(Collection<String> oldPackageNames, String newDelimitedPackageNames) throws ConfigurationException {
Set<String> packageNames = commaDelimitedStringToSet(newDelimitedPackageNames)
.stream().map(s -> strip(s, ".")).collect(toSet());
validatePackageNames(packageNames);
Set<String> newPackageNames = new HashSet<>(oldPackageNames);
newPackageNames.addAll(packageNames);
return unmodifiableSet(newPackageNames);
}

private static void validatePackageNames(Collection<String> packageNames) {
if (packageNames.stream().anyMatch(s -> Pattern.compile("\\s").matcher(s).find())) {
throw new ConfigurationException("Excluded package names could not be parsed due to erroneous whitespace characters: " + packageNames);
}
}

@Inject(value = StrutsConstants.STRUTS_EXCLUDED_PACKAGE_EXEMPT_CLASSES, required = false)
public void setExcludedPackageExemptClasses(String commaDelimitedClasses) {
excludedPackageExemptClasses = toNewClassesSet(excludedPackageExemptClasses, commaDelimitedClasses);
excludedPackageExemptClasses = toClassesSet(commaDelimitedClasses);
}

@Inject(value = StrutsConstants.STRUTS_DEV_MODE_EXCLUDED_PACKAGE_EXEMPT_CLASSES, required = false)
public void setDevModeExcludedPackageExemptClasses(String commaDelimitedClasses) {
devModeExcludedPackageExemptClasses = toNewClassesSet(devModeExcludedPackageExemptClasses, commaDelimitedClasses);
devModeExcludedPackageExemptClasses = toClassesSet(commaDelimitedClasses);
}

public Set<String> getExcludedClasses() {
Expand All @@ -266,6 +290,33 @@ public Set<String> getExcludedPackageExemptClasses() {
return excludedPackageExemptClasses;
}

@Inject(value = STRUTS_ALLOWLIST_ENABLE, required = false)
protected void setEnforceAllowlistEnabled(String enforceAllowlistEnabled) {
this.enforceAllowlistEnabled = BooleanUtils.toBoolean(enforceAllowlistEnabled);
}

@Inject(value = STRUTS_ALLOWLIST_CLASSES, required = false)
protected void setAllowlistClasses(String commaDelimitedClasses) {
allowlistClasses = toClassesSet(commaDelimitedClasses);
}

@Inject(value = STRUTS_ALLOWLIST_PACKAGE_NAMES, required = false)
protected void setAllowlistPackageNames(String commaDelimitedPackageNames) {
allowlistPackageNames = toPackageNamesSet(commaDelimitedPackageNames);
}

public boolean isEnforceAllowlistEnabled() {
return enforceAllowlistEnabled;
}

public Set<String> getAllowlistClasses() {
return allowlistClasses;
}

public Set<String> getAllowlistPackageNames() {
return allowlistPackageNames;
}

@Inject
protected void setContainer(Container container) {
this.container = container;
Expand Down Expand Up @@ -884,10 +935,13 @@ protected Map<String, Object> createDefaultContext(Object root, ClassResolver cl
memberAccess.useExcludedPackageNames(devModeExcludedPackageNames);
memberAccess.useExcludedPackageExemptClasses(devModeExcludedPackageExemptClasses);
} else {
memberAccess.useExcludedClasses(excludedClasses);
memberAccess.useExcludedPackageNamePatterns(excludedPackageNamePatterns);
memberAccess.useExcludedPackageNames(excludedPackageNames);
memberAccess.useExcludedPackageExemptClasses(excludedPackageExemptClasses);
memberAccess.useExcludedClasses(getExcludedClasses());
memberAccess.useExcludedPackageNamePatterns(getExcludedPackageNamePatterns());
memberAccess.useExcludedPackageNames(getExcludedPackageNames());
memberAccess.useExcludedPackageExemptClasses(getExcludedPackageExemptClasses());
memberAccess.useEnforceAllowlistEnabled(isEnforceAllowlistEnabled());
memberAccess.useAllowlistClasses(getAllowlistClasses());
memberAccess.useAllowlistPackageNames(getAllowlistPackageNames());
}

return Ognl.createDefaultContext(root, memberAccess, resolver, defaultConverter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ protected void setOgnlUtil(OgnlUtil ognlUtil) {
securityMemberAccess.useExcludedPackageNamePatterns(ognlUtil.getExcludedPackageNamePatterns());
securityMemberAccess.useExcludedPackageNames(ognlUtil.getExcludedPackageNames());
securityMemberAccess.useExcludedPackageExemptClasses(ognlUtil.getExcludedPackageExemptClasses());
securityMemberAccess.useEnforceAllowlistEnabled(ognlUtil.isEnforceAllowlistEnabled());
securityMemberAccess.useAllowlistClasses(ognlUtil.getAllowlistClasses());
securityMemberAccess.useAllowlistPackageNames(ognlUtil.getAllowlistPackageNames());
securityMemberAccess.disallowProxyMemberAccess(ognlUtil.isDisallowProxyMemberAccess());
securityMemberAccess.disallowDefaultPackageAccess(ognlUtil.isDisallowDefaultPackageAccess());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,11 @@ public class SecurityMemberAccess implements MemberAccess {
private Set<Pattern> excludedPackageNamePatterns = emptySet();
private Set<String> excludedPackageNames = emptySet();
private Set<String> excludedPackageExemptClasses = emptySet();
private boolean disallowProxyMemberAccess;
private boolean disallowDefaultPackageAccess;
private boolean enforceAllowlistEnabled = false;
private Set<String> allowlistClasses = emptySet();
private Set<String> allowlistPackageNames = emptySet();
private boolean disallowProxyMemberAccess = false;
private boolean disallowDefaultPackageAccess = false;

/**
* SecurityMemberAccess
Expand Down Expand Up @@ -149,13 +152,44 @@ public boolean isAccessible(Map context, Object target, Member member, String pr
return false;
}

if (!checkAllowlist(target, member)) {
return false;
}

if (!isAcceptableProperty(propertyName)) {
return false;
}

return true;
}

/**
* @return {@code true} if member access is allowed
*/
protected boolean checkAllowlist(Object target, Member member) {
Class<?> memberClass = member.getDeclaringClass();
if (!enforceAllowlistEnabled) {
return true;
}
if (!isClassAllowlisted(memberClass)) {
LOG.warn(format("Declaring class [{0}] of member type [{1}] is not allowlisted!", memberClass, member));
return false;
}
if (target == null || target.getClass() == memberClass) {
return true;
}
Class<?> targetClass = target.getClass();
if (!isClassAllowlisted(targetClass)) {
LOG.warn(format("Target class [{0}] of target [{1}] is not allowlisted!", targetClass, target));
return false;
}
return true;
}

protected boolean isClassAllowlisted(Class<?> clazz) {
return allowlistClasses.contains(clazz.getName()) || isClassBelongsToPackages(clazz, allowlistPackageNames);
}

/**
* @return {@code true} if member access is allowed
*/
Expand Down Expand Up @@ -286,14 +320,14 @@ protected boolean isExcludedPackageNamePatterns(Class<?> clazz) {
}

protected boolean isExcludedPackageNames(Class<?> clazz) {
return isExcludedPackageNamesStatic(clazz, excludedPackageNames);
return isClassBelongsToPackages(clazz, excludedPackageNames);
}

public static boolean isExcludedPackageNamesStatic(Class<?> clazz, Set<String> excludedPackageNames) {
public static boolean isClassBelongsToPackages(Class<?> clazz, Set<String> matchingPackages) {
List<String> packageParts = Arrays.asList(toPackageName(clazz).split("\\."));
for (int i = 0; i < packageParts.size(); i++) {
String parentPackage = String.join(".", packageParts.subList(0, i + 1));
if (excludedPackageNames.contains(parentPackage)) {
if (matchingPackages.contains(parentPackage)) {
return true;
}
}
Expand Down Expand Up @@ -399,6 +433,18 @@ public void useExcludedPackageExemptClasses(Set<String> excludedPackageExemptCla
this.excludedPackageExemptClasses = excludedPackageExemptClasses;
}

public void useEnforceAllowlistEnabled(boolean enforceAllowlistEnabled) {
this.enforceAllowlistEnabled = enforceAllowlistEnabled;
}

public void useAllowlistClasses(Set<String> allowlistClasses) {
this.allowlistClasses = allowlistClasses;
}

public void useAllowlistPackageNames(Set<String> allowlistPackageNames) {
this.allowlistPackageNames = allowlistPackageNames;
}

/**
* @deprecated please use {@link #disallowProxyMemberAccess(boolean)}
*/
Expand Down
6 changes: 6 additions & 0 deletions core/src/main/java/org/apache/struts2/StrutsConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,12 @@ public final class StrutsConstants {
public static final String STRUTS_DEV_MODE_EXCLUDED_PACKAGE_NAMES = "struts.devMode.excludedPackageNames";
public static final String STRUTS_DEV_MODE_EXCLUDED_PACKAGE_EXEMPT_CLASSES = "struts.devMode.excludedPackageExemptClasses";

/** Boolean to enable strict allowlist processing of all OGNL expression calls. */
public static final String STRUTS_ALLOWLIST_ENABLE = "struts.allowlist.enable";
/** Comma delimited set of allowed classes which CAN be accessed via OGNL expressions. Both target and member classes of OGNL expression must be allowlisted. */
public static final String STRUTS_ALLOWLIST_CLASSES = "struts.allowlist.classes";
/** Comma delimited set of package names, of which all its classes, and all classes in its subpackages, CAN be accessed via OGNL expressions. Both target and member classes of OGNL expression must be allowlisted. */
public static final String STRUTS_ALLOWLIST_PACKAGE_NAMES = "struts.allowlist.packageNames";

/** Dedicated services to check if passed string is excluded/accepted */
public static final String STRUTS_EXCLUDED_PATTERNS_CHECKER = "struts.excludedPatterns.checker";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
*/
package com.opensymphony.xwork2.ognl;

import com.opensymphony.xwork2.TestBean;
import com.opensymphony.xwork2.test.TestBean2;
import com.opensymphony.xwork2.util.TextParseUtil;
import org.junit.Before;
import org.junit.Test;

import java.lang.reflect.Field;
import java.lang.reflect.Member;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
Expand All @@ -32,6 +35,7 @@
import java.util.Set;
import java.util.regex.Pattern;

import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
Expand Down Expand Up @@ -794,6 +798,79 @@ public void testPackageNameExclusionAsCommaDelimited() {
assertTrue("package java.lang. is accessible!", actual);
}

@Test
public void classInclusion() throws Exception {

sma.useEnforceAllowlistEnabled(true);

TestBean2 bean = new TestBean2();
Method method = TestBean2.class.getMethod("getData");

assertFalse(sma.checkAllowlist(bean, method));

sma.useAllowlistClasses(new HashSet<>(singletonList(TestBean2.class.getName())));

assertTrue(sma.checkAllowlist(bean, method));
}

@Test
public void packageInclusion() throws Exception {
sma.useEnforceAllowlistEnabled(true);

TestBean2 bean = new TestBean2();
Method method = TestBean2.class.getMethod("getData");

assertFalse(sma.checkAllowlist(bean, method));

sma.useAllowlistPackageNames(new HashSet<>(singletonList(TestBean2.class.getPackage().getName())));

assertTrue(sma.checkAllowlist(bean, method));
}

@Test
public void classInclusion_subclass() throws Exception {
sma.useEnforceAllowlistEnabled(true);
sma.useAllowlistClasses(new HashSet<>(singletonList(TestBean2.class.getName())));

TestBean2 bean = new TestBean2();
Method method = TestBean2.class.getMethod("getName");

assertFalse(sma.checkAllowlist(bean, method));
}

@Test
public void classInclusion_subclass_both() throws Exception {
sma.useEnforceAllowlistEnabled(true);
sma.useAllowlistClasses(new HashSet<>(asList(TestBean.class.getName(), TestBean2.class.getName())));

TestBean2 bean = new TestBean2();
Method method = TestBean2.class.getMethod("getName");

assertTrue(sma.checkAllowlist(bean, method));
}

@Test
public void packageInclusion_subclass() throws Exception {
sma.useEnforceAllowlistEnabled(true);
sma.useAllowlistPackageNames(new HashSet<>(singletonList(TestBean2.class.getPackage().getName())));

TestBean2 bean = new TestBean2();
Method method = TestBean2.class.getMethod("getName");

assertFalse(sma.checkAllowlist(bean, method));
}

@Test
public void packageInclusion_subclass_both() throws Exception {
sma.useEnforceAllowlistEnabled(true);
sma.useAllowlistPackageNames(new HashSet<>(asList(TestBean.class.getPackage().getName(), TestBean2.class.getPackage().getName())));

TestBean2 bean = new TestBean2();
Method method = TestBean2.class.getMethod("getName");

assertTrue(sma.checkAllowlist(bean, method));
}

private static String formGetterName(String propertyName) {
return "get" + propertyName.substring(0, 1).toUpperCase() + propertyName.substring(1);
}
Expand Down

0 comments on commit bef9769

Please sign in to comment.