Skip to content

Commit

Permalink
+str Add StreamCollector api for unsafe transformation.
Browse files Browse the repository at this point in the history
  • Loading branch information
He-Pin committed Sep 10, 2023
1 parent b37b5cd commit c009807
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.pekko.stream

import org.apache.pekko
import pekko.actor.{ DeadLetterSuppression, NoSerializationVerificationNeeded }
import pekko.annotation.InternalApi

import scala.util.{ Failure, Success, Try }

object StreamCollectorOps {
def emit[T](value: T)(implicit collector: StreamCollector[T]): Unit = collector.emit(value)
def fail(throwable: Throwable)(implicit collector: StreamCollector[_]): Unit = collector.fail(throwable)
def complete()(implicit collector: StreamCollector[_]): Unit = collector.complete()
def handle[T](result: Try[T])(implicit collector: StreamCollector[T]): Unit = collector.handle(result)
def handle[T](result: Either[Throwable, T])(implicit collector: StreamCollector[T]): Unit = collector.handle(result)
}

object StreamCollectorUnsafeOps {
def emitSync[T](value: T)(implicit collector: StreamCollector[T]): Unit =
collector.asInstanceOf[UnsafeStreamCollector[T]].emitSync(value)
def failSync(throwable: Throwable)(implicit collector: StreamCollector[_]): Unit =
collector.asInstanceOf[UnsafeStreamCollector[_]].failSync(throwable)
def completeSync()(implicit collector: StreamCollector[_]): Unit =
collector.asInstanceOf[UnsafeStreamCollector[_]].completeSync()
def handleSync[T](result: Try[T])(implicit collector: StreamCollector[T]): Unit =
collector.asInstanceOf[UnsafeStreamCollector[T]].handleSync(result)
def handleSync[T](result: Either[Throwable, T])(implicit collector: StreamCollector[T]): Unit =
collector.asInstanceOf[UnsafeStreamCollector[T]].handleSync(result)
}

object StreamCollector {
sealed trait StreamCollectorCommand
extends DeadLetterSuppression
with NoSerializationVerificationNeeded

case class EmitNext[T](value: T) extends StreamCollectorCommand

case class Fail(throwable: Throwable) extends StreamCollectorCommand

object Complete extends StreamCollectorCommand

object TryPull extends StreamCollectorCommand
}

trait StreamCollector[T] {

def emit(value: T): Unit

def tryPull(): Unit = ()

def fail(throwable: Throwable): Unit

def complete(): Unit

def handle(result: Try[T]): Unit = result match {
case Success(value) => emit(value)
case Failure(ex) => fail(ex)
}

def handle(result: Either[Throwable, T]): Unit = result match {
case Right(value) => emit(value)
case Left(ex) => fail(ex)
}
}

@InternalApi
private[pekko] trait UnsafeStreamCollector[T] extends StreamCollector[T] {
def emitSync(value: T): Unit
def failSync(throwable: Throwable): Unit

def completeSync(): Unit
def handleSync(result: Try[T]): Unit = result match {
case Success(value) => emitSync(value)
case Failure(ex) => failSync(ex)
}
def handleSync(result: Either[Throwable, T]): Unit = result match {
case Right(value) => emitSync(value)
case Left(ex) => failSync(ex)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.pekko.stream.impl.fusing

import org.apache.pekko.stream.ActorAttributes.SupervisionStrategy
import org.apache.pekko.stream.Attributes.SourceLocation
import org.apache.pekko.stream.impl.Stages.DefaultAttributes
import org.apache.pekko.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler }
import org.apache.pekko.stream._

import scala.util.control.NonFatal

private[pekko] class UnsafeTransformUnordered[In, Out](
parallelism: Int,
transform: (In, StreamCollector[Out]) => Unit)
extends GraphStage[FlowShape[In, Out]] {
private val in = Inlet[In]("UnsafeTransformOrdered.in")
private val out = Outlet[Out]("UnsafeTransformOrdered.out")

override def initialAttributes = DefaultAttributes.mapAsyncUnordered and SourceLocation.forLambda(transform)

override val shape = FlowShape(in, out)

override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler with UnsafeStreamCollector[Out] {
override def toString = s"UnsafeTransformOrdered.Logic(inFlight=$inFlight, buffer=$buffer)"
import org.apache.pekko.stream.impl.{ Buffer => BufferImpl }

private val decider =
inheritedAttributes.mandatoryAttribute[SupervisionStrategy].decider
private var inFlight = 0
private var buffer: org.apache.pekko.stream.impl.Buffer[Out] = _

import StreamCollector._

private val callback: StreamCollectorCommand => Unit = getAsyncCallback[StreamCollectorCommand](handle).invoke

override def emitSync(value: Out): Unit = handle(EmitNext(value))
override def failSync(throwable: Throwable): Unit = handle(Fail(throwable))
override def completeSync(): Unit = completeStage()

// TODO check permit
override final def emit(value: Out): Unit = callback(EmitNext(value))
override final def fail(throwable: Throwable): Unit = callback(Fail(throwable))
override final def complete(): Unit = callback(Complete)

//
private[this] def todo: Int = inFlight + buffer.used

override def preStart(): Unit = buffer = BufferImpl(parallelism, inheritedAttributes)

private def isCompleted = isClosed(in) && todo == 0

def handle(msg: StreamCollectorCommand): Unit = {
inFlight -= 1

msg match {
case EmitNext(elem: Out @unchecked) if elem != null =>
if (isAvailable(out)) {
if (!hasBeenPulled(in)) tryPull(in)
push(out, elem)
if (isCompleted) completeStage()
} else buffer.enqueue(elem)
case EmitNext(_) =>
if (isCompleted) completeStage()
else if (!hasBeenPulled(in)) tryPull(in)
case TryPull =>
if (!hasBeenPulled(in)) tryPull(in)
case Complete =>
completeStage()
case Fail(ex) =>
if (decider(ex) == Supervision.Stop) failStage(ex)
else if (isCompleted) completeStage()
else if (!hasBeenPulled(in)) tryPull(in)
}
}

override def onPush(): Unit = {
try {
val elem = grab(in)
transform(elem, this)
inFlight += 1
} catch {
case NonFatal(ex) => if (decider(ex) == Supervision.Stop) failStage(ex)
}
if (todo < parallelism && !hasBeenPulled(in)) tryPull(in)
}

override def onUpstreamFinish(): Unit = {
if (todo == 0) completeStage()
}

override def onPull(): Unit = {
if (!buffer.isEmpty) push(out, buffer.dequeue())

val leftTodo = todo
if (isClosed(in) && leftTodo == 0) completeStage()
else if (leftTodo < parallelism && !hasBeenPulled(in)) tryPull(in)
}

setHandlers(in, out, this)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,10 @@ final class Flow[In, Out, Mat](delegate: scaladsl.Flow[In, Out, Mat]) extends Gr
def mapAsyncUnordered[T](parallelism: Int, f: function.Function[Out, CompletionStage[T]]): javadsl.Flow[In, T, Mat] =
new Flow(delegate.mapAsyncUnordered(parallelism)(x => f(x).asScala))

def unsafeTransformUnordered[T](
parallelism: Int, transform: function.Function2[Out, StreamCollector[T], Unit]): javadsl.Flow[In, T, Mat] =
new Flow(delegate.unsafeTransformUnordered[T](parallelism)(collector => out => transform(out, collector)))

/**
* Use the `ask` pattern to send a request-reply message to the target `ref` actor.
* If any of the asks times out it will fail the stream with a [[pekko.pattern.AskTimeoutException]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2527,6 +2527,10 @@ final class Source[Out, Mat](delegate: scaladsl.Source[Out, Mat]) extends Graph[
def mapAsyncUnordered[T](parallelism: Int, f: function.Function[Out, CompletionStage[T]]): javadsl.Source[T, Mat] =
new Source(delegate.mapAsyncUnordered(parallelism)(x => f(x).asScala))

def unsafeTransformUnordered[T](
parallelism: Int, transform: function.Function2[Out, StreamCollector[T], Unit]): javadsl.Source[T, Mat] =
new Source(delegate.unsafeTransformUnordered[T](parallelism)(collector => out => transform(out, collector)))

/**
* Use the `ask` pattern to send a request-reply message to the target `ref` actor.
* If any of the asks times out it will fail the stream with a [[pekko.pattern.AskTimeoutException]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@ package org.apache.pekko.stream.javadsl
import java.util.{ Comparator, Optional }
import java.util.concurrent.CompletionStage
import java.util.function.Supplier

import scala.annotation.{ nowarn, varargs }
import scala.annotation.unchecked.uncheckedVariance
import scala.collection.immutable
import scala.concurrent.duration.FiniteDuration
import scala.reflect.ClassTag

import org.apache.pekko
import pekko.NotUsed
import pekko.annotation.ApiMayChange
Expand Down Expand Up @@ -348,6 +346,10 @@ class SubFlow[In, Out, Mat](
def mapAsyncUnordered[T](parallelism: Int, f: function.Function[Out, CompletionStage[T]]): SubFlow[In, T, Mat] =
new SubFlow(delegate.mapAsyncUnordered(parallelism)(x => f(x).asScala))

def unsafeTransformUnordered[T](
parallelism: Int, transform: function.Function2[Out, StreamCollector[T], Unit]): javadsl.SubFlow[In, T, Mat] =
new SubFlow(delegate.unsafeTransformUnordered[T](parallelism)(emitter => out => transform(out, emitter)))

/**
* Only pass on those elements that satisfy the given predicate.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,10 @@ class SubSource[Out, Mat](
def mapAsyncUnordered[T](parallelism: Int, f: function.Function[Out, CompletionStage[T]]): SubSource[T, Mat] =
new SubSource(delegate.mapAsyncUnordered(parallelism)(x => f(x).asScala))

def unsafeTransformUnordered[T](
parallelism: Int, transform: function.Function2[Out, StreamCollector[T], Unit]): javadsl.SubSource[T, Mat] =
new SubSource(delegate.unsafeTransformUnordered[T](parallelism)(collector => out => transform(out, collector)))

/**
* Only pass on those elements that satisfy the given predicate.
*
Expand Down
19 changes: 15 additions & 4 deletions stream/src/main/scala/org/apache/pekko/stream/scaladsl/Flow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ import scala.collection.immutable
import scala.concurrent.Future
import scala.concurrent.duration.FiniteDuration
import scala.reflect.ClassTag

import org.reactivestreams.Processor
import org.reactivestreams.Publisher
import org.reactivestreams.Subscriber
import org.reactivestreams.Subscription

import org.apache.pekko
import pekko.Done
import pekko.NotUsed
Expand Down Expand Up @@ -829,7 +827,6 @@ final case class RunnableGraph[+Mat](override val traversalBuilder: TraversalBui
@ccompatUsedUntil213
trait FlowOps[+Out, +Mat] {
import GraphDSL.Implicits._

import org.apache.pekko.stream.impl.Stages._

type Repr[+O] <: FlowOps[O, Mat] {
Expand Down Expand Up @@ -1141,7 +1138,21 @@ trait FlowOps[+Out, +Mat] {
*
* @see [[#mapAsync]]
*/
def mapAsyncUnordered[T](parallelism: Int)(f: Out => Future[T]): Repr[T] = via(MapAsyncUnordered(parallelism, f))
def mapAsyncUnordered[T](parallelism: Int)(f: Out => Future[T]): Repr[T] =
unsafeTransformUnordered[T](parallelism) { implicit collector => out =>
import StreamCollectorOps._
import StreamCollectorUnsafeOps._

val future = f(out)
future.value match {
case Some(elem) => handleSync(elem)
case None => future.onComplete(handle(_))(pekko.dispatch.ExecutionContexts.parasitic)
}
}

@ApiMayChange
def unsafeTransformUnordered[T](parallelism: Int)(transform: StreamCollector[T] => Out => Unit): Repr[T] =
via(new UnsafeTransformUnordered[Out, T](parallelism, (out, emitter) => transform(emitter)(out)))

/**
* Use the `ask` pattern to send a request-reply message to the target `ref` actor.
Expand Down

0 comments on commit c009807

Please sign in to comment.