Skip to content

Commit

Permalink
Properly combine spark-submit arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
dszeto committed Apr 9, 2017
1 parent 070f179 commit 9deca1a
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 4 deletions.
101 changes: 97 additions & 4 deletions tools/src/main/scala/org/apache/predictionio/tools/Runner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.predictionio.tools.ReturnTypes._
import org.apache.predictionio.workflow.WorkflowUtils

import scala.collection.mutable
import scala.sys.process._

case class SparkArgs(
Expand Down Expand Up @@ -95,6 +96,92 @@ object Runner extends EitherLogging {
}
}

/** Group argument values by argument names
*
* This only works with long argument names immediately followed by a value
*
* Input:
* Seq("--foo", "bar", "--flag", "--dead", "beef baz", "n00b", "--foo", "jeez")
*
* Output:
* Map("--foo" -> Seq("bar", "jeez"), "--dead"- > "beef baz")
*
* @param arguments Sequence of argument names and values
* @return A map with argument values keyed by the same argument name
*/
def groupByArgumentName(arguments: Seq[String]): Map[String, Seq[String]] = {
val argumentMap = mutable.HashMap.empty[String, Seq[String]]
arguments.foldLeft("") { (prev, current) =>
if (prev.startsWith("--") && !current.startsWith("--")) {
if (argumentMap.contains(prev)) {
argumentMap(prev) = argumentMap(prev) :+ current
} else {
argumentMap(prev) = Seq(current)
}
}
current
}
argumentMap.toMap
}

/** Remove argument names and values
*
* This only works with long argument names immediately followed by a value
*
* Input:
* Seq("--foo", "bar", "--flag", "--dead", "beef baz", "n00b", "--foo", "jeez")
* Set("--flag", "--foo")
*
* Output:
* Seq("--flag", "--dead", "beef baz", "n00b")
*
* @param arguments Sequence of argument names and values
* @param remove Name of argument and associated values to remove
* @return Sequence of argument names and values with targets removed
*/
def removeArguments(arguments: Seq[String], remove: Set[String]): Seq[String] = {
if (remove.isEmpty) {
arguments
} else {
arguments.foldLeft(Seq.empty[String]) { (ongoing, current) =>
if (ongoing.isEmpty) {
Seq(current)
} else {
if (remove.contains(ongoing.last) && !current.startsWith("--")) {
ongoing.take(ongoing.length - 1)
} else {
ongoing :+ current
}
}
}
}
}

/** Combine repeated arguments together
*
* Input:
* Seq("--foo", "bar", "--flag", "--dead", "beef baz", "n00b", "--foo", "jeez")
* Map("--foo", (_ + _))
*
* Output:
* Seq("--flag", "--dead", "beef baz", "n00b", "--foo", "bar jeez")
*
* @param arguments Sequence of argument names and values
* @param combinators Map of argument name to combinator function
* @return Sequence of argument names and values with specific argument values combined
*/
def combineArguments(
arguments: Seq[String],
combinators: Map[String, (String, String) => String]): Seq[String] = {
val argumentsToCombine: Map[String, Seq[String]] =
groupByArgumentName(arguments).filterKeys(combinators.keySet.contains(_))
val argumentsMinusToCombine = removeArguments(arguments, combinators.keySet)
val combinedArguments = argumentsToCombine flatMap { kv =>
Seq(kv._1, kv._2.reduce(combinators(kv._1)))
}
argumentsMinusToCombine ++ combinedArguments
}

def runOnSpark(
className: String,
classArgs: Seq[String],
Expand Down Expand Up @@ -189,17 +276,23 @@ object Runner extends EitherLogging {
}

val verboseArg = if (verbose) Seq("--verbose") else Nil
val pioLogDir = Option(System.getProperty("pio.log.dir")).getOrElse(s"${pioHome}/log")
val pioLogDir = Option(System.getProperty("pio.log.dir")).getOrElse(s"$pioHome/log")

val sparkSubmit = Seq(
sparkSubmitCommand,
val sparkSubmitArgs = Seq(
sa.sparkPassThrough,
Seq("--class", className),
sparkSubmitJars,
sparkSubmitFiles,
sparkSubmitExtraClasspaths,
sparkSubmitKryo,
Seq("--driver-java-options", s"-Dpio.log.dir=${pioLogDir}"),
Seq("--driver-java-options", s"-Dpio.log.dir=$pioLogDir")).flatten

val whitespaceCombinator = (a: String, b: String) => s"$a $b"
val combinators = Map("--driver-java-options" -> whitespaceCombinator)

val sparkSubmit = Seq(
sparkSubmitCommand,
combineArguments(sparkSubmitArgs, combinators),
Seq(mainJar),
detectFilePaths(fs, sa.scratchUri, classArgs),
Seq("--env", pioEnvVars),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.predictionio.tools

import org.specs2.mutable.Specification

class RunnerSpec extends Specification {
"groupByArgumentName" >> {
"test1" >> {
val test = Seq("--foo", "bar", "--flag", "--dead", "beef baz", "n00b", "--foo", "jeez")
Runner.groupByArgumentName(test) must havePairs(
"--foo" -> Seq("bar", "jeez"),
"--dead" -> Seq("beef baz"))
}

"test2" >> {
val test =
Seq("--foo", "--bar", "flag", "--dead", "beef baz", "n00b", "--foo", "jeez", "--flag")
Runner.groupByArgumentName(test) must havePairs(
"--foo" -> Seq("jeez"),
"--bar" -> Seq("flag"),
"--dead" -> Seq("beef baz"))
}
}

"removeArguments" >> {
"test1" >> {
val test = Seq("--foo", "bar", "--flag", "--dead", "beef baz", "n00b", "--foo", "jeez")
val remove = Set("--flag", "--foo")
Runner.removeArguments(test, remove) === Seq("--flag", "--dead", "beef baz", "n00b")
}

"test2" >> {
val test =
Seq("--foo", "--bar", "flag", "--dead", "beef baz", "n00b", "--foo", "jeez", "--flag")
val remove = Set("--flag", "--foo")
Runner.removeArguments(test, remove) ===
Seq("--foo", "--bar", "flag", "--dead", "beef baz", "n00b", "--flag")
}
}

"combineArguments" >> {
"test1" >> {
val test = Seq("--foo", "bar", "--flag", "--dead", "beef baz", "n00b", "--foo", "jeez")
val combinators = Map("--foo" -> ((a: String, b: String) => s"$a $b"))
Runner.combineArguments(test, combinators) ===
Seq("--flag", "--dead", "beef baz", "n00b", "--foo", "bar jeez")
}

"test2" >> {
val test =
Seq("--foo", "--bar", "flag", "--dead", "beef baz", "n00b", "--foo", "jeez", "--flag")
val combinators = Map("--foo" -> ((a: String, b: String) => s"$a $b"))
Runner.combineArguments(test, combinators) ===
Seq("--foo", "--bar", "flag", "--dead", "beef baz", "n00b", "--flag", "--foo", "jeez")
}
}
}

0 comments on commit 9deca1a

Please sign in to comment.