Skip to content

Commit

Permalink
[FLINK-3138] [types] Method References are not supported as lambda ex…
Browse files Browse the repository at this point in the history
…pressions

This closes apache#2329.
  • Loading branch information
twalthr committed Aug 8, 2016
1 parent 8d25c64 commit ff77708
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,22 @@ public static <IN, OUT> TypeInformation<OUT> getUnaryOperatorReturnType(
if (m != null) {
// check for lambda type erasure
validateLambdaGenericParameters(m);

// parameters must be accessed from behind, since JVM can add additional parameters e.g. when using local variables inside lambda function
final int paramLen = m.getGenericParameterTypes().length - 1;
final Type input = (outputTypeArgumentIndex >= 0) ? m.getGenericParameterTypes()[paramLen - 1] : m.getGenericParameterTypes()[paramLen];
validateInputType((inputTypeArgumentIndex >= 0) ? extractTypeArgument(input, inputTypeArgumentIndex) : input, inType);
if(function instanceof ResultTypeQueryable) {

// method references "this" implicitly
if (paramLen < 0) {
// methods declaring class can also be a super class of the input type
// we only validate if the method exists in input type
validateInputContainsMethod(m, inType);
}
else {
final Type input = (outputTypeArgumentIndex >= 0) ? m.getGenericParameterTypes()[paramLen - 1] : m.getGenericParameterTypes()[paramLen];
validateInputType((inputTypeArgumentIndex >= 0) ? extractTypeArgument(input, inputTypeArgumentIndex) : input, inType);
}

if (function instanceof ResultTypeQueryable) {
return ((ResultTypeQueryable<OUT>) function).getProducedType();
}
return new TypeExtractor().privateCreateTypeInfo(
Expand Down Expand Up @@ -1234,7 +1244,17 @@ else if (typeInfo instanceof GenericTypeInfo<?>) {
}
}
}


private static void validateInputContainsMethod(Method m, TypeInformation<?> typeInfo) {
List<Method> methods = getAllDeclaredMethods(typeInfo.getTypeClass());
for (Method method : methods) {
if (method.equals(m)) {
return;
}
}
throw new InvalidTypesException("Type contains no method '" + m.getName() + "'.");
}

// --------------------------------------------------------------------------------------------
// Utility methods
// --------------------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,17 @@ public Tuple2<Integer, Long> map(Integer value) {
}
};

MapFunction<String, Integer> lambda = Integer::parseInt;
MapFunction<String, Integer> staticLambda = Integer::parseInt;
MapFunction<Integer, String> instanceLambda = Object::toString;
MapFunction<String, Integer> constructorLambda = Integer::new;

assertNull(FunctionUtils.checkAndExtractLambdaMethod(anonymousFromInterface));
assertNull(FunctionUtils.checkAndExtractLambdaMethod(anonymousFromClass));
assertNull(FunctionUtils.checkAndExtractLambdaMethod(fromProperClass));
assertNull(FunctionUtils.checkAndExtractLambdaMethod(fromDerived));
assertNotNull(FunctionUtils.checkAndExtractLambdaMethod(lambda));
assertNotNull(FunctionUtils.checkAndExtractLambdaMethod(staticLambda));
assertNotNull(FunctionUtils.checkAndExtractLambdaMethod(instanceLambda));
assertNotNull(FunctionUtils.checkAndExtractLambdaMethod(constructorLambda));
assertNotNull(FunctionUtils.checkAndExtractLambdaMethod(STATIC_LAMBDA));
}
catch (Exception e) {
Expand Down Expand Up @@ -248,4 +252,52 @@ public void testLambdaTypeErasure() {
Assert.assertTrue(ti instanceof MissingTypeInfo);
}

public static class MyType {
private int key;

public int getKey() {
return key;
}

public void setKey(int key) {
this.key = key;
}

protected int getKey2() {
return 0;
}
}

@Test
public void testInstanceMethodRefSameType() {
MapFunction<MyType, Integer> f = MyType::getKey;
TypeInformation<?> ti = TypeExtractor.getMapReturnTypes(f, TypeExtractor.createTypeInfo(MyType.class));
Assert.assertEquals(ti, BasicTypeInfo.INT_TYPE_INFO);
}

@Test
public void testInstanceMethodRefSuperType() {
MapFunction<Integer, String> f = Object::toString;
TypeInformation<?> ti = TypeExtractor.getMapReturnTypes(f, BasicTypeInfo.INT_TYPE_INFO);
Assert.assertEquals(ti, BasicTypeInfo.STRING_TYPE_INFO);
}

public static class MySubtype extends MyType {
public boolean test;
}

@Test
public void testInstanceMethodRefSuperTypeProtected() {
MapFunction<MySubtype, Integer> f = MyType::getKey2;
TypeInformation<?> ti = TypeExtractor.getMapReturnTypes(f, TypeExtractor.createTypeInfo(MySubtype.class));
Assert.assertEquals(ti, BasicTypeInfo.INT_TYPE_INFO);
}

@Test
public void testConstructorMethodRef() {
MapFunction<String, Integer> f = Integer::new;
TypeInformation<?> ti = TypeExtractor.getMapReturnTypes(f, BasicTypeInfo.STRING_TYPE_INFO);
Assert.assertEquals(ti, BasicTypeInfo.INT_TYPE_INFO);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@

public class MapITCase extends JavaProgramTestBase {

private static final String EXPECTED_RESULT = "bb\n" +
"bb\n" +
"bc\n" +
"bd\n";
private static final String EXPECTED_RESULT = "22\n" +
"22\n" +
"23\n" +
"24\n";

private String resultPath;

Expand All @@ -40,8 +40,8 @@ protected void preSubmit() throws Exception {
protected void testProgram() throws Exception {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

DataSet<String> stringDs = env.fromElements("aa", "ab", "ac", "ad");
DataSet<String> mappedDs = stringDs.map (s -> s.replace("a", "b"));
DataSet<Integer> stringDs = env.fromElements(11, 12, 13, 14);
DataSet<String> mappedDs = stringDs.map(Object::toString).map (s -> s.replace("1", "2"));
mappedDs.writeAsText(resultPath);
env.execute();
}
Expand Down

0 comments on commit ff77708

Please sign in to comment.