Skip to content

Commit

Permalink
Merge pull request apache#21809 from ihji/BEAM-14506
Browse files Browse the repository at this point in the history
[BEAM-14506] Adding testcases and examples for xlang Python RunInference
  • Loading branch information
ihji authored Jul 13, 2022
2 parents b78a080 + 1dc015f commit 316dfae
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ class BeamModulePlugin implements Plugin<Project> {
Integer numParallelTests = 1
// Whether the pipeline needs --sdk_location option
boolean needsSdkLocation = false
// semi_persist_dir for SDK containers
String semiPersistDir = "/tmp"
// classpath for running tests.
FileCollection classpath
}
Expand Down Expand Up @@ -2353,6 +2355,7 @@ class BeamModulePlugin implements Plugin<Project> {
systemProperty "beamTestPipelineOptions", JsonOutput.toJson(config.javaPipelineOptions)
systemProperty "expansionJar", expansionJar
systemProperty "expansionPort", port
systemProperty "semiPersistDir", config.semiPersistDir
classpath = config.classpath + project.files(
project.project(":runners:core-construction-java").sourceSets.test.runtimeClasspath,
project.project(":sdks:java:extensions:python").sourceSets.test.runtimeClasspath
Expand Down
1 change: 1 addition & 0 deletions runners/google-cloud-dataflow-java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ createCrossLanguageValidatesRunnerTask(
classpath: configurations.validatesRunner,
numParallelTests: Integer.MAX_VALUE,
needsSdkLocation: true,
semiPersistDir: "/var/opt/google",
pythonPipelineOptions: [
"--runner=TestDataflowRunner",
"--project=${dataflowProject}",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.python.transforms;

import java.util.Map;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.extensions.python.PythonExternalTransform;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.util.PythonCallableSource;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;

/** Wrapper for invoking external Python RunInference. */
public class RunInference extends PTransform<PCollection<?>, PCollection<Row>> {
private final String modelLoader;
private final Schema schema;
private final Map<String, Object> kwargs;
private final String expansionService;

/**
* Instantiates a multi-language wrapper for a Python RunInference with a given model loader.
*
* @param modelLoader A Python callable for a model loader class object.
* @param exampleType A schema field type for the example column in output rows.
* @param inferenceType A schema field type for the inference column in output rows.
* @return A {@link RunInference} for the given model loader.
*/
public static RunInference of(
String modelLoader, Schema.FieldType exampleType, Schema.FieldType inferenceType) {
Schema schema =
Schema.of(
Schema.Field.of("example", exampleType), Schema.Field.of("inference", inferenceType));
return new RunInference(modelLoader, schema, ImmutableMap.of(), "");
}

/**
* Instantiates a multi-language wrapper for a Python RunInference with a given model loader.
*
* @param modelLoader A Python callable for a model loader class object.
* @param schema A schema for output rows.
* @return A {@link RunInference} for the given model loader.
*/
public static RunInference of(String modelLoader, Schema schema) {
return new RunInference(modelLoader, schema, ImmutableMap.of(), "");
}

/**
* Sets keyword arguments for the model loader.
*
* @return A {@link RunInference} with keyword arguments.
*/
public RunInference withKwarg(String key, Object arg) {
ImmutableMap.Builder<String, Object> builder =
ImmutableMap.<String, Object>builder().putAll(kwargs).put(key, arg);
return new RunInference(modelLoader, schema, builder.build(), expansionService);
}

/**
* Sets an expansion service endpoint for RunInference.
*
* @param expansionService A URL for a Python expansion service.
* @return A {@link RunInference} for the given expansion service endpoint.
*/
public RunInference withExpansionService(String expansionService) {
return new RunInference(modelLoader, schema, kwargs, expansionService);
}

private RunInference(
String modelLoader, Schema schema, Map<String, Object> kwargs, String expansionService) {
this.modelLoader = modelLoader;
this.schema = schema;
this.kwargs = kwargs;
this.expansionService = expansionService;
}

@Override
public PCollection<Row> expand(PCollection<?> input) {
return input.apply(
PythonExternalTransform.<PCollection<?>, PCollection<Row>>from(
"apache_beam.ml.inference.base.RunInference.from_callable", expansionService)
.withKwarg("model_handler_provider", PythonCallableSource.of(modelLoader))
.withKwargs(kwargs)
.withOutputCoder(RowCoder.of(schema)));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.python.transforms;

import java.util.Arrays;
import java.util.Optional;
import org.apache.beam.runners.core.construction.BaseExternalTest;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.UsesPythonExpansionService;
import org.apache.beam.sdk.testing.ValidatesRunner;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class RunInferenceTransformTest extends BaseExternalTest {
@Test
@Category({ValidatesRunner.class, UsesPythonExpansionService.class})
public void testRunInference() {
String stagingLocation =
Optional.ofNullable(System.getProperty("semiPersistDir")).orElse("/tmp");
Schema schema =
Schema.of(
Schema.Field.of("example", Schema.FieldType.array(Schema.FieldType.INT64)),
Schema.Field.of("inference", Schema.FieldType.INT32));
Row row0 = Row.withSchema(schema).addArray(0L, 0L).addValue(0).build();
Row row1 = Row.withSchema(schema).addArray(1L, 1L).addValue(1).build();
PCollection<Row> col =
testPipeline
.apply(Create.<Iterable<Long>>of(Arrays.asList(0L, 0L), Arrays.asList(1L, 1L)))
.setCoder(IterableCoder.of(VarLongCoder.of()))
.apply(
RunInference.of(
"apache_beam.ml.inference.sklearn_inference.SklearnModelHandlerNumpy",
schema)
.withKwarg(
// The test expansion service creates the test model and saves it to the
// returning external environment as a dependency.
// (sdks/python/apache_beam/runners/portability/expansion_service_test.py)
// The dependencies for Python SDK harness are supposed to be staged to
// $SEMI_PERSIST_DIR/staged directory.
"model_uri", String.format("%s/staged/sklearn_model", stagingLocation))
.withExpansionService(expansionAddr));
PAssert.that(col).containsInAnyOrder(row0, row1);
}
}
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def __init__(

# TODO(BEAM-14046): Add and link to help documentation.
@classmethod
def create(cls, model_handler_provider, **kwargs):
def from_callable(cls, model_handler_provider, **kwargs):
"""Multi-language friendly constructor.
This constructor can be used with fully_qualified_named_transform to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import argparse
import logging
import pickle
import signal
import sys
import typing
Expand All @@ -33,6 +34,7 @@
from apache_beam.portability.api import external_transforms_pb2
from apache_beam.runners.portability import artifact_service
from apache_beam.runners.portability import expansion_service
from apache_beam.runners.portability.stager import Stager
from apache_beam.transforms import fully_qualified_named_transform
from apache_beam.transforms import ptransform
from apache_beam.transforms.environments import PyPIArtifactRegistry
Expand Down Expand Up @@ -347,6 +349,24 @@ def parse_string_payload(input_byte):
return RowCoder(payload.schema).decode(payload.payload)._asdict()


def create_test_sklearn_model(file_name):
from sklearn import svm
x = [[0, 0], [1, 1]]
y = [0, 1]
model = svm.SVC()
model.fit(x, y)
with open(file_name, 'wb') as file:
pickle.dump(model, file)


def update_sklearn_model_dependency(env):
model_file = "/tmp/sklearn_test_model"
staged_name = "sklearn_model"
create_test_sklearn_model(model_file)
env._artifacts.append(
Stager._create_file_stage_to_artifact(model_file, staged_name))


server = None


Expand All @@ -367,12 +387,12 @@ def main(unused_argv):
with fully_qualified_named_transform.FullyQualifiedNamedTransform.with_filter(
options.fully_qualified_name_glob):
server = grpc.server(thread_pool_executor.shared_unbounded_instance())
expansion_servicer = expansion_service.ExpansionServiceServicer(
PipelineOptions(
["--experiments", "beam_fn_api", "--sdk_location", "container"]))
update_sklearn_model_dependency(expansion_servicer._default_environment)
beam_expansion_api_pb2_grpc.add_ExpansionServiceServicer_to_server(
expansion_service.ExpansionServiceServicer(
PipelineOptions(
["--experiments", "beam_fn_api", "--sdk_location",
"container"])),
server)
expansion_servicer, server)
beam_artifact_api_pb2_grpc.add_ArtifactRetrievalServiceServicer_to_server(
artifact_service.ArtifactRetrievalService(
artifact_service.BeamFilesystemHandler(None).file_reader),
Expand Down

0 comments on commit 316dfae

Please sign in to comment.