Skip to content

Commit

Permalink
Add testShardDebug
Browse files Browse the repository at this point in the history
  • Loading branch information
reibitto committed Jul 12, 2022
1 parent 63ed564 commit cd3c61b
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 20 deletions.
30 changes: 14 additions & 16 deletions src/main/scala/sbttestshards/ShardingAlgorithm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,24 @@ import java.time.Duration

// This trait is open so that users can implement a custom `ShardingAlgorithm` if they'd like
trait ShardingAlgorithm {
def isInShard(specName: String, shardContext: ShardContext): Boolean

/** Determines whether the specified spec will run on this shard or not. */
def shouldRun(specName: String, shardContext: ShardContext): Boolean
}

object ShardingAlgorithm {
final case object Always extends ShardingAlgorithm {
override def isInShard(specName: String, shardContext: ShardContext): Boolean = true
override def shouldRun(specName: String, shardContext: ShardContext): Boolean = true
}

final case object Never extends ShardingAlgorithm {
override def isInShard(specName: String, shardContext: ShardContext): Boolean = false
override def shouldRun(specName: String, shardContext: ShardContext): Boolean = false
}

final case object SuiteName extends ShardingAlgorithm {
override def isInShard(specName: String, shardContext: ShardContext): Boolean = {
val shouldRun = specName.hashCode % shardContext.testShardCount == shardContext.testShard

println(s"${specName} will run? ${shouldRun}")

shouldRun
}
override def shouldRun(specName: String, shardContext: ShardContext): Boolean =
// TODO: Test whether `hashCode` gets a good distribution. Otherwise implement a different hash algorithm.
specName.hashCode.abs % shardContext.testShardCount == shardContext.testShard
}

final case class Balance(
Expand All @@ -35,14 +33,11 @@ object ShardingAlgorithm {
private val averageTime: Option[Duration] = {
val allTimeTaken = tests.flatMap(_.timeTaken)
allTimeTaken.reduceOption(_.plus(_)).map { d =>
if (d.isZero) Duration.ZERO
if (allTimeTaken.isEmpty) Duration.ZERO
else d.dividedBy(allTimeTaken.length)
}
}

private final case class TestSuiteInfoSimple(name: String, timeTaken: Duration)
private final case class TestBucket(var tests: List[TestSuiteInfoSimple], var sum: Duration)

private def createBucketMap(testShardCount: Int) = {
val durationOrdering: Ordering[Duration] = (a: Duration, b: Duration) => a.compareTo(b)

Expand Down Expand Up @@ -70,10 +65,13 @@ object ShardingAlgorithm {
// TODO: Maybe print a warning if it's not a multiple of it.
private val bucketMap: Map[String, Int] = createBucketMap(bucketCount)

def isInShard(specName: String, shardContext: ShardContext): Boolean =
def shouldRun(specName: String, shardContext: ShardContext): Boolean =
bucketMap.get(specName) match {
case Some(bucketIndex) => bucketIndex == shardContext.testShard
case None => fallbackShardingAlgorithm.isInShard(specName, shardContext)
case None => fallbackShardingAlgorithm.shouldRun(specName, shardContext)
}
}

private final case class TestSuiteInfoSimple(name: String, timeTaken: Duration)
private final case class TestBucket(var tests: List[TestSuiteInfoSimple], var sum: Duration)
}
27 changes: 23 additions & 4 deletions src/main/scala/sbttestshards/TestShardsPlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,39 @@ object TestShardsPlugin extends AutoPlugin {
val testShard = settingKey[Int]("testShard")
val testShardCount = settingKey[Int]("testShardCount")
val shardingAlgorithm = settingKey[ShardingAlgorithm]("shardingAlgorithm")
val testShardDebug = settingKey[Boolean]("testShardDebug")
}

import autoImport.*

override def trigger = allRequirements

def stringConfig(key: String, default: String): String = {
val propertyKey = key.replace('_', '.').toLowerCase
sys.props.get(propertyKey).orElse(sys.env.get(key)).getOrElse(default)
}

override lazy val projectSettings: Seq[Def.Setting[?]] =
Seq(
testShard := 0,
testShardCount := 1,
testShard := stringConfig("TEST_SHARD", "0").toInt,
testShardCount := stringConfig("TEST_SHARD_COUNT", "1").toInt,
shardingAlgorithm := ShardingAlgorithm.SuiteName,
testShardDebug := false,
Test / testOptions += {
val shardContext = ShardContext(testShardCount.value, testShard.value, sLog.value)
Tests.Filter(specName => shardingAlgorithm.value.isInShard(specName, shardContext))
val shardContext = ShardContext(testShard.value, testShardCount.value, sLog.value)
Tests.Filter { specName =>
val isInShard = shardingAlgorithm.value.shouldRun(specName, shardContext)

if (testShardDebug.value) {
if (isInShard) {
sLog.value.info(s"`$specName` set to run on this shard.")
} else {
sLog.value.warn(s"`$specName` skipped because it will run on another shard.")
}
}

isInShard
}
}
)

Expand Down

0 comments on commit cd3c61b

Please sign in to comment.