diff --git a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FromMaterializationSpec.scala b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FromMaterializationSpec.scala index 68ad985e45f..fd922effc26 100644 --- a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FromMaterializationSpec.scala +++ b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FromMaterializationSpec.scala @@ -14,6 +14,7 @@ package org.apache.pekko.stream.scaladsl import org.apache.pekko +import org.apache.pekko.stream.impl.fusing.GraphInterpreter import pekko.NotUsed import pekko.stream.Attributes import pekko.stream.Attributes.Attribute @@ -143,6 +144,14 @@ class FromMaterializerSpec extends StreamSpec { Source.empty.via(flow).runWith(Sink.head).futureValue should not be empty } + "expose interpreter" in { + val flow = Flow.fromMaterializer { (_, _) => + Flow.fromSinkAndSource(Sink.ignore, Source.single(GraphInterpreter.currentInterpreter)) + } + + Source.empty.via(flow).runWith(Sink.head).futureValue should not be null + } + "propagate materialized value" in { val flow = Flow.fromMaterializer { (_, _) => Flow.fromSinkAndSourceMat(Sink.ignore, Source.maybe[NotUsed])(Keep.right) diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/GraphInterpreter.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/GraphInterpreter.scala index b3be4fbe5bd..bb7f6a51629 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/GraphInterpreter.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/GraphInterpreter.scala @@ -308,20 +308,27 @@ import pekko.stream.stage._ */ def init(subMat: Materializer): Unit = { _subFusingMaterializer = if (subMat == null) materializer else subMat - var i = 0 - while (i < logics.length) { - val logic = logics(i) - logic.interpreter = this - try { - logic.beforePreStart() - logic.preStart() - } catch { - case NonFatal(e) => - log.error(e, "Error during preStart in [{}]: {}", logic.toString, e.getMessage) - logic.failStage(e) + val currentInterpreterHolder = _currentInterpreter.get() + val previousInterpreter = currentInterpreterHolder(0) + currentInterpreterHolder(0) = this + try { + var i = 0 + while (i < logics.length) { + val logic = logics(i) + logic.interpreter = this + try { + logic.beforePreStart() + logic.preStart() + } catch { + case NonFatal(e) => + log.error(e, "Error during preStart in [{}]: {}", logic.toString, e.getMessage) + logic.failStage(e) + } + afterStageHasRun(logic) + i += 1 } - afterStageHasRun(logic) - i += 1 + } finally { + currentInterpreterHolder(0) = previousInterpreter } }