Skip to content

Commit

Permalink
[FLINK-15635][table] Improve implementation of user classloader
Browse files Browse the repository at this point in the history
  • Loading branch information
lsyldliu authored and wuchong committed Jun 8, 2022
1 parent 0fe20bc commit 7aa58b3
Show file tree
Hide file tree
Showing 15 changed files with 109 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def java_class(cls):
@classmethod
def excluded_methods(cls):
# internal interfaces, no need to expose to users.
return {'getPlanner', 'getExecutor'}
return {'getPlanner', 'getExecutor', 'getUserClassLoader'}


class EnvironmentSettingsBuilderCompletenessTests(PythonAPICompletenessTestCase, PyFlinkTestCase):
Expand All @@ -54,6 +54,10 @@ def python_class(cls):
def java_class(cls):
return "org.apache.flink.table.api.EnvironmentSettings$Builder"

@classmethod
def excluded_methods(cls):
# internal interfaces, no need to expose to users.
return {'withClassLoader'}

if __name__ == '__main__':
import unittest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ private StreamTableEnvironment createTableEnvironment() {
new StreamExecutionEnvironment(new Configuration(flinkConfig), classLoader);

final Executor executor = lookupExecutor(streamExecEnv);

// Updates the classloader of FunctionCatalog by the new classloader to solve ClassNotFound
// exception when use an udf created by add jar syntax, temporary solution until FLINK-14055
// is fixed
sessionState.functionCatalog.updateClassLoader(classLoader);
return createStreamTableEnvironment(
streamExecEnv,
settings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ public void setUp() {
Column.physical(
"timestamp", DataTypes.TIMESTAMP(6).bridgedTo(Timestamp.class)),
Column.physical("binary", DataTypes.BYTES()));
rowDataToStringConverter = new RowDataToStringConverterImpl(schema.toPhysicalRowDataType());
rowDataToStringConverter =
new RowDataToStringConverterImpl(
schema.toPhysicalRowDataType(),
DateTimeUtils.UTC_ZONE.toZoneId(),
Thread.currentThread().getContextClassLoader(),
false);

List<Row> rows =
Arrays.asList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public final class FunctionCatalog {
private final ReadableConfig config;
private final CatalogManager catalogManager;
private final ModuleManager moduleManager;
private final ClassLoader classLoader;
private ClassLoader classLoader;

private final Map<String, CatalogFunction> tempSystemFunctions = new LinkedHashMap<>();
private final Map<ObjectIdentifier, CatalogFunction> tempCatalogFunctions =
Expand All @@ -87,6 +87,11 @@ public FunctionCatalog(
this.classLoader = classLoader;
}

/** Updates the classloader, this is a temporary solution until FLINK-14055 is fixed. */
public void updateClassLoader(ClassLoader classLoader) {
this.classLoader = classLoader;
}

public void setPlannerTypeInferenceUtil(PlannerTypeInferenceUtil plannerTypeInferenceUtil) {
this.plannerTypeInferenceUtil = plannerTypeInferenceUtil;
}
Expand Down Expand Up @@ -151,7 +156,7 @@ public void registerTemporaryCatalogFunction(
normalizedIdentifier.toObjectPath(), catalogFunction);
}
try {
catalogFunction = validateAndPrepareFunction(catalogFunction);
validateAndPrepareFunction(catalogFunction);
} catch (Throwable t) {
throw new ValidationException(
String.format(
Expand Down Expand Up @@ -373,7 +378,7 @@ public Optional<ContextResolvedFunction> lookupFunction(UnresolvedIdentifier ide
*/
@Deprecated
public void registerTempSystemScalarFunction(String name, ScalarFunction function) {
function = UserDefinedFunctionHelper.prepareInstance(config, classLoader, function);
UserDefinedFunctionHelper.prepareInstance(config, function);

registerTempSystemFunction(name, new ScalarFunctionDefinition(name, function));
}
Expand All @@ -385,7 +390,7 @@ public void registerTempSystemScalarFunction(String name, ScalarFunction functio
@Deprecated
public <T> void registerTempSystemTableFunction(
String name, TableFunction<T> function, TypeInformation<T> resultType) {
function = UserDefinedFunctionHelper.prepareInstance(config, classLoader, function);
UserDefinedFunctionHelper.prepareInstance(config, function);

registerTempSystemFunction(name, new TableFunctionDefinition(name, function, resultType));
}
Expand All @@ -400,7 +405,7 @@ public <T, ACC> void registerTempSystemAggregateFunction(
ImperativeAggregateFunction<T, ACC> function,
TypeInformation<T> resultType,
TypeInformation<ACC> accType) {
function = UserDefinedFunctionHelper.prepareInstance(config, classLoader, function);
UserDefinedFunctionHelper.prepareInstance(config, function);

final FunctionDefinition definition;
if (function instanceof AggregateFunction) {
Expand All @@ -424,7 +429,7 @@ public <T, ACC> void registerTempSystemAggregateFunction(
*/
@Deprecated
public void registerTempCatalogScalarFunction(ObjectIdentifier oi, ScalarFunction function) {
function = UserDefinedFunctionHelper.prepareInstance(config, classLoader, function);
UserDefinedFunctionHelper.prepareInstance(config, function);

registerTempCatalogFunction(oi, new ScalarFunctionDefinition(oi.getObjectName(), function));
}
Expand Down Expand Up @@ -498,7 +503,7 @@ public void registerTemporarySystemFunction(
final String normalizedName = FunctionIdentifier.normalizeName(name);

try {
function = validateAndPrepareFunction(function);
validateAndPrepareFunction(function);
} catch (Throwable t) {
throw new ValidationException(
String.format(
Expand Down Expand Up @@ -626,7 +631,7 @@ private Optional<ContextResolvedFunction> resolveAmbiguousFunctionReference(Stri
}

@SuppressWarnings("unchecked")
private CatalogFunction validateAndPrepareFunction(CatalogFunction function)
private void validateAndPrepareFunction(CatalogFunction function)
throws ClassNotFoundException {
// If the input is instance of UserDefinedFunction, it means it uses the new type inference.
// In this situation the UDF have not been validated and cleaned, so we need to validate it
Expand All @@ -637,17 +642,14 @@ private CatalogFunction validateAndPrepareFunction(CatalogFunction function)
if (function instanceof InlineCatalogFunction) {
FunctionDefinition definition = ((InlineCatalogFunction) function).getDefinition();
if (definition instanceof UserDefinedFunction) {
return new InlineCatalogFunction(
UserDefinedFunctionHelper.prepareInstance(
config, classLoader, (UserDefinedFunction) definition));
UserDefinedFunctionHelper.prepareInstance(config, (UserDefinedFunction) definition);
}
// Skip validation if it's not a UserDefinedFunction.
} else if (function.getFunctionLanguage() == FunctionLanguage.JAVA) {
UserDefinedFunctionHelper.validateClass(
(Class<? extends UserDefinedFunction>)
classLoader.loadClass(function.getClassName()));
}
return function;
}

private FunctionDefinition getFunctionDefinition(String name, CatalogFunction function) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,12 @@
import org.apache.flink.table.expressions.TypeLiteralExpression;
import org.apache.flink.table.expressions.UnresolvedCallExpression;
import org.apache.flink.table.expressions.ValueLiteralExpression;
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.functions.AggregateFunctionDefinition;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.FunctionIdentifier;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.flink.table.functions.ScalarFunctionDefinition;
import org.apache.flink.table.functions.TableAggregateFunction;
import org.apache.flink.table.functions.TableAggregateFunctionDefinition;
import org.apache.flink.table.functions.TableFunction;
import org.apache.flink.table.functions.TableFunctionDefinition;
import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.table.functions.UserDefinedFunctionHelper;
Expand Down Expand Up @@ -321,50 +317,37 @@ private List<ResolvedExpression> adaptArguments(
private FunctionDefinition prepareInlineUserDefinedFunction(FunctionDefinition definition) {
if (definition instanceof ScalarFunctionDefinition) {
final ScalarFunctionDefinition sf = (ScalarFunctionDefinition) definition;
final ScalarFunction fnInstance =
UserDefinedFunctionHelper.prepareInstance(
resolutionContext.configuration(),
resolutionContext.userClassLoader(),
sf.getScalarFunction());
return new ScalarFunctionDefinition(sf.getName(), fnInstance);
UserDefinedFunctionHelper.prepareInstance(
resolutionContext.configuration(), sf.getScalarFunction());
return new ScalarFunctionDefinition(sf.getName(), sf.getScalarFunction());
} else if (definition instanceof TableFunctionDefinition) {
final TableFunctionDefinition tf = (TableFunctionDefinition) definition;
final TableFunction<?> fnInstance =
UserDefinedFunctionHelper.prepareInstance(
resolutionContext.configuration(),
resolutionContext.userClassLoader(),
tf.getTableFunction());
return new TableFunctionDefinition(tf.getName(), fnInstance, tf.getResultType());
UserDefinedFunctionHelper.prepareInstance(
resolutionContext.configuration(), tf.getTableFunction());
return new TableFunctionDefinition(
tf.getName(), tf.getTableFunction(), tf.getResultType());
} else if (definition instanceof AggregateFunctionDefinition) {
final AggregateFunctionDefinition af = (AggregateFunctionDefinition) definition;
final AggregateFunction<?, ?> afInstance =
UserDefinedFunctionHelper.prepareInstance(
resolutionContext.configuration(),
resolutionContext.userClassLoader(),
af.getAggregateFunction());
UserDefinedFunctionHelper.prepareInstance(
resolutionContext.configuration(), af.getAggregateFunction());
return new AggregateFunctionDefinition(
af.getName(),
afInstance,
af.getAggregateFunction(),
af.getResultTypeInfo(),
af.getAccumulatorTypeInfo());
} else if (definition instanceof TableAggregateFunctionDefinition) {
final TableAggregateFunctionDefinition taf =
(TableAggregateFunctionDefinition) definition;
final TableAggregateFunction<?, ?> tafInstance =
UserDefinedFunctionHelper.prepareInstance(
resolutionContext.configuration(),
resolutionContext.userClassLoader(),
taf.getTableAggregateFunction());
UserDefinedFunctionHelper.prepareInstance(
resolutionContext.configuration(), taf.getTableAggregateFunction());
return new TableAggregateFunctionDefinition(
taf.getName(),
tafInstance,
taf.getTableAggregateFunction(),
taf.getResultTypeInfo(),
taf.getAccumulatorTypeInfo());
} else if (definition instanceof UserDefinedFunction) {
return UserDefinedFunctionHelper.prepareInstance(
resolutionContext.configuration(),
resolutionContext.userClassLoader(),
(UserDefinedFunction) definition);
UserDefinedFunctionHelper.prepareInstance(
resolutionContext.configuration(), (UserDefinedFunction) definition);
}
return definition;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@

import javax.annotation.Nullable;

import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
Expand Down Expand Up @@ -248,29 +247,18 @@ public static UserDefinedFunction instantiateFunction(Class<?> functionClass) {
}
}

/**
* Prepares a {@link UserDefinedFunction} instance for usage in the API.
*
* @return A cloned instance of the function, to be used by the runtime, instantiated via {@code
* userClassLoader}.
*/
public static <T extends UserDefinedFunction> T prepareInstance(
ReadableConfig config, ClassLoader userClassLoader, T function) {
/** Prepares a {@link UserDefinedFunction} instance for usage in the API. */
public static void prepareInstance(ReadableConfig config, UserDefinedFunction function) {
validateClass(function.getClass(), false);
cleanFunction(config, function);
try {
return InstantiationUtil.clone(function, userClassLoader);
} catch (IOException | ClassNotFoundException e) {
throw new TableException("Error while cloning the function " + function, e);
}
}

/**
* Returns whether a {@link UserDefinedFunction} can be easily serialized and identified by only
* a fully qualified class name. It must have a default constructor and no serializable fields.
*
* <p>Other properties (such as checks for abstract classes) are validated at the entry points
* of the API, see {@link #prepareInstance(ReadableConfig, ClassLoader, UserDefinedFunction)}.
* of the API, see {@link #prepareInstance(ReadableConfig, UserDefinedFunction)}.
*/
public static boolean isClassNameSerializable(UserDefinedFunction function) {
final Class<?> functionClass = function.getClass();
Expand Down Expand Up @@ -347,9 +335,6 @@ public static void validateClassForRuntime(
* Creates the runtime implementation of a {@link FunctionDefinition} as an instance of {@link
* UserDefinedFunction}.
*
* <p>Note that the argument {@code builtInClassLoader} is going to be used only for built-in
* function, hence it requires the planner class loader, rather than the user classloader.
*
* @see SpecializedFunction
*/
public static UserDefinedFunction createSpecializedFunction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,7 @@ void testValidation(TestSpec testSpec) {
if (testSpec.functionClass != null) {
runnable = () -> validateClass(testSpec.functionClass);
} else if (testSpec.functionInstance != null) {
runnable =
() ->
prepareInstance(
new Configuration(),
UserDefinedFunctionHelperTest.class.getClassLoader(),
testSpec.functionInstance);
runnable = () -> prepareInstance(new Configuration(), testSpec.functionInstance);
} else {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ private ExpressionResolverBuilder createExpressionResolverBuilder(
FlinkContext context, Parser parser) {
return ExpressionResolver.resolverFor(
context.getTableConfig(),
context.getClassLoader(),
name -> Optional.empty(),
context.getFunctionCatalog().asLookup(parser::parseIdentifier),
context.getCatalogManager().getDataTypeFactory(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public PlannerContext(
List<RelTraitDef> traitDefs,
ClassLoader classLoader) {
this.typeSystem = FlinkTypeSystem.INSTANCE;
this.typeFactory = new FlinkTypeFactory(typeSystem);
this.typeFactory = new FlinkTypeFactory(classLoader, typeSystem);
this.context =
new FlinkContextImpl(
isBatchMode,
Expand All @@ -114,7 +114,8 @@ public PlannerContext(
typeFactory,
this::createFlinkPlanner,
this::getCalciteSqlDialect,
this::createRelBuilder));
this::createRelBuilder),
classLoader);
this.rootSchema = rootSchema;
this.traitDefs = traitDefs;
// Make a framework config to initialize the RelOptCluster instance,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,7 @@ public Transformation<RowData> translateToPlanInternal(
boolean isAsyncEnabled = false;
UserDefinedFunction userDefinedFunction =
LookupJoinUtil.getLookupFunction(temporalTable, lookupKeys.keySet());
userDefinedFunction =
UserDefinedFunctionHelper.prepareInstance(
config, planner.getFlinkContext().getClassLoader(), userDefinedFunction);
UserDefinedFunctionHelper.prepareInstance(config, userDefinedFunction);

if (userDefinedFunction instanceof AsyncTableFunction) {
isAsyncEnabled = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ class BatchCommonSubGraphBasedOptimizer(planner: BatchPlanner)

override def getModuleManager: ModuleManager = planner.moduleManager

override def getRexFactory: RexFactory = context.getRexFactory
override def getRexFactory: RexFactory = context.getRexFactory

override def getFlinkRelBuilder: FlinkRelBuilder = planner.createRelBuilder
override def getFlinkRelBuilder: FlinkRelBuilder = planner.createRelBuilder

override def needFinalTimeIndicatorConversion: Boolean = true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.plan.optimize

import org.apache.flink.table.api.TableConfig
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@
/** Tests for {@link MergeTableLikeUtil}. */
public class MergeTableLikeUtilTest {

private final FlinkTypeFactory typeFactory = new FlinkTypeFactory(FlinkTypeSystem.INSTANCE);
private final FlinkTypeFactory typeFactory =
new FlinkTypeFactory(
Thread.currentThread().getContextClassLoader(), FlinkTypeSystem.INSTANCE);
private final SqlValidator sqlValidator =
PlannerMocks.create().getPlanner().getOrCreateSqlValidator();
private final MergeTableLikeUtil util = new MergeTableLikeUtil(sqlValidator, SqlNode::toString);
Expand Down
Loading

0 comments on commit 7aa58b3

Please sign in to comment.