Skip to content

Commit

Permalink
[FLINK-2590] Fixes Scala's DataSetUtilsITCase
Browse files Browse the repository at this point in the history
  • Loading branch information
tillrohrmann committed Sep 1, 2015
1 parent ab14f90 commit 6a58aad
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import _root_.scala.reflect.ClassTag
* or with a unique identifier.
*/

class DataSetUtils[T](val self: DataSet[T]) extends AnyVal {
class DataSetUtils[T](val self: DataSet[T]) {

/**
* Method that takes a set of subtask index, total number of elements mappings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
Expand Down Expand Up @@ -69,7 +70,14 @@ public void testZipWithUniqueId() throws Exception {
long expectedSize = 100L;
DataSet<Long> numbers = env.generateSequence(1L, expectedSize);

Set<Tuple2<Long, Long>> result = Sets.newHashSet(DataSetUtils.zipWithUniqueId(numbers).collect());
DataSet<Long> ids = DataSetUtils.zipWithUniqueId(numbers).map(new MapFunction<Tuple2<Long,Long>, Long>() {
@Override
public Long map(Tuple2<Long, Long> value) throws Exception {
return value.f0;
}
});

Set<Long> result = Sets.newHashSet(ids.collect());

Assert.assertEquals(expectedSize, result.size());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,63 +19,46 @@
package org.apache.flink.api.scala.util

import org.apache.flink.api.scala._
import org.apache.flink.test.util.{MultipleProgramsTestBase, TestBaseUtils}
import org.junit.rules.TemporaryFolder
import org.apache.flink.test.util.{MultipleProgramsTestBase}
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.{After, Before, Rule, Test}
import org.junit._
import org.apache.flink.api.scala.DataSetUtils.utilsToDataSet

@RunWith(classOf[Parameterized])
class DataSetUtilsITCase (mode: MultipleProgramsTestBase.TestExecutionMode) extends
MultipleProgramsTestBase(mode){

private var resultPath: String = null
private var expectedResult: String = null
private val tempFolder: TemporaryFolder = new TemporaryFolder()

@Rule
def getFolder = tempFolder

@Before
@throws(classOf[Exception])
def before(): Unit = {
resultPath = tempFolder.newFile.toURI.toString
}
class DataSetUtilsITCase (
mode: MultipleProgramsTestBase.TestExecutionMode)
extends MultipleProgramsTestBase(mode) {

@Test
@throws(classOf[Exception])
def testZipWithIndex(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
env.setParallelism(1)

val input: DataSet[String] = env.fromElements("A", "B", "C", "D", "E", "F")
val result: DataSet[(Long, String)] = input.zipWithIndex
val expectedSize = 100L

result.writeAsCsv(resultPath, "\n", ",")
env.execute()
val numbers = env.generateSequence(0, expectedSize - 1)

expectedResult = "0,A\n" + "1,B\n" + "2,C\n" + "3,D\n" + "4,E\n" + "5,F"
val result = numbers.zipWithIndex.collect()

Assert.assertEquals(expectedSize, result.size)

for( ((index, _), expected) <- result.sortBy(_._1).zipWithIndex) {
Assert.assertEquals(expected, index)
}
}

@Test
@throws(classOf[Exception])
def testZipWithUniqueId(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
env.setParallelism(1)

val input: DataSet[String] = env.fromElements("A", "B", "C", "D", "E", "F")
val result: DataSet[(Long, String)] = input.zipWithUniqueId
val expectedSize = 100L

result.writeAsCsv(resultPath, "\n", ",")
env.execute()
val numbers = env.generateSequence(1L, expectedSize)

expectedResult = "0,A\n" + "2,B\n" + "4,C\n" + "6,D\n" + "8,E\n" + "10,F"
}
val result = numbers.zipWithUniqueId.collect().map(_._1).toSet

@After
@throws(classOf[Exception])
def after(): Unit = {
TestBaseUtils.compareResultsByLinesInMemory(expectedResult, resultPath)
Assert.assertEquals(expectedSize, result.size)
}
}

0 comments on commit 6a58aad

Please sign in to comment.