diff --git a/core/src/main/java/com/opensymphony/xwork2/ognl/OgnlUtil.java b/core/src/main/java/com/opensymphony/xwork2/ognl/OgnlUtil.java index 7c9b5fbbbc..1f019f64a2 100644 --- a/core/src/main/java/com/opensymphony/xwork2/ognl/OgnlUtil.java +++ b/core/src/main/java/com/opensymphony/xwork2/ognl/OgnlUtil.java @@ -59,6 +59,9 @@ import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toSet; import static org.apache.commons.lang3.StringUtils.strip; +import static org.apache.struts2.StrutsConstants.STRUTS_ALLOWLIST_CLASSES; +import static org.apache.struts2.StrutsConstants.STRUTS_ALLOWLIST_ENABLE; +import static org.apache.struts2.StrutsConstants.STRUTS_ALLOWLIST_PACKAGE_NAMES; import static org.apache.struts2.ognl.OgnlGuard.EXPR_BLOCKED; @@ -89,6 +92,10 @@ public class OgnlUtil { private Set excludedPackageNames = emptySet(); private Set excludedPackageExemptClasses = emptySet(); + private boolean enforceAllowlistEnabled = false; + private Set allowlistClasses = emptySet(); + private Set allowlistPackageNames = emptySet(); + private Set devModeExcludedClasses = emptySet(); private Set devModeExcludedPackageNamePatterns = emptySet(); private Set devModeExcludedPackageNames = emptySet(); @@ -96,8 +103,8 @@ public class OgnlUtil { private Container container; private boolean allowStaticFieldAccess = true; - private boolean disallowProxyMemberAccess; - private boolean disallowDefaultPackageAccess; + private boolean disallowProxyMemberAccess = false; + private boolean disallowDefaultPackageAccess = false; /** * Construct a new OgnlUtil instance for use with the framework @@ -178,6 +185,12 @@ protected void setDevModeExcludedClasses(String commaDelimitedClasses) { devModeExcludedClasses = toNewClassesSet(devModeExcludedClasses, commaDelimitedClasses); } + private static Set toClassesSet(String newDelimitedClasses) throws ConfigurationException { + Set classNames = commaDelimitedStringToSet(newDelimitedClasses); + validateClasses(classNames, OgnlUtil.class.getClassLoader()); + return unmodifiableSet(classNames); + } + private static Set toNewClassesSet(Set oldClasses, String newDelimitedClasses) throws ConfigurationException { Set classNames = commaDelimitedStringToSet(newDelimitedClasses); validateClasses(classNames, OgnlUtil.class.getClassLoader()); @@ -229,25 +242,36 @@ protected void setDevModeExcludedPackageNames(String commaDelimitedPackageNames) devModeExcludedPackageNames = toNewPackageNamesSet(devModeExcludedPackageNames, commaDelimitedPackageNames); } - private static Set toNewPackageNamesSet(Set oldPackageNames, String newDelimitedPackageNames) throws ConfigurationException { + private static Set toPackageNamesSet(String newDelimitedPackageNames) throws ConfigurationException { Set packageNames = commaDelimitedStringToSet(newDelimitedPackageNames) .stream().map(s -> strip(s, ".")).collect(toSet()); - if (packageNames.stream().anyMatch(s -> Pattern.compile("\\s").matcher(s).find())) { - throw new ConfigurationException("Excluded package names could not be parsed due to erroneous whitespace characters: " + newDelimitedPackageNames); - } + validatePackageNames(packageNames); + return unmodifiableSet(packageNames); + } + + private static Set toNewPackageNamesSet(Collection oldPackageNames, String newDelimitedPackageNames) throws ConfigurationException { + Set packageNames = commaDelimitedStringToSet(newDelimitedPackageNames) + .stream().map(s -> strip(s, ".")).collect(toSet()); + validatePackageNames(packageNames); Set newPackageNames = new HashSet<>(oldPackageNames); newPackageNames.addAll(packageNames); return unmodifiableSet(newPackageNames); } + private static void validatePackageNames(Collection packageNames) { + if (packageNames.stream().anyMatch(s -> Pattern.compile("\\s").matcher(s).find())) { + throw new ConfigurationException("Excluded package names could not be parsed due to erroneous whitespace characters: " + packageNames); + } + } + @Inject(value = StrutsConstants.STRUTS_EXCLUDED_PACKAGE_EXEMPT_CLASSES, required = false) public void setExcludedPackageExemptClasses(String commaDelimitedClasses) { - excludedPackageExemptClasses = toNewClassesSet(excludedPackageExemptClasses, commaDelimitedClasses); + excludedPackageExemptClasses = toClassesSet(commaDelimitedClasses); } @Inject(value = StrutsConstants.STRUTS_DEV_MODE_EXCLUDED_PACKAGE_EXEMPT_CLASSES, required = false) public void setDevModeExcludedPackageExemptClasses(String commaDelimitedClasses) { - devModeExcludedPackageExemptClasses = toNewClassesSet(devModeExcludedPackageExemptClasses, commaDelimitedClasses); + devModeExcludedPackageExemptClasses = toClassesSet(commaDelimitedClasses); } public Set getExcludedClasses() { @@ -266,6 +290,33 @@ public Set getExcludedPackageExemptClasses() { return excludedPackageExemptClasses; } + @Inject(value = STRUTS_ALLOWLIST_ENABLE, required = false) + protected void setEnforceAllowlistEnabled(String enforceAllowlistEnabled) { + this.enforceAllowlistEnabled = BooleanUtils.toBoolean(enforceAllowlistEnabled); + } + + @Inject(value = STRUTS_ALLOWLIST_CLASSES, required = false) + protected void setAllowlistClasses(String commaDelimitedClasses) { + allowlistClasses = toClassesSet(commaDelimitedClasses); + } + + @Inject(value = STRUTS_ALLOWLIST_PACKAGE_NAMES, required = false) + protected void setAllowlistPackageNames(String commaDelimitedPackageNames) { + allowlistPackageNames = toPackageNamesSet(commaDelimitedPackageNames); + } + + public boolean isEnforceAllowlistEnabled() { + return enforceAllowlistEnabled; + } + + public Set getAllowlistClasses() { + return allowlistClasses; + } + + public Set getAllowlistPackageNames() { + return allowlistPackageNames; + } + @Inject protected void setContainer(Container container) { this.container = container; @@ -884,10 +935,13 @@ protected Map createDefaultContext(Object root, ClassResolver cl memberAccess.useExcludedPackageNames(devModeExcludedPackageNames); memberAccess.useExcludedPackageExemptClasses(devModeExcludedPackageExemptClasses); } else { - memberAccess.useExcludedClasses(excludedClasses); - memberAccess.useExcludedPackageNamePatterns(excludedPackageNamePatterns); - memberAccess.useExcludedPackageNames(excludedPackageNames); - memberAccess.useExcludedPackageExemptClasses(excludedPackageExemptClasses); + memberAccess.useExcludedClasses(getExcludedClasses()); + memberAccess.useExcludedPackageNamePatterns(getExcludedPackageNamePatterns()); + memberAccess.useExcludedPackageNames(getExcludedPackageNames()); + memberAccess.useExcludedPackageExemptClasses(getExcludedPackageExemptClasses()); + memberAccess.useEnforceAllowlistEnabled(isEnforceAllowlistEnabled()); + memberAccess.useAllowlistClasses(getAllowlistClasses()); + memberAccess.useAllowlistPackageNames(getAllowlistPackageNames()); } return Ognl.createDefaultContext(root, memberAccess, resolver, defaultConverter); diff --git a/core/src/main/java/com/opensymphony/xwork2/ognl/OgnlValueStack.java b/core/src/main/java/com/opensymphony/xwork2/ognl/OgnlValueStack.java index 936619ae45..01b6af81d7 100644 --- a/core/src/main/java/com/opensymphony/xwork2/ognl/OgnlValueStack.java +++ b/core/src/main/java/com/opensymphony/xwork2/ognl/OgnlValueStack.java @@ -93,6 +93,9 @@ protected void setOgnlUtil(OgnlUtil ognlUtil) { securityMemberAccess.useExcludedPackageNamePatterns(ognlUtil.getExcludedPackageNamePatterns()); securityMemberAccess.useExcludedPackageNames(ognlUtil.getExcludedPackageNames()); securityMemberAccess.useExcludedPackageExemptClasses(ognlUtil.getExcludedPackageExemptClasses()); + securityMemberAccess.useEnforceAllowlistEnabled(ognlUtil.isEnforceAllowlistEnabled()); + securityMemberAccess.useAllowlistClasses(ognlUtil.getAllowlistClasses()); + securityMemberAccess.useAllowlistPackageNames(ognlUtil.getAllowlistPackageNames()); securityMemberAccess.disallowProxyMemberAccess(ognlUtil.isDisallowProxyMemberAccess()); securityMemberAccess.disallowDefaultPackageAccess(ognlUtil.isDisallowDefaultPackageAccess()); } diff --git a/core/src/main/java/com/opensymphony/xwork2/ognl/SecurityMemberAccess.java b/core/src/main/java/com/opensymphony/xwork2/ognl/SecurityMemberAccess.java index e993f3adb1..beb45a2ddc 100644 --- a/core/src/main/java/com/opensymphony/xwork2/ognl/SecurityMemberAccess.java +++ b/core/src/main/java/com/opensymphony/xwork2/ognl/SecurityMemberAccess.java @@ -56,8 +56,11 @@ public class SecurityMemberAccess implements MemberAccess { private Set excludedPackageNamePatterns = emptySet(); private Set excludedPackageNames = emptySet(); private Set excludedPackageExemptClasses = emptySet(); - private boolean disallowProxyMemberAccess; - private boolean disallowDefaultPackageAccess; + private boolean enforceAllowlistEnabled = false; + private Set allowlistClasses = emptySet(); + private Set allowlistPackageNames = emptySet(); + private boolean disallowProxyMemberAccess = false; + private boolean disallowDefaultPackageAccess = false; /** * SecurityMemberAccess @@ -149,6 +152,10 @@ public boolean isAccessible(Map context, Object target, Member member, String pr return false; } + if (!checkAllowlist(target, member)) { + return false; + } + if (!isAcceptableProperty(propertyName)) { return false; } @@ -156,6 +163,33 @@ public boolean isAccessible(Map context, Object target, Member member, String pr return true; } + /** + * @return {@code true} if member access is allowed + */ + protected boolean checkAllowlist(Object target, Member member) { + Class memberClass = member.getDeclaringClass(); + if (!enforceAllowlistEnabled) { + return true; + } + if (!isClassAllowlisted(memberClass)) { + LOG.warn(format("Declaring class [{0}] of member type [{1}] is not allowlisted!", memberClass, member)); + return false; + } + if (target == null || target.getClass() == memberClass) { + return true; + } + Class targetClass = target.getClass(); + if (!isClassAllowlisted(targetClass)) { + LOG.warn(format("Target class [{0}] of target [{1}] is not allowlisted!", targetClass, target)); + return false; + } + return true; + } + + protected boolean isClassAllowlisted(Class clazz) { + return allowlistClasses.contains(clazz.getName()) || isClassBelongsToPackages(clazz, allowlistPackageNames); + } + /** * @return {@code true} if member access is allowed */ @@ -286,14 +320,14 @@ protected boolean isExcludedPackageNamePatterns(Class clazz) { } protected boolean isExcludedPackageNames(Class clazz) { - return isExcludedPackageNamesStatic(clazz, excludedPackageNames); + return isClassBelongsToPackages(clazz, excludedPackageNames); } - public static boolean isExcludedPackageNamesStatic(Class clazz, Set excludedPackageNames) { + public static boolean isClassBelongsToPackages(Class clazz, Set matchingPackages) { List packageParts = Arrays.asList(toPackageName(clazz).split("\\.")); for (int i = 0; i < packageParts.size(); i++) { String parentPackage = String.join(".", packageParts.subList(0, i + 1)); - if (excludedPackageNames.contains(parentPackage)) { + if (matchingPackages.contains(parentPackage)) { return true; } } @@ -399,6 +433,18 @@ public void useExcludedPackageExemptClasses(Set excludedPackageExemptCla this.excludedPackageExemptClasses = excludedPackageExemptClasses; } + public void useEnforceAllowlistEnabled(boolean enforceAllowlistEnabled) { + this.enforceAllowlistEnabled = enforceAllowlistEnabled; + } + + public void useAllowlistClasses(Set allowlistClasses) { + this.allowlistClasses = allowlistClasses; + } + + public void useAllowlistPackageNames(Set allowlistPackageNames) { + this.allowlistPackageNames = allowlistPackageNames; + } + /** * @deprecated please use {@link #disallowProxyMemberAccess(boolean)} */ diff --git a/core/src/main/java/org/apache/struts2/StrutsConstants.java b/core/src/main/java/org/apache/struts2/StrutsConstants.java index bfe639d2e7..bcb07a69a3 100644 --- a/core/src/main/java/org/apache/struts2/StrutsConstants.java +++ b/core/src/main/java/org/apache/struts2/StrutsConstants.java @@ -432,6 +432,12 @@ public final class StrutsConstants { public static final String STRUTS_DEV_MODE_EXCLUDED_PACKAGE_NAMES = "struts.devMode.excludedPackageNames"; public static final String STRUTS_DEV_MODE_EXCLUDED_PACKAGE_EXEMPT_CLASSES = "struts.devMode.excludedPackageExemptClasses"; + /** Boolean to enable strict allowlist processing of all OGNL expression calls. */ + public static final String STRUTS_ALLOWLIST_ENABLE = "struts.allowlist.enable"; + /** Comma delimited set of allowed classes which CAN be accessed via OGNL expressions. Both target and member classes of OGNL expression must be allowlisted. */ + public static final String STRUTS_ALLOWLIST_CLASSES = "struts.allowlist.classes"; + /** Comma delimited set of package names, of which all its classes, and all classes in its subpackages, CAN be accessed via OGNL expressions. Both target and member classes of OGNL expression must be allowlisted. */ + public static final String STRUTS_ALLOWLIST_PACKAGE_NAMES = "struts.allowlist.packageNames"; /** Dedicated services to check if passed string is excluded/accepted */ public static final String STRUTS_EXCLUDED_PATTERNS_CHECKER = "struts.excludedPatterns.checker"; diff --git a/core/src/test/java/com/opensymphony/xwork2/ognl/SecurityMemberAccessTest.java b/core/src/test/java/com/opensymphony/xwork2/ognl/SecurityMemberAccessTest.java index 7ffc47b9f8..08a3b919ea 100644 --- a/core/src/test/java/com/opensymphony/xwork2/ognl/SecurityMemberAccessTest.java +++ b/core/src/test/java/com/opensymphony/xwork2/ognl/SecurityMemberAccessTest.java @@ -18,12 +18,15 @@ */ package com.opensymphony.xwork2.ognl; +import com.opensymphony.xwork2.TestBean; +import com.opensymphony.xwork2.test.TestBean2; import com.opensymphony.xwork2.util.TextParseUtil; import org.junit.Before; import org.junit.Test; import java.lang.reflect.Field; import java.lang.reflect.Member; +import java.lang.reflect.Method; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; @@ -32,6 +35,7 @@ import java.util.Set; import java.util.regex.Pattern; +import static java.util.Arrays.asList; import static java.util.Collections.singletonList; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -794,6 +798,79 @@ public void testPackageNameExclusionAsCommaDelimited() { assertTrue("package java.lang. is accessible!", actual); } + @Test + public void classInclusion() throws Exception { + + sma.useEnforceAllowlistEnabled(true); + + TestBean2 bean = new TestBean2(); + Method method = TestBean2.class.getMethod("getData"); + + assertFalse(sma.checkAllowlist(bean, method)); + + sma.useAllowlistClasses(new HashSet<>(singletonList(TestBean2.class.getName()))); + + assertTrue(sma.checkAllowlist(bean, method)); + } + + @Test + public void packageInclusion() throws Exception { + sma.useEnforceAllowlistEnabled(true); + + TestBean2 bean = new TestBean2(); + Method method = TestBean2.class.getMethod("getData"); + + assertFalse(sma.checkAllowlist(bean, method)); + + sma.useAllowlistPackageNames(new HashSet<>(singletonList(TestBean2.class.getPackage().getName()))); + + assertTrue(sma.checkAllowlist(bean, method)); + } + + @Test + public void classInclusion_subclass() throws Exception { + sma.useEnforceAllowlistEnabled(true); + sma.useAllowlistClasses(new HashSet<>(singletonList(TestBean2.class.getName()))); + + TestBean2 bean = new TestBean2(); + Method method = TestBean2.class.getMethod("getName"); + + assertFalse(sma.checkAllowlist(bean, method)); + } + + @Test + public void classInclusion_subclass_both() throws Exception { + sma.useEnforceAllowlistEnabled(true); + sma.useAllowlistClasses(new HashSet<>(asList(TestBean.class.getName(), TestBean2.class.getName()))); + + TestBean2 bean = new TestBean2(); + Method method = TestBean2.class.getMethod("getName"); + + assertTrue(sma.checkAllowlist(bean, method)); + } + + @Test + public void packageInclusion_subclass() throws Exception { + sma.useEnforceAllowlistEnabled(true); + sma.useAllowlistPackageNames(new HashSet<>(singletonList(TestBean2.class.getPackage().getName()))); + + TestBean2 bean = new TestBean2(); + Method method = TestBean2.class.getMethod("getName"); + + assertFalse(sma.checkAllowlist(bean, method)); + } + + @Test + public void packageInclusion_subclass_both() throws Exception { + sma.useEnforceAllowlistEnabled(true); + sma.useAllowlistPackageNames(new HashSet<>(asList(TestBean.class.getPackage().getName(), TestBean2.class.getPackage().getName()))); + + TestBean2 bean = new TestBean2(); + Method method = TestBean2.class.getMethod("getName"); + + assertTrue(sma.checkAllowlist(bean, method)); + } + private static String formGetterName(String propertyName) { return "get" + propertyName.substring(0, 1).toUpperCase() + propertyName.substring(1); }