Skip to content

Commit

Permalink
Merge pull request akka#19438 from 2m/wip-eager-zip-mergep
Browse files Browse the repository at this point in the history
akka#19271 Eager Zip and Merge Preferred
  • Loading branch information
drewhk committed Jan 19, 2016
2 parents 7c93b69 + edc119d commit 4e61673
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,7 @@ class GraphInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
lastEvents() should ===(Set.empty)

source2.onNext("Meaning of life")
lastEvents() should ===(Set(OnNext(sink, (42, "Meaning of life"))))

sink.requestOne()
lastEvents() should ===(Set(RequestOne(source1), RequestOne(source2)))
lastEvents() should ===(Set(OnNext(sink, (42, "Meaning of life")), RequestOne(source1), RequestOne(source2)))
}

"implement Broadcast" in new TestSetup {
Expand Down Expand Up @@ -169,13 +166,11 @@ class GraphInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
lastEvents() should ===(Set(RequestOne(source)))

source.onNext(1)
lastEvents() should ===(Set(OnNext(sink, (1, 1))))
lastEvents() should ===(Set(OnNext(sink, (1, 1)), RequestOne(source)))

sink.requestOne()
lastEvents() should ===(Set(RequestOne(source)))

source.onNext(2)
lastEvents() should ===(Set(OnNext(sink, (2, 2))))
lastEvents() should ===(Set(OnNext(sink, (2, 2)), RequestOne(source)))

}

Expand All @@ -198,16 +193,15 @@ class GraphInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit {
lastEvents() should ===(Set.empty)

sink1.requestOne()
lastEvents() should ===(Set.empty)
lastEvents() should ===(Set(RequestOne(source1), RequestOne(source2)))

sink2.requestOne()
lastEvents() should ===(Set(RequestOne(source1), RequestOne(source2)))

source1.onNext(1)
lastEvents() should ===(Set.empty)

source2.onNext(2)
lastEvents() should ===(Set(OnNext(sink1, (1, 2)), OnNext(sink2, (1, 2))))
lastEvents() should ===(Set(OnNext(sink1, (1, 2)), OnNext(sink2, (1, 2)), RequestOne(source1), RequestOne(source2)))

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
*/
package akka.stream.scaladsl

import akka.stream.{ FlowShape, ActorMaterializer, ActorMaterializerSettings }
import akka.stream.{ FlowShape, ActorMaterializer, ActorMaterializerSettings, OverflowStrategy }
import akka.stream.impl.fusing.GraphStages.Detacher
import akka.stream.testkit._
import akka.stream.testkit.Utils._
import akka.stream.testkit.scaladsl._
import com.typesafe.config.ConfigFactory
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.time._

import scala.collection.immutable
import scala.concurrent.Await
import scala.concurrent.duration._
import akka.stream.OverflowStrategy
import org.scalatest.concurrent.ScalaFutures

class FlowJoinSpec extends AkkaSpec(ConfigFactory.parseString("akka.loglevel=INFO")) with ScalaFutures {

Expand All @@ -18,8 +23,11 @@ class FlowJoinSpec extends AkkaSpec(ConfigFactory.parseString("akka.loglevel=INF

implicit val materializer = ActorMaterializer(settings)

implicit val defaultPatience =
PatienceConfig(timeout = Span(2, Seconds), interval = Span(200, Millis))

"A Flow using join" must {
"allow for cycles" in {
"allow for cycles" in assertAllStagesStopped {
val end = 47
val (even, odd) = (0 to end).partition(_ % 2 == 0)
val result = Set() ++ even ++ odd ++ odd.map(_ * 10)
Expand Down Expand Up @@ -51,14 +59,103 @@ class FlowJoinSpec extends AkkaSpec(ConfigFactory.parseString("akka.loglevel=INF
sub.cancel()
}

"propagate one element" in {
"allow for merge cycle" in assertAllStagesStopped {
val source = Source.single("lonely traveler")

val flow1 = Flow.fromGraph(GraphDSL.create(Sink.head[String]) { implicit b
sink
import GraphDSL.Implicits._
val merge = b.add(Merge[String](2))
val broadcast = b.add(Broadcast[String](2))
val broadcast = b.add(Broadcast[String](2, eagerCancel = true))
source ~> merge.in(0)
merge.out ~> broadcast.in
broadcast.out(0) ~> sink

FlowShape(merge.in(1), broadcast.out(1))
})

whenReady(flow1.join(Flow[String]).run())(_ shouldBe "lonely traveler")
}

"allow for merge preferred cycle" in assertAllStagesStopped {
val source = Source.single("lonely traveler")

val flow1 = Flow.fromGraph(GraphDSL.create(Sink.head[String]) { implicit b
sink
import GraphDSL.Implicits._
val merge = b.add(MergePreferred[String](1))
val broadcast = b.add(Broadcast[String](2, eagerCancel = true))
source ~> merge.preferred
merge.out ~> broadcast.in
broadcast.out(0) ~> sink

FlowShape(merge.in(0), broadcast.out(1))
})

whenReady(flow1.join(Flow[String]).run())(_ shouldBe "lonely traveler")
}

"allow for zip cycle" in assertAllStagesStopped {
val source = Source(immutable.Seq("traveler1", "traveler2"))

val flow = Flow.fromGraph(GraphDSL.create(TestSink.probe[(String, String)]) { implicit b
sink
import GraphDSL.Implicits._
val zip = b.add(Zip[String, String])
val broadcast = b.add(Broadcast[(String, String)](2))
source ~> zip.in0
zip.out ~> broadcast.in
broadcast.out(0) ~> sink

FlowShape(zip.in1, broadcast.out(1))
})

val feedback = Flow.fromGraph(GraphDSL.create(Source.single("ignition")) { implicit b
ignition
import GraphDSL.Implicits._
val flow = b.add(Flow[(String, String)].map(_._1))
val merge = b.add(Merge[String](2))

ignition ~> merge.in(0)
flow ~> merge.in(1)

FlowShape(flow.in, merge.out)
})

val probe = flow.join(feedback).run()
probe.requestNext(("traveler1", "ignition"))
probe.requestNext(("traveler2", "traveler1"))
}

"allow for concat cycle" in assertAllStagesStopped {
val flow = Flow.fromGraph(GraphDSL.create(TestSource.probe[String](system), Sink.head[String])(Keep.both) { implicit b
(source, sink)
import GraphDSL.Implicits._
val concat = b.add(Concat[String](2))
val broadcast = b.add(Broadcast[String](2, eagerCancel = true))
source ~> concat.in(0)
concat.out ~> broadcast.in
broadcast.out(0) ~> sink

FlowShape(concat.in(1), broadcast.out(1))
})

val (probe, result) = flow.join(Flow[String]).run()
probe.sendNext("lonely traveler")
whenReady(result) { r
r shouldBe "lonely traveler"
probe.sendComplete()
}
}

"allow for interleave cycle" in assertAllStagesStopped {
val source = Source.single("lonely traveler")

val flow1 = Flow.fromGraph(GraphDSL.create(Sink.head[String]) { implicit b
sink
import GraphDSL.Implicits._
val merge = b.add(Interleave[String](2, 1))
val broadcast = b.add(Broadcast[String](2, eagerCancel = true))
source ~> merge.in(0)
merge.out ~> broadcast.in
broadcast.out(0) ~> sink
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,6 @@ class GraphZipSpec extends TwoStreamsSetup {

downstream.requestNext((1, "A"))
downstream.expectComplete()

upstream1.expectNoMsg(500.millis)
upstream2.expectNoMsg(500.millis)
}

"complete if one side complete before requested with elements pending" in {
Expand Down Expand Up @@ -159,9 +156,6 @@ class GraphZipSpec extends TwoStreamsSetup {

downstream.requestNext((1, "A"))
downstream.expectComplete()

upstream1.expectNoMsg(500.millis)
upstream2.expectNoMsg(500.millis)
}

"complete if one side complete before requested with elements pending 2" in {
Expand Down Expand Up @@ -190,9 +184,6 @@ class GraphZipSpec extends TwoStreamsSetup {
upstream2.sendComplete()
downstream.requestNext((1, "A"))
downstream.expectComplete()

upstream1.expectNoMsg(500.millis)
upstream2.expectNoMsg(500.millis)
}

commonTests()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,22 @@ class ZipWith1[[#A1#], O] (zipper: ([#A1#]) ⇒ O) extends GraphStage[FanInShape
]

override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
var pending = 1
var pending = ##0
// Without this field the completion signalling would take one extra pull
var willShutDown = false

private def pushAll(): Unit = {
push(out, zipper([#grab(in0)#]))
if (willShutDown) completeStage()
else {
[#pull(in0)#
]
}
}

override def preStart(): Unit = {
[#pull(in0)#
]
}

[#setHandler(in0, new InHandler {
Expand All @@ -56,17 +65,13 @@ class ZipWith1[[#A1#], O] (zipper: ([#A1#]) ⇒ O) extends GraphStage[FanInShape

setHandler(out, new OutHandler {
override def onPull(): Unit = {
pending = shape.inlets.size
if (willShutDown) completeStage()
else {
[#pull(in0)#
]
}
pending += shape.inlets.size
if (pending == ##0) pushAll()
}
})
}

override def toString = "Zip"
override def toString = "ZipWith1"

}
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import akka.actor.Cancellable
import akka.dispatch.ExecutionContexts
import akka.event.Logging
import akka.stream._
import akka.stream.scaladsl._
import akka.stream.impl.Stages.DefaultAttributes
import akka.stream.stage._
import scala.concurrent.{ Future, Promise }
Expand Down Expand Up @@ -65,14 +66,16 @@ object GraphStages {

def identity[T] = Identity.asInstanceOf[SimpleLinearGraphStage[T]]

private class Detacher[T] extends GraphStage[FlowShape[T, T]] {
/**
* INERNAL API
*/
private[stream] final class Detacher[T] extends GraphStage[FlowShape[T, T]] {
val in = Inlet[T]("in")
val out = Outlet[T]("out")
override def initialAttributes = Attributes.name("Detacher")
override val shape = FlowShape(in, out)

override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
var initialized = false

setHandler(in, new InHandler {
override def onPush(): Unit = {
Expand Down Expand Up @@ -220,4 +223,27 @@ object GraphStages {
}
override def toString: String = "FutureSource"
}

/**
* INTERNAL API.
*
* Fusing graphs that have cycles involving FanIn stages might lead to deadlocks if
* demand is not carefully managed.
*
* This means that FanIn stages need to early pull every relevant input on startup.
* This can either be implemented inside the stage itself, or this method can be used,
* which adds a detacher stage to every input.
*/
private[stream] def withDetachedInputs[T](stage: GraphStage[UniformFanInShape[T, T]]) =
GraphDSL.create() { implicit builder
import GraphDSL.Implicits._
val concat = builder.add(stage)
val ds = concat.inSeq.map { inlet
val detacher = builder.add(GraphStages.detacher[T])
detacher ~> inlet
detacher.in
}
UniformFanInShape(concat.out, ds: _*)
}

}
30 changes: 14 additions & 16 deletions akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
*/
package akka.stream.scaladsl

import akka.stream.impl.Stages.{ StageModule, SymbolicStage }
import akka.stream._
import akka.stream.impl._
import akka.stream.impl.fusing.GraphStages
import akka.stream.impl.fusing.GraphStages.MaterializedValueSource
import akka.stream.impl.Stages.{ StageModule, SymbolicStage }
import akka.stream.impl.StreamLayout._
import akka.stream._
import akka.stream.stage.{ OutHandler, InHandler, GraphStageLogic, GraphStage }
import scala.annotation.unchecked.uncheckedVariance
import scala.annotation.tailrec
import scala.collection.immutable
import akka.stream.impl.fusing.GraphStages.MaterializedValueSource

object Merge {
/**
Expand Down Expand Up @@ -159,16 +160,12 @@ final class MergePreferred[T] private (val secondaryPorts: Int, val eagerComplet
if (eagerComplete || openInputs == 0) completeStage()
}

setHandler(out, new OutHandler {
private var first = true
override def onPull(): Unit = {
if (first) {
first = false
tryPull(preferred)
shape.inSeq.foreach(tryPull)
}
}
})
override def preStart(): Unit = {
tryPull(preferred)
shape.inSeq.foreach(tryPull)
}

setHandler(out, eagerTerminateOutput)

val pullMe = Array.tabulate(secondaryPorts)(i {
val port = in(i)
Expand Down Expand Up @@ -240,8 +237,8 @@ object Interleave {
* @param segmentSize number of elements to send downstream before switching to next input port
* @param eagerClose if true, interleave completes upstream if any of its upstream completes.
*/
def apply[T](inputPorts: Int, segmentSize: Int, eagerClose: Boolean = false): Interleave[T] =
new Interleave(inputPorts, segmentSize, eagerClose)
def apply[T](inputPorts: Int, segmentSize: Int, eagerClose: Boolean = false): Graph[UniformFanInShape[T, T], Unit] =
GraphStages.withDetachedInputs(new Interleave[T](inputPorts, segmentSize, eagerClose))
}

/**
Expand Down Expand Up @@ -644,7 +641,8 @@ object Concat {
/**
* Create a new `Concat`.
*/
def apply[T](inputPorts: Int = 2): Concat[T] = new Concat(inputPorts)
def apply[T](inputPorts: Int = 2): Graph[UniformFanInShape[T, T], Unit] =
GraphStages.withDetachedInputs(new Concat[T](inputPorts))
}

/**
Expand Down

0 comments on commit 4e61673

Please sign in to comment.