Skip to content

Commit

Permalink
Support BiConsumer/BiFunction in Groovy
Browse files Browse the repository at this point in the history
  • Loading branch information
graemerocher committed Nov 3, 2017
1 parent 632e8be commit aa48927
Show file tree
Hide file tree
Showing 26 changed files with 870 additions and 148 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.particleframework.core.annotation;

import java.lang.annotation.Documented;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;

Expand All @@ -9,7 +10,8 @@
* @author Graeme Rocher
* @since 1.0
*/
@Retention(RetentionPolicy.SOURCE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Internal {

}
7 changes: 6 additions & 1 deletion function-groovy/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,9 @@ dependencies {
compile project(":inject-groovy")
compile project(":function")
runtime project(":configurations/jackson")
}

testCompile project(":function-web")
testRuntime project(":http-server-netty")
testCompile 'com.squareup.okhttp3:okhttp:3.8.1'
}
//compileTestGroovy.groovyOptions.forkOptions.jvmArgs = ['-Xdebug', '-Xrunjdwp:transport=dt_socket,server=y,suspend=y,address=5005']
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright 2017 original authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.particleframework.function.groovy;

import org.particleframework.context.ApplicationContext;
import org.particleframework.context.env.PropertySource;
import org.particleframework.core.annotation.Internal;
import org.particleframework.function.executor.FunctionInitializer;

import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;

/**
* Base class for Function scripts
*
* @author Graeme Rocher
* @since 1.0
*/
public abstract class FunctionScript extends FunctionInitializer implements PropertySource {

public FunctionScript() {
}

protected FunctionScript(ApplicationContext applicationContext) {
super(applicationContext, false);
}

private Map<String, Object> props;

@Override
@Internal
public Object get(String key) {
return resolveProps().get(key);
}

@Override
@Internal
public Iterator<String> iterator() {
return resolveProps().keySet().iterator();
}

protected void addProperty(String name, Object value) {
resolveProps().put(name, value);
}

private Map<String, Object> resolveProps() {
if(props == null) {
props = new LinkedHashMap<>();
}
return props;
}

@Override
@Internal
protected void startThis(ApplicationContext applicationContext) {
// no-op this, equivalent behaviour will be called from the script constructor
}

@Override
@Internal
protected void injectThis(ApplicationContext applicationContext) {
// no-op this, equivalent behaviour will be called from the script constructor
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,34 @@ package org.particleframework.function.groovy
import groovy.transform.CompileStatic
import groovy.transform.Field
import org.codehaus.groovy.ast.*
import org.codehaus.groovy.ast.expr.ArgumentListExpression
import org.codehaus.groovy.ast.expr.BinaryExpression
import org.codehaus.groovy.ast.expr.ConstructorCallExpression
import org.codehaus.groovy.ast.expr.DeclarationExpression
import org.codehaus.groovy.ast.expr.Expression
import org.codehaus.groovy.ast.expr.MethodCallExpression
import org.codehaus.groovy.ast.stmt.BlockStatement
import org.codehaus.groovy.ast.stmt.ExpressionStatement
import org.codehaus.groovy.ast.tools.GenericsUtils
import org.codehaus.groovy.control.CompilePhase
import org.codehaus.groovy.control.SourceUnit
import org.codehaus.groovy.transform.ASTTransformation
import org.codehaus.groovy.transform.FieldASTTransformation
import org.codehaus.groovy.transform.GroovyASTTransformation
import org.codehaus.groovy.transform.sc.transformers.StaticCompilationTransformer
import org.codehaus.groovy.transform.stc.StaticTypeCheckingVisitor
import org.particleframework.ast.groovy.InjectTransform
import org.particleframework.ast.groovy.annotation.AnnotationStereoTypeFinder
import org.particleframework.ast.groovy.utils.AstMessageUtils
import org.particleframework.ast.groovy.utils.AstUtils
import org.particleframework.context.ApplicationContext
import org.particleframework.function.executor.FunctionInitializer
import org.particleframework.context.env.groovy.SetPropertyTransformer
import org.particleframework.core.naming.NameUtils
import org.particleframework.function.FunctionBean

import javax.inject.Inject
import java.lang.reflect.Modifier
import java.util.function.BiConsumer
import java.util.function.BiFunction
import java.util.function.Consumer
import java.util.function.Function
import java.util.function.Supplier

import static org.codehaus.groovy.ast.tools.GeneralUtils.*

Expand All @@ -51,20 +59,20 @@ import static org.codehaus.groovy.ast.tools.GeneralUtils.*
@GroovyASTTransformation(phase = CompilePhase.CANONICALIZATION)
class FunctionTransform implements ASTTransformation{
public static final ClassNode FIELD_TYPE = ClassHelper.make(Field)
AnnotationStereoTypeFinder stereoTypeFinder = new AnnotationStereoTypeFinder();

@Override
void visit(ASTNode[] nodes, SourceUnit source) {

def uri = source.getSource().getURI()
if(uri != null) {
def file = uri.toString()
if(!file.contains("src/main/groovy") || file.endsWith("logback.groovy") || file.endsWith("application.groovy") || file ==~ (/\S+\/application-\S+.groovy/)) {
if( !file.endsWith("Function.groovy") && !file.toLowerCase(Locale.ENGLISH).endsWith("-function.groovy")) {
return
}
}
for(node in source.getAST().classes) {
if(node.isScript()) {
node.setSuperClass(ClassHelper.makeCached(FunctionInitializer))
node.setSuperClass(ClassHelper.makeCached(FunctionScript))
MethodNode functionMethod = node.methods.find() { method -> !method.isAbstract() && !method.isStatic() && method.isPublic() && method.name != 'run' }
if(functionMethod == null) {
AstMessageUtils.error(source, node, "Function must have at least one public method")
Expand All @@ -77,7 +85,28 @@ class FunctionTransform implements ASTTransformation{
MethodNode mainMethod = node.getMethod("main", new Parameter(ClassHelper.make(([] as String[]).class), "args"))
Parameter argParam = mainMethod.getParameters()[0]
def thisInstance = varX('$this')
def functionCall = callX(thisInstance, functionMethod.getName(), args(callX(varX("it"), "get", args(classX(functionMethod.parameters[0].type.plainNodeReference)))))
def parameters = functionMethod.parameters
int argLength = parameters.length
boolean isVoidReturn = functionMethod.returnType == ClassHelper.VOID_TYPE

if(argLength > 2) {
AstMessageUtils.error(source, node, "Functions can only have a maximum of 2 arguments")
continue
}
else if(argLength == 0 && isVoidReturn) {
AstMessageUtils.error(source, node, "Zero argument functions must return a value")
continue
}

MethodCallExpression functionCall

if(argLength == 1) {
functionCall = callX(thisInstance, functionMethod.getName(), args(callX(varX("it"), "get", args(classX(parameters[0].type.plainNodeReference)))))
}
else {
functionCall = callX(thisInstance, functionMethod.getName())
}

def closureExpression = closureX(stmt(functionCall))
mainMethod.variableScope.putDeclaredVariable(thisInstance)
closureExpression.setVariableScope(mainMethod.variableScope)
Expand All @@ -87,8 +116,11 @@ class FunctionTransform implements ASTTransformation{
stmt(callX(thisInstance, "run", args(varX(argParam), closureExpression)))
)
)
new StaticCompilationTransformer(source, new StaticTypeCheckingVisitor(source, node)).visitMethod(mainMethod)
def code = runMethod.getCode()
def appCtx = varX("applicationContext")
def constructorBody = block(

)
if(code instanceof BlockStatement) {
BlockStatement bs = (BlockStatement)code
for(st in bs.statements) {
Expand All @@ -99,36 +131,150 @@ class FunctionTransform implements ASTTransformation{
DeclarationExpression de = (DeclarationExpression)exp
def initial = de.getVariableExpression().getInitialExpression()
if ( initial == null) {
de.addAnnotation(new AnnotationNode(ClassHelper.make(Inject)))
de.addAnnotation(new AnnotationNode(AstUtils.INJECT_ANNOTATION))
new FieldASTTransformation().visit([new AnnotationNode(FIELD_TYPE), de] as ASTNode[], source)
}
}
else if(exp instanceof BinaryExpression || exp instanceof MethodCallExpression) {
def setPropertyTransformer = new SetPropertyTransformer(source)
setPropertyTransformer.setPropertyMethodName = "addProperty"
constructorBody.addStatement(
stmt(setPropertyTransformer.transform(exp))
)
}
}
}
}

node.addMethod(new MethodNode(
"injectThis",
Modifier.PROTECTED,
ClassHelper.VOID_TYPE,
params(param(ClassHelper.make(ApplicationContext),"ctx")),
null,
block()
))

ConstructorNode constructorNode = new ConstructorNode(Modifier.PUBLIC, block(
constructorBody.addStatement(block(
stmt(
callX(varX("this"), "startEnvironment", appCtx)
),
stmt(
callX(varX("applicationContext"), "inject", varX("this"))
callX(appCtx, "inject", varX("this"))
)
))
constructorNode.addAnnotation(new AnnotationNode(ClassHelper.make(Inject)))
ConstructorNode constructorNode = new ConstructorNode(Modifier.PUBLIC, constructorBody)
node.declaredConstructors.clear()
node.addConstructor(constructorNode)
def ctxParam = param(ClassHelper.make(ApplicationContext), "ctx")

def applicationContextConstructor = new ConstructorNode(
Modifier.PUBLIC,
params(ctxParam),
null,
stmt(
ctorX(ClassNode.SUPER, varX(ctxParam))
)

)
for(field in node.getFields()) {
field.addAnnotation(new AnnotationNode(AstUtils.INJECT_ANNOTATION))
def setterName = getSetterName(field.getName())
def setterMethod = node.getMethod(setterName, params(param(field.getType(), "arg")))
if(setterMethod != null) {
setterMethod.addAnnotation(new AnnotationNode(AstUtils.INTERNAL_ANNOTATION))
}
}

applicationContextConstructor.addAnnotation(new AnnotationNode(AstUtils.INJECT_ANNOTATION))
def functionBean = new AnnotationNode(ClassHelper.make(FunctionBean))
String functionName = NameUtils.hyphenate(node.nameWithoutPackage)
functionName -= '-function'

functionBean.setMember("value", constX(functionName))
node.addAnnotation(functionBean)
node.addConstructor(
applicationContextConstructor
)

if(isVoidReturn) {
if(argLength == 1) {
implementConsumer(functionMethod, node)
}
else {
implementBiConsumer(functionMethod, node)
}

}
else {
if(argLength == 0) {
def returnType = ClassHelper.getWrapper(functionMethod.returnType.plainNodeReference)
node.addInterface(GenericsUtils.makeClassSafeWithGenerics(
ClassHelper.make(Supplier).plainNodeReference,
new GenericsType(returnType)
))
def mn = new MethodNode("get", Modifier.PUBLIC, functionMethod.returnType.plainNodeReference, AstUtils.ZERO_PARAMETERS, null, stmt(
callX(varX("this"), functionMethod.getName())
))
mn.addAnnotation(new AnnotationNode(AstUtils.INTERNAL_ANNOTATION))
node.addMethod(mn)
}
else {
if(argLength == 1) {
implementFunction(functionMethod, node)
}
else {
implementBiFunction(functionMethod, node)
}

}
}
new InjectTransform().visit(nodes, source)
}

}
}
}

protected void implementConsumer(MethodNode functionMethod, ClassNode classNode) {
implementFunction(functionMethod, classNode, Consumer, ClassHelper.VOID_TYPE, "accept")
}

protected void implementBiConsumer(MethodNode functionMethod, ClassNode classNode) {
implementFunction(functionMethod, classNode, BiConsumer, ClassHelper.VOID_TYPE, "accept")
}

protected void implementFunction(MethodNode functionMethod, ClassNode classNode) {
ClassNode returnType = ClassHelper.getWrapper(functionMethod.returnType.plainNodeReference)
implementFunction(functionMethod, classNode, Function, returnType, "apply")
}

protected void implementBiFunction(MethodNode functionMethod, ClassNode classNode) {
ClassNode returnType = ClassHelper.getWrapper(functionMethod.returnType.plainNodeReference)
implementFunction(functionMethod, classNode, BiFunction, returnType, "apply")
}

protected void implementFunction(MethodNode functionMethod, ClassNode classNode, Class functionType, ClassNode returnType, String methodName) {
List<ClassNode> argTypes = []
for(p in functionMethod.parameters) {
argTypes.add(ClassHelper.getWrapper(p.type.plainNodeReference))
}
List<GenericsType> genericsTypes = []
for(type in argTypes) {
genericsTypes.add(new GenericsType(type))
}
if(returnType != ClassHelper.VOID_TYPE) {
genericsTypes.add(new GenericsType(returnType))
}
classNode.addInterface(GenericsUtils.makeClassSafeWithGenerics(
ClassHelper.make(functionType).plainNodeReference,
genericsTypes as GenericsType[]
))
List<Parameter> params = []
int i = 0
ArgumentListExpression argList = new ArgumentListExpression()
for(type in argTypes) {
def p = param(type, "arg${i++}")
params.add(p)
argList.addExpression(varX(p))
}
def mn = new MethodNode(methodName, Modifier.PUBLIC, returnType, params as Parameter[], null, stmt(
callX(varX("this"), functionMethod.getName(), argList))
)
mn.addAnnotation(new AnnotationNode(AstUtils.INTERNAL_ANNOTATION))
classNode.addMethod(mn)
}


}
Loading

0 comments on commit aa48927

Please sign in to comment.