Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC +str Add StreamCollector api for unsafe transformation. #628

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.pekko.stream

import org.apache.pekko
import pekko.actor.{ DeadLetterSuppression, NoSerializationVerificationNeeded }
import pekko.annotation.InternalApi

import scala.util.{ Failure, Success, Try }

object StreamCollectorOps {
def emit[T](value: T)(implicit collector: StreamCollector[T]): Unit = collector.emit(value)
def fail(throwable: Throwable)(implicit collector: StreamCollector[_]): Unit = collector.fail(throwable)
def complete()(implicit collector: StreamCollector[_]): Unit = collector.complete()
def handle[T](result: Try[T])(implicit collector: StreamCollector[T]): Unit = collector.handle(result)
def handle[T](result: Either[Throwable, T])(implicit collector: StreamCollector[T]): Unit = collector.handle(result)
}

object StreamCollectorUnsafeOps {
def emitSync[T](value: T)(implicit collector: StreamCollector[T]): Unit =
collector.asInstanceOf[UnsafeStreamCollector[T]].emitSync(value)
def failSync(throwable: Throwable)(implicit collector: StreamCollector[_]): Unit =
collector.asInstanceOf[UnsafeStreamCollector[_]].failSync(throwable)
def completeSync()(implicit collector: StreamCollector[_]): Unit =
collector.asInstanceOf[UnsafeStreamCollector[_]].completeSync()
def handleSync[T](result: Try[T])(implicit collector: StreamCollector[T]): Unit =
collector.asInstanceOf[UnsafeStreamCollector[T]].handleSync(result)
def handleSync[T](result: Either[Throwable, T])(implicit collector: StreamCollector[T]): Unit =
collector.asInstanceOf[UnsafeStreamCollector[T]].handleSync(result)
}

object StreamCollector {
sealed trait StreamCollectorCommand
extends DeadLetterSuppression
with NoSerializationVerificationNeeded

case class EmitNext[T](value: T) extends StreamCollectorCommand

case class Fail(throwable: Throwable) extends StreamCollectorCommand

object Complete extends StreamCollectorCommand

object TryPull extends StreamCollectorCommand
}

trait StreamCollector[T] {

def emit(value: T): Unit

def tryPull(): Unit = ()

def fail(throwable: Throwable): Unit

def complete(): Unit

def handle(result: Try[T]): Unit = result match {
case Success(value) => emit(value)
case Failure(ex) => fail(ex)
}

def handle(result: Either[Throwable, T]): Unit = result match {
case Right(value) => emit(value)
case Left(ex) => fail(ex)
}
}

@InternalApi
private[pekko] trait UnsafeStreamCollector[T] extends StreamCollector[T] {
def emitSync(value: T): Unit
def failSync(throwable: Throwable): Unit

def completeSync(): Unit
def handleSync(result: Try[T]): Unit = result match {
case Success(value) => emitSync(value)
case Failure(ex) => failSync(ex)
}
def handleSync(result: Either[Throwable, T]): Unit = result match {
case Right(value) => emitSync(value)
case Left(ex) => failSync(ex)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.pekko.stream.impl.fusing

import org.apache.pekko.stream.ActorAttributes.SupervisionStrategy
import org.apache.pekko.stream.Attributes.SourceLocation
import org.apache.pekko.stream.impl.Stages.DefaultAttributes
import org.apache.pekko.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler }
import org.apache.pekko.stream._

import scala.util.control.NonFatal

private[pekko] class UnsafeTransformUnordered[In, Out](
parallelism: Int,
transform: (In, StreamCollector[Out]) => Unit)
extends GraphStage[FlowShape[In, Out]] {
private val in = Inlet[In]("UnsafeTransformOrdered.in")
private val out = Outlet[Out]("UnsafeTransformOrdered.out")

override def initialAttributes = DefaultAttributes.mapAsyncUnordered and SourceLocation.forLambda(transform)

override val shape = FlowShape(in, out)

override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler with UnsafeStreamCollector[Out] {
override def toString = s"UnsafeTransformOrdered.Logic(inFlight=$inFlight, buffer=$buffer)"
import org.apache.pekko.stream.impl.{ Buffer => BufferImpl }

private val decider =
inheritedAttributes.mandatoryAttribute[SupervisionStrategy].decider
private var inFlight = 0
private var buffer: org.apache.pekko.stream.impl.Buffer[Out] = _

import StreamCollector._

private val callback: StreamCollectorCommand => Unit = getAsyncCallback[StreamCollectorCommand](handle).invoke

override def emitSync(value: Out): Unit = handle(EmitNext(value))
override def failSync(throwable: Throwable): Unit = handle(Fail(throwable))
override def completeSync(): Unit = completeStage()

// TODO check permit
override final def emit(value: Out): Unit = callback(EmitNext(value))
override final def fail(throwable: Throwable): Unit = callback(Fail(throwable))
override final def complete(): Unit = callback(Complete)

//
private[this] def todo: Int = inFlight + buffer.used

override def preStart(): Unit = buffer = BufferImpl(parallelism, inheritedAttributes)

private def isCompleted = isClosed(in) && todo == 0

def handle(msg: StreamCollectorCommand): Unit = {
inFlight -= 1

msg match {
case EmitNext(elem: Out @unchecked) if elem != null =>
if (isAvailable(out)) {
if (!hasBeenPulled(in)) tryPull(in)
push(out, elem)
if (isCompleted) completeStage()
} else buffer.enqueue(elem)
case EmitNext(_) =>
if (isCompleted) completeStage()
else if (!hasBeenPulled(in)) tryPull(in)
case TryPull =>
if (!hasBeenPulled(in)) tryPull(in)
case Complete =>
completeStage()
case Fail(ex) =>
if (decider(ex) == Supervision.Stop) failStage(ex)
else if (isCompleted) completeStage()
else if (!hasBeenPulled(in)) tryPull(in)
}
}

override def onPush(): Unit = {
try {
val elem = grab(in)
transform(elem, this)
inFlight += 1
} catch {
case NonFatal(ex) => if (decider(ex) == Supervision.Stop) failStage(ex)
}
if (todo < parallelism && !hasBeenPulled(in)) tryPull(in)
}

override def onUpstreamFinish(): Unit = {
if (todo == 0) completeStage()
}

override def onPull(): Unit = {
if (!buffer.isEmpty) push(out, buffer.dequeue())

val leftTodo = todo
if (isClosed(in) && leftTodo == 0) completeStage()
else if (leftTodo < parallelism && !hasBeenPulled(in)) tryPull(in)
}

setHandlers(in, out, this)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,10 @@ final class Flow[In, Out, Mat](delegate: scaladsl.Flow[In, Out, Mat]) extends Gr
def mapAsyncUnordered[T](parallelism: Int, f: function.Function[Out, CompletionStage[T]]): javadsl.Flow[In, T, Mat] =
new Flow(delegate.mapAsyncUnordered(parallelism)(x => f(x).asScala))

def unsafeTransformUnordered[T](
parallelism: Int, transform: function.Function2[Out, StreamCollector[T], Unit]): javadsl.Flow[In, T, Mat] =
new Flow(delegate.unsafeTransformUnordered[T](parallelism)(collector => out => transform(out, collector)))

/**
* Use the `ask` pattern to send a request-reply message to the target `ref` actor.
* If any of the asks times out it will fail the stream with a [[pekko.pattern.AskTimeoutException]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2527,6 +2527,10 @@ final class Source[Out, Mat](delegate: scaladsl.Source[Out, Mat]) extends Graph[
def mapAsyncUnordered[T](parallelism: Int, f: function.Function[Out, CompletionStage[T]]): javadsl.Source[T, Mat] =
new Source(delegate.mapAsyncUnordered(parallelism)(x => f(x).asScala))

def unsafeTransformUnordered[T](
parallelism: Int, transform: function.Function2[Out, StreamCollector[T], Unit]): javadsl.Source[T, Mat] =
new Source(delegate.unsafeTransformUnordered[T](parallelism)(collector => out => transform(out, collector)))

/**
* Use the `ask` pattern to send a request-reply message to the target `ref` actor.
* If any of the asks times out it will fail the stream with a [[pekko.pattern.AskTimeoutException]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@ package org.apache.pekko.stream.javadsl
import java.util.{ Comparator, Optional }
import java.util.concurrent.CompletionStage
import java.util.function.Supplier

import scala.annotation.{ nowarn, varargs }
import scala.annotation.unchecked.uncheckedVariance
import scala.collection.immutable
import scala.concurrent.duration.FiniteDuration
import scala.reflect.ClassTag

import org.apache.pekko
import pekko.NotUsed
import pekko.annotation.ApiMayChange
Expand Down Expand Up @@ -348,6 +346,10 @@ class SubFlow[In, Out, Mat](
def mapAsyncUnordered[T](parallelism: Int, f: function.Function[Out, CompletionStage[T]]): SubFlow[In, T, Mat] =
new SubFlow(delegate.mapAsyncUnordered(parallelism)(x => f(x).asScala))

def unsafeTransformUnordered[T](
parallelism: Int, transform: function.Function2[Out, StreamCollector[T], Unit]): javadsl.SubFlow[In, T, Mat] =
new SubFlow(delegate.unsafeTransformUnordered[T](parallelism)(emitter => out => transform(out, emitter)))

/**
* Only pass on those elements that satisfy the given predicate.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,10 @@ class SubSource[Out, Mat](
def mapAsyncUnordered[T](parallelism: Int, f: function.Function[Out, CompletionStage[T]]): SubSource[T, Mat] =
new SubSource(delegate.mapAsyncUnordered(parallelism)(x => f(x).asScala))

def unsafeTransformUnordered[T](
parallelism: Int, transform: function.Function2[Out, StreamCollector[T], Unit]): javadsl.SubSource[T, Mat] =
new SubSource(delegate.unsafeTransformUnordered[T](parallelism)(collector => out => transform(out, collector)))

/**
* Only pass on those elements that satisfy the given predicate.
*
Expand Down
19 changes: 15 additions & 4 deletions stream/src/main/scala/org/apache/pekko/stream/scaladsl/Flow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ import scala.collection.immutable
import scala.concurrent.Future
import scala.concurrent.duration.FiniteDuration
import scala.reflect.ClassTag

import org.reactivestreams.Processor
import org.reactivestreams.Publisher
import org.reactivestreams.Subscriber
import org.reactivestreams.Subscription

import org.apache.pekko
import pekko.Done
import pekko.NotUsed
Expand Down Expand Up @@ -829,7 +827,6 @@ final case class RunnableGraph[+Mat](override val traversalBuilder: TraversalBui
@ccompatUsedUntil213
trait FlowOps[+Out, +Mat] {
import GraphDSL.Implicits._

import org.apache.pekko.stream.impl.Stages._

type Repr[+O] <: FlowOps[O, Mat] {
Expand Down Expand Up @@ -1141,7 +1138,21 @@ trait FlowOps[+Out, +Mat] {
*
* @see [[#mapAsync]]
*/
def mapAsyncUnordered[T](parallelism: Int)(f: Out => Future[T]): Repr[T] = via(MapAsyncUnordered(parallelism, f))
def mapAsyncUnordered[T](parallelism: Int)(f: Out => Future[T]): Repr[T] =
unsafeTransformUnordered[T](parallelism) { implicit collector => out =>
import StreamCollectorOps._
import StreamCollectorUnsafeOps._

val future = f(out)
future.value match {
case Some(elem) => handleSync(elem)
case None => future.onComplete(handle(_))(pekko.dispatch.ExecutionContexts.parasitic)
}
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for demo


@ApiMayChange
def unsafeTransformUnordered[T](parallelism: Int)(transform: StreamCollector[T] => Out => Unit): Repr[T] =
via(new UnsafeTransformUnordered[Out, T](parallelism, (out, emitter) => transform(emitter)(out)))

/**
* Use the `ask` pattern to send a request-reply message to the target `ref` actor.
Expand Down
Loading