forked from apache/flink
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added KMeans Java examples using the broadcast variables API.
- release-1.5.0-rc1
- release-1.4.2
- release-1.4.2-rc2
- release-1.4.2-rc1
- release-1.4.1
- release-1.4.1-rc1
- release-1.4.0
- release-1.4.0-rc3
- release-1.4.0-rc2
- release-1.4.0-rc1
- release-1.3.3
- release-1.3.2
- release-1.3.2-rc3
- release-1.3.2-rc2
- release-1.3.2-rc1
- release-1.3.1
- release-1.3.0
- release-1.2.1
- release-1.2.0
- release-1.1.5
- release-1.1.4
- release-1.1.3
- release-1.1.2
- release-1.1.1
- release-1.1.0
- release-1.0.3
- release-1.0.2
- release-1.0.1
- release-1.0.0
- release-0.10.2
- release-0.10.1
- release-0.10.0
- release-0.9.1
- release-0.9.0
- release-0.8.1
- release-0.8.0
- release-0.7.0
- release-0.6
- release-0.5
- release-0.5-rc3
- release-0.5-rc2
- release-0.5-rc1
- pre-apache-rename
1 parent
4c93530
commit f45357c
Showing
5 changed files
with
404 additions
and
0 deletions.
There are no files selected for viewing
80 changes: 80 additions & 0 deletions
80
...es/src/main/java/eu/stratosphere/example/java/record/kmeans/KMeansIterativeBroadcast.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
/*********************************************************************************************************************** | ||
* Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu) | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations under the License. | ||
**********************************************************************************************************************/ | ||
|
||
package eu.stratosphere.example.java.record.kmeans; | ||
|
||
import eu.stratosphere.api.common.Plan; | ||
import eu.stratosphere.api.common.Program; | ||
import eu.stratosphere.api.common.ProgramDescription; | ||
import eu.stratosphere.api.common.operators.BulkIteration; | ||
import eu.stratosphere.api.common.operators.FileDataSink; | ||
import eu.stratosphere.api.common.operators.FileDataSource; | ||
import eu.stratosphere.api.java.record.operators.MapOperator; | ||
import eu.stratosphere.api.java.record.operators.ReduceOperator; | ||
import eu.stratosphere.example.java.record.kmeans.udfs.FindNearestCenterBroadcast; | ||
import eu.stratosphere.example.java.record.kmeans.udfs.PointInFormat; | ||
import eu.stratosphere.example.java.record.kmeans.udfs.PointOutFormat; | ||
import eu.stratosphere.example.java.record.kmeans.udfs.RecomputeClusterCenter; | ||
import eu.stratosphere.types.IntValue; | ||
|
||
|
||
public class KMeansIterativeBroadcast implements Program, ProgramDescription { | ||
|
||
@Override | ||
public Plan getPlan(String... args) { | ||
// parse job parameters | ||
final int numSubTasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1); | ||
final String dataPointInput = (args.length > 1 ? args[1] : ""); | ||
final String clusterInput = (args.length > 2 ? args[2] : ""); | ||
final String output = (args.length > 3 ? args[3] : ""); | ||
final int numIterations = (args.length > 4 ? Integer.parseInt(args[4]) : 1); | ||
|
||
// create DataSourceContract for cluster center input | ||
FileDataSource initialClusterPoints = new FileDataSource(new PointInFormat(), clusterInput, "Centers"); | ||
initialClusterPoints.setDegreeOfParallelism(1); | ||
|
||
BulkIteration iteration = new BulkIteration("K-Means Loop"); | ||
iteration.setInput(initialClusterPoints); | ||
iteration.setMaximumNumberOfIterations(numIterations); | ||
|
||
// create DataSourceContract for data point input | ||
FileDataSource dataPoints = new FileDataSource(new PointInFormat(), dataPointInput, "Data Points"); | ||
|
||
// create MapOperator for finding the nearest cluster centers | ||
MapOperator findNearestClusterCenters = MapOperator.builder(new FindNearestCenterBroadcast()) | ||
.setBroadcastVariable("centers", iteration.getPartialSolution()) | ||
.input(dataPoints) | ||
.name("Find Nearest Centers") | ||
.build(); | ||
|
||
// create ReduceOperator for computing new cluster positions | ||
ReduceOperator recomputeClusterCenter = ReduceOperator.builder(new RecomputeClusterCenter(), IntValue.class, 0) | ||
.input(findNearestClusterCenters) | ||
.name("Recompute Center Positions") | ||
.build(); | ||
iteration.setNextPartialSolution(recomputeClusterCenter); | ||
|
||
// create DataSinkContract for writing the new cluster positions | ||
FileDataSink finalResult = new FileDataSink(new PointOutFormat(), output, iteration, "New Center Positions"); | ||
|
||
// return the PACT plan | ||
Plan plan = new Plan(finalResult, "Iterative KMeans"); | ||
plan.setDefaultParallelism(numSubTasks); | ||
return plan; | ||
} | ||
|
||
@Override | ||
public String getDescription() { | ||
return "Parameters: <numSubStasks> <dataPoints> <clusterCenters> <output> <numIterations>"; | ||
} | ||
} |
87 changes: 87 additions & 0 deletions
87
...n/java/eu/stratosphere/example/java/record/kmeans/KMeansIterativeWithParameterInputs.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
/*********************************************************************************************************************** | ||
* Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu) | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations under the License. | ||
**********************************************************************************************************************/ | ||
|
||
package eu.stratosphere.example.java.record.kmeans; | ||
|
||
import eu.stratosphere.api.common.Plan; | ||
import eu.stratosphere.api.common.Program; | ||
import eu.stratosphere.api.common.ProgramDescription; | ||
import eu.stratosphere.api.common.operators.BulkIteration; | ||
import eu.stratosphere.api.common.operators.FileDataSink; | ||
import eu.stratosphere.api.common.operators.FileDataSource; | ||
import eu.stratosphere.api.java.record.operators.MapOperator; | ||
import eu.stratosphere.api.java.record.operators.ReduceOperator; | ||
import eu.stratosphere.example.java.record.kmeans.udfs.ComputeDistanceParameterized; | ||
import eu.stratosphere.example.java.record.kmeans.udfs.FindNearestCenter; | ||
import eu.stratosphere.example.java.record.kmeans.udfs.PointInFormat; | ||
import eu.stratosphere.example.java.record.kmeans.udfs.PointOutFormat; | ||
import eu.stratosphere.example.java.record.kmeans.udfs.RecomputeClusterCenter; | ||
import eu.stratosphere.types.IntValue; | ||
|
||
|
||
public class KMeansIterativeWithParameterInputs implements Program, ProgramDescription { | ||
|
||
@Override | ||
public Plan getPlan(String... args) { | ||
// parse job parameters | ||
final int numSubTasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1); | ||
final String dataPointInput = (args.length > 1 ? args[1] : ""); | ||
final String clusterInput = (args.length > 2 ? args[2] : ""); | ||
final String output = (args.length > 3 ? args[3] : ""); | ||
final int numIterations = (args.length > 4 ? Integer.parseInt(args[4]) : 1); | ||
|
||
// create DataSourceContract for cluster center input | ||
FileDataSource initialClusterPoints = new FileDataSource(new PointInFormat(), clusterInput, "Centers"); | ||
initialClusterPoints.setDegreeOfParallelism(1); | ||
|
||
BulkIteration iteration = new BulkIteration("K-Means Loop"); | ||
iteration.setInput(initialClusterPoints); | ||
iteration.setMaximumNumberOfIterations(numIterations); | ||
|
||
// create DataSourceContract for data point input | ||
FileDataSource dataPoints = new FileDataSource(new PointInFormat(), dataPointInput, "Data Points"); | ||
|
||
// create CrossOperator for distance computation | ||
MapOperator computeDistance = MapOperator.builder(new ComputeDistanceParameterized()) | ||
.setBroadcastVariable("centers", iteration.getPartialSolution()) | ||
.input(dataPoints) | ||
.name("Compute Distances") | ||
.build(); | ||
|
||
// create ReduceOperator for finding the nearest cluster centers | ||
ReduceOperator findNearestClusterCenters = ReduceOperator.builder(new FindNearestCenter(), IntValue.class, 0) | ||
.input(computeDistance) | ||
.name("Find Nearest Centers") | ||
.build(); | ||
|
||
// create ReduceOperator for computing new cluster positions | ||
ReduceOperator recomputeClusterCenter = ReduceOperator.builder(new RecomputeClusterCenter(), IntValue.class, 0) | ||
.input(findNearestClusterCenters) | ||
.name("Recompute Center Positions") | ||
.build(); | ||
iteration.setNextPartialSolution(recomputeClusterCenter); | ||
|
||
// create DataSinkContract for writing the new cluster positions | ||
FileDataSink finalResult = new FileDataSink(new PointOutFormat(), output, iteration, "New Center Positions"); | ||
|
||
// return the PACT plan | ||
Plan plan = new Plan(finalResult, "Iterative KMeans"); | ||
plan.setDefaultParallelism(numSubTasks); | ||
return plan; | ||
} | ||
|
||
@Override | ||
public String getDescription() { | ||
return "Parameters: <numSubStasks> <dataPoints> <clusterCenters> <output> <numIterations>"; | ||
} | ||
} |
86 changes: 86 additions & 0 deletions
86
...s/src/main/java/eu/stratosphere/example/java/record/kmeans/KMeansSingleStepBroadcast.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
/*********************************************************************************************************************** | ||
* Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu) | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations under the License. | ||
**********************************************************************************************************************/ | ||
|
||
package eu.stratosphere.example.java.record.kmeans; | ||
|
||
|
||
import eu.stratosphere.api.common.Plan; | ||
import eu.stratosphere.api.common.Program; | ||
import eu.stratosphere.api.common.ProgramDescription; | ||
import eu.stratosphere.api.common.operators.FileDataSink; | ||
import eu.stratosphere.api.common.operators.FileDataSource; | ||
import eu.stratosphere.api.java.record.operators.MapOperator; | ||
import eu.stratosphere.api.java.record.operators.ReduceOperator; | ||
import eu.stratosphere.example.java.record.kmeans.udfs.FindNearestCenterBroadcast; | ||
import eu.stratosphere.example.java.record.kmeans.udfs.PointInFormat; | ||
import eu.stratosphere.example.java.record.kmeans.udfs.PointOutFormat; | ||
import eu.stratosphere.example.java.record.kmeans.udfs.RecomputeClusterCenter; | ||
import eu.stratosphere.types.IntValue; | ||
|
||
/** | ||
* The K-Means cluster algorithm is well-known (see | ||
* http://en.wikipedia.org/wiki/K-means_clustering). KMeansIteration is a PACT | ||
* program that computes a single iteration of the k-means algorithm. The job | ||
* has two inputs, a set of data points and a set of cluster centers. A Cross | ||
* PACT is used to compute all distances from all centers to all points. A | ||
* following Reduce PACT assigns each data point to the cluster center that is | ||
* next to it. Finally, a second Reduce PACT compute the new locations of all | ||
* cluster centers. | ||
*/ | ||
public class KMeansSingleStepBroadcast implements Program, ProgramDescription { | ||
|
||
|
||
@Override | ||
public Plan getPlan(String... args) { | ||
// parse job parameters | ||
int numSubTasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1); | ||
String dataPointInput = (args.length > 1 ? args[1] : ""); | ||
String clusterInput = (args.length > 2 ? args[2] : ""); | ||
String output = (args.length > 3 ? args[3] : ""); | ||
|
||
// create DataSourceContract for data point input | ||
FileDataSource dataPoints = new FileDataSource(new PointInFormat(), dataPointInput, "Data Points"); | ||
dataPoints.getCompilerHints().addUniqueField(0); | ||
|
||
// create DataSourceContract for cluster center input | ||
FileDataSource clusterPoints = new FileDataSource(new PointInFormat(), clusterInput, "Centers"); | ||
clusterPoints.setDegreeOfParallelism(1); | ||
clusterPoints.getCompilerHints().addUniqueField(0); | ||
|
||
// create CrossOperator for distance computation | ||
MapOperator findNearestClusterCenters = MapOperator.builder(new FindNearestCenterBroadcast()) | ||
.setBroadcastVariable("centers", clusterPoints) | ||
.input(dataPoints) | ||
.name("Find Nearest Centers") | ||
.build(); | ||
|
||
// create ReduceOperator for computing new cluster positions | ||
ReduceOperator recomputeClusterCenter = ReduceOperator.builder(new RecomputeClusterCenter(), IntValue.class, 0) | ||
.input(findNearestClusterCenters) | ||
.name("Recompute Center Positions") | ||
.build(); | ||
|
||
// create DataSinkContract for writing the new cluster positions | ||
FileDataSink newClusterPoints = new FileDataSink(new PointOutFormat(), output, recomputeClusterCenter, "New Center Positions"); | ||
|
||
// return the PACT plan | ||
Plan plan = new Plan(newClusterPoints, "KMeans Iteration"); | ||
plan.setDefaultParallelism(numSubTasks); | ||
return plan; | ||
} | ||
|
||
@Override | ||
public String getDescription() { | ||
return "Parameters: [numSubStasks] [dataPoints] [clusterCenters] [output]"; | ||
} | ||
} |
70 changes: 70 additions & 0 deletions
70
...in/java/eu/stratosphere/example/java/record/kmeans/udfs/ComputeDistanceParameterized.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
/*********************************************************************************************************************** | ||
* Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu) | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations under the License. | ||
**********************************************************************************************************************/ | ||
package eu.stratosphere.example.java.record.kmeans.udfs; | ||
|
||
import java.io.Serializable; | ||
import java.util.Collection; | ||
|
||
import eu.stratosphere.api.java.record.functions.FunctionAnnotation.ConstantFieldsFirst; | ||
import eu.stratosphere.api.java.record.functions.MapFunction; | ||
import eu.stratosphere.configuration.Configuration; | ||
import eu.stratosphere.types.DoubleValue; | ||
import eu.stratosphere.types.IntValue; | ||
import eu.stratosphere.types.Record; | ||
import eu.stratosphere.util.Collector; | ||
|
||
/** | ||
* Cross PACT computes the distance of all data points to all cluster | ||
* centers. | ||
*/ | ||
@ConstantFieldsFirst({0,1}) | ||
public class ComputeDistanceParameterized extends MapFunction implements Serializable { | ||
private static final long serialVersionUID = 1L; | ||
|
||
private final DoubleValue distance = new DoubleValue(); | ||
|
||
private Collection<Record> clusterCenters; | ||
|
||
@Override | ||
public void open(Configuration parameters) throws Exception { | ||
this.clusterCenters = this.getRuntimeContext().getBroadcastVariable("centers"); | ||
} | ||
|
||
/** | ||
* Computes the distance of one data point to one cluster center. | ||
* | ||
* Output Format: | ||
* 0: pointID | ||
* 1: pointVector | ||
* 2: clusterID | ||
* 3: distance | ||
*/ | ||
@Override | ||
public void map(Record dataPointRecord, Collector<Record> out) { | ||
|
||
CoordVector dataPoint = dataPointRecord.getField(1, CoordVector.class); | ||
|
||
for (Record clusterCenterRecord : this.clusterCenters) { | ||
IntValue clusterCenterId = clusterCenterRecord.getField(0, IntValue.class); | ||
CoordVector clusterPoint = clusterCenterRecord.getField(1, CoordVector.class); | ||
|
||
this.distance.setValue(dataPoint.computeEuclidianDistance(clusterPoint)); | ||
|
||
// add cluster center id and distance to the data point record | ||
dataPointRecord.setField(2, clusterCenterId); | ||
dataPointRecord.setField(3, this.distance); | ||
|
||
out.collect(dataPointRecord); | ||
} | ||
} | ||
} |
81 changes: 81 additions & 0 deletions
81
...main/java/eu/stratosphere/example/java/record/kmeans/udfs/FindNearestCenterBroadcast.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
/*********************************************************************************************************************** | ||
* Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu) | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations under the License. | ||
**********************************************************************************************************************/ | ||
package eu.stratosphere.example.java.record.kmeans.udfs; | ||
|
||
import java.io.Serializable; | ||
import java.util.Collection; | ||
|
||
import eu.stratosphere.api.java.record.functions.FunctionAnnotation.ConstantFieldsFirst; | ||
import eu.stratosphere.api.java.record.functions.MapFunction; | ||
import eu.stratosphere.configuration.Configuration; | ||
import eu.stratosphere.types.IntValue; | ||
import eu.stratosphere.types.Record; | ||
import eu.stratosphere.util.Collector; | ||
|
||
/** | ||
* Determines the closest cluster center for a data point. | ||
*/ | ||
@ConstantFieldsFirst({0,1}) | ||
public class FindNearestCenterBroadcast extends MapFunction implements Serializable { | ||
private static final long serialVersionUID = 1L; | ||
|
||
private final IntValue centerId = new IntValue(); | ||
private final CoordVector dataPoint = new CoordVector(); | ||
private final CoordVector centerPoint = new CoordVector(); | ||
private final IntValue one = new IntValue(1); | ||
|
||
private final Record result = new Record(3); | ||
|
||
private Collection<Record> clusterCenters; | ||
|
||
@Override | ||
public void open(Configuration parameters) throws Exception { | ||
this.clusterCenters = this.getRuntimeContext().getBroadcastVariable("centers"); | ||
} | ||
|
||
/** | ||
* Computes a minimum aggregation on the distance of a data point to cluster centers. | ||
* | ||
* Output Format: | ||
* 0: centerID | ||
* 1: pointVector | ||
* 2: constant(1) (to enable combinable average computation in the following reducer) | ||
*/ | ||
@Override | ||
public void map(Record dataPointRecord, Collector<Record> out) { | ||
dataPointRecord.getFieldInto(1, this.dataPoint); | ||
|
||
double nearestDistance = Double.MAX_VALUE; | ||
|
||
// check all cluster centers | ||
for (Record clusterCenterRecord : this.clusterCenters) { | ||
clusterCenterRecord.getFieldInto(1, this.centerPoint); | ||
|
||
// compute distance | ||
double distance = this.dataPoint.computeEuclidianDistance(this.centerPoint); | ||
// update nearest cluster if necessary | ||
if (distance < nearestDistance) { | ||
nearestDistance = distance; | ||
clusterCenterRecord.getFieldInto(0, this.centerId); | ||
} | ||
} | ||
|
||
// emit a new record with the center id and the data point. add a one to ease the | ||
// implementation of the average function with a combiner | ||
this.result.setField(0, this.centerId); | ||
this.result.setField(1, this.dataPoint); | ||
this.result.setField(2, this.one); | ||
|
||
out.collect(this.result); | ||
} | ||
} |