Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ScalaNative version of AwsLambdaRuntime #2906

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -1433,12 +1433,19 @@ lazy val awsLambda: ProjectMatrix = (projectMatrix in file("serverless/aws/lambd
name := "tapir-aws-lambda",
libraryDependencies ++= loggerDependencies,
libraryDependencies ++= Seq(
"com.softwaremill.sttp.client3" %% "fs2" % Versions.sttp,
"com.softwaremill.sttp.client3" %%% "fs2" % Versions.sttp,
"com.amazonaws" % "aws-lambda-java-runtime-interface-client" % Versions.awsLambdaInterface
)
)
.jvmPlatform(scalaVersions = scala2And3Versions)
.jsPlatform(scalaVersions = scala2Versions)
.jsPlatform(
scalaVersions = scala2Versions,
Seq(
// Cross compiles only on JVM and Native
Test / unmanagedSources / excludeFilter ~= { _ || "AwsLambdaRuntimeInvocationTest.scala" }
)
)
.nativePlatform(scalaVersions = scala2And3Versions)
.dependsOn(serverCore, cats, catsEffect, circeJson, tests % "test")

// integration tests for lambda interpreter
Expand Down Expand Up @@ -1639,6 +1646,10 @@ lazy val awsExamples: ProjectMatrix = (projectMatrix in file("serverless/aws/exa
scalaJSLinkerConfig ~= { _.withModuleKind(ModuleKind.CommonJSModule) }
)
)
.nativePlatform(
scalaVersions = scala2Versions,
settings = commonNativeSettings
)
.dependsOn(awsLambda)

lazy val awsExamples2_12 = awsExamples.jvm(scala2_12).dependsOn(awsSam.jvm(scala2_12), awsTerraform.jvm(scala2_12), awsCdk.jvm(scala2_12))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package sttp.tapir.serverless.aws.examples

import cats.effect.IO
import cats.syntax.all._
import sttp.tapir._
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.serverless.aws.lambda.runtime._

object LambdaRuntime extends AwsLambdaIORuntime {
val helloEndpoint: ServerEndpoint[Any, IO] = endpoint.get
.in("api" / "hello")
.out(stringBody)
.serverLogic { _ => IO.pure(s"Hello!".asRight[Unit]) }

override val endpoints = Seq(helloEndpoint)
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package sttp.tapir.serverless.aws.lambda.runtime

import cats.effect.{Resource, Sync}
import cats.syntax.either._
import com.typesafe.scalalogging.StrictLogging
import PlatformCompat.StrictLogging
import io.circe.Printer
import io.circe.generic.auto._
import io.circe.parser.decode
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package sttp.tapir.serverless.aws.lambda.runtime

private[runtime] object PlatformCompat {
// Compiles, but would not link, scalalogging is not cross-compiled for ScalaJS
type StrictLogging = com.typesafe.scalalogging.StrictLogging
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package sttp.tapir.serverless.aws.lambda.runtime

private[runtime] object PlatformCompat {
type StrictLogging = com.typesafe.scalalogging.StrictLogging
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package sttp.tapir.serverless.aws.lambda

import sttp.capabilities
import sttp.model.HasHeaders
import sttp.tapir.capabilities.NoStreams
import sttp.tapir.server.interpreter.ToResponseBody
import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput}

import java.io.InputStream
import java.nio.ByteBuffer
import java.nio.charset.Charset
import java.util.Base64

// The same as for the JVM
private[lambda] class AwsToResponseBody[F[_]](options: AwsServerOptions[F]) extends ToResponseBody[LambdaResponseBody, NoStreams] {
override val streams: capabilities.Streams[NoStreams] = NoStreams

override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): LambdaResponseBody =
bodyType match {
case RawBodyType.StringBody(charset) =>
val str = v.asInstanceOf[String]
val r = if (options.encodeResponseBody) Base64.getEncoder.encodeToString(str.getBytes(charset)) else str
(r, Some(str.length.toLong))

case RawBodyType.ByteArrayBody =>
val bytes = v.asInstanceOf[Array[Byte]]
val r = if (options.encodeResponseBody) Base64.getEncoder.encodeToString(bytes) else new String(bytes)
(r, Some(bytes.length.toLong))

case RawBodyType.ByteBufferBody =>
val byteBuffer = v.asInstanceOf[ByteBuffer]
val r = if (options.encodeResponseBody) Base64.getEncoder.encodeToString(safeRead(byteBuffer)) else new String(safeRead(byteBuffer))
(r, None)

case RawBodyType.InputStreamBody =>
val stream = v.asInstanceOf[InputStream]
val r =
if (options.encodeResponseBody) Base64.getEncoder.encodeToString(stream.readAllBytes()) else new String(stream.readAllBytes())
(r, None)
case RawBodyType.InputStreamRangeBody =>
val bytes: Array[Byte] = v.range
.map(r => v.inputStreamFromRangeStart().readNBytes(r.contentLength.toInt))
.getOrElse(v.inputStream().readAllBytes())
val body =
if (options.encodeResponseBody) Base64.getEncoder.encodeToString(bytes) else new String(bytes)
(body, Some(bytes.length.toLong))

case RawBodyType.FileBody => throw new UnsupportedOperationException
case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException
}

private def safeRead(byteBuffer: ByteBuffer): Array[Byte] = {
if (byteBuffer.hasArray) {
if (byteBuffer.array().length != byteBuffer.limit()) {
val array = new Array[Byte](byteBuffer.limit())
byteBuffer.get(array, 0, byteBuffer.limit())
array
} else byteBuffer.array()
} else {
val array = new Array[Byte](byteBuffer.remaining())
byteBuffer.get(array)
array
}
}

override def fromStreamValue(
v: streams.BinaryStream,
headers: HasHeaders,
format: CodecFormat,
charset: Option[Charset]
): LambdaResponseBody =
throw new UnsupportedOperationException

override def fromWebSocketPipe[REQ, RESP](
pipe: streams.Pipe[REQ, RESP],
o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, NoStreams]
): LambdaResponseBody = throw new UnsupportedOperationException
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package sttp.tapir.serverless.aws.lambda.runtime

import cats.effect.{Sync, IO, Resource}
import cats.effect.unsafe.implicits.global
import cats.syntax.all._
import sttp.client3.{SttpBackend, AbstractCurlBackend, FollowRedirectsBackend}
import sttp.client3.impl.cats.CatsMonadError
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.serverless.aws.lambda._

import PlatformCompat.IOPlatformOps

object AwsLambdaRuntime {
// TODO: Move to sttp
private class CurlCatsBackend[F[_]: Sync](verbose: Boolean) extends AbstractCurlBackend[F](new CatsMonadError, verbose)
object CurlCatsBackend {
def apply[F[_]: Sync](verbose: Boolean = false): SttpBackend[F, Any] =
new FollowRedirectsBackend(new CurlCatsBackend(verbose))
}

def apply[F[_]: Sync](endpoints: Iterable[ServerEndpoint[Any, F]], serverOptions: AwsServerOptions[F]): F[Unit] = {
val backend = Resource.pure[F, SttpBackend[F, Any]](CurlCatsBackend(verbose = false))
val route: Route[F] = AwsCatsEffectServerInterpreter(serverOptions).toRoute(endpoints.toList)
AwsLambdaRuntimeInvocation.handleNext(route, sys.env("AWS_LAMBDA_RUNTIME_API"), backend).foreverM
}
}

/** A runtime which uses the [[IO]] effect */
abstract class AwsLambdaIORuntime {
def endpoints: Iterable[ServerEndpoint[Any, IO]]
def serverOptions: AwsServerOptions[IO] = AwsCatsEffectServerOptions.default[IO]

def main(args: Array[String]): Unit =
AwsLambdaRuntime(endpoints, serverOptions).unsafeRunSync()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package sttp.tapir.serverless.aws.lambda.runtime

import cats.effect.IO
import cats.effect.unsafe.IORuntime
import scala.concurrent.Promise

private[runtime] object PlatformCompat {
// No unsafeRunSync in Native version of IOPlatform
implicit class IOPlatformOps[T](io: IO[T]) {
def unsafeRunSync()(implicit runtime: IORuntime): T = {
val result = Promise[T]()
io.unsafeRunAsync(v => result.complete(v.toTry))
while (!result.isCompleted) scala.scalanative.runtime.loop()
result.future.value.get.get
}
}

trait StrictLogging {
val logger = new ConsoleLogger({
val parts = getClass.getName().split('.')
val shortPackageName = parts.init.map(_.take(1)).mkString(".")
s"$shortPackageName.${parts.last}"
})
class ConsoleLogger(loggerName: String) {
private sealed abstract class Severity(val name: String)
private object Error extends Severity("ERROR")
private object Warn extends Severity("WARN")
private object Info extends Severity("INFO")
private object Debug extends Severity("DEBUG")
private object Trace extends Severity("TRACE")

private def getTimestamp(): String = {
import scala.scalanative.meta.LinktimeInfo.isWindows
import scala.scalanative.posix.time._
import scala.scalanative.posix.timeOps._
import scala.scalanative.unsafe._
import scala.scalanative.unsigned._

if (isWindows) ""
else {
val currentTime = stackalloc[timespec]()
val timeInfo = stackalloc[tm]()

clock_gettime(CLOCK_REALTIME, currentTime)
localtime_r(currentTime.at1, timeInfo)

val length = 25.toUInt
val timestamp = stackalloc[CChar](length)
strftime(timestamp, length, c"%Y-%m-%d %H:%M:%S", timeInfo)
val milliseconds = currentTime.tv_nsec / 1000000
f"${fromCString(timestamp)},$milliseconds%03d"
}
}

private def log(severity: Severity, msg: String): Unit = {
val timestamp = getTimestamp()
val thread = Thread.currentThread().getName()
println(s"$timestamp [$thread] ${severity.name} ${loggerName} - $msg")
}
private def log(severity: Severity, msg: String, cause: Throwable): Unit = log(severity, s"$msg coused by ${cause.getMessage()}")
private def log(severity: Severity, msg: String, args: Any*): Unit =
log(severity, args.map(_.toString()).foldLeft(msg)(_.replaceFirst(raw"\{\}", _)))

// Error
def error(message: String): Unit = log(Error, message)
def error(message: String, cause: Throwable): Unit = log(Error, message, cause)
def error(message: String, args: Any*): Unit = log(Error, message, args)
def whenErrorEnabled(body: => Unit): Unit = body

// Warn
def warn(message: String): Unit = log(Warn, message)
def warn(message: String, cause: Throwable): Unit = log(Warn, message, cause)
def warn(message: String, args: Any*): Unit = log(Warn, message, args)
def whenWarnEnabled(body: => Unit): Unit = body

// Info
def info(message: String): Unit = log(Info, message)
def info(message: String, cause: Throwable): Unit = log(Info, message, cause)
def info(message: String, args: Any*): Unit = log(Info, message, args)
def whenInfoEnabled(body: => Unit): Unit = body

// Debug
def debug(message: String): Unit = log(Debug, message)
def debug(message: String, cause: Throwable): Unit = log(Debug, message, cause)
def debug(message: String, args: Any*): Unit = log(Debug, message, args)
def whenDebugEnabled(body: => Unit): Unit = body

// Trace
def trace(message: String): Unit = log(Trace, message)
def trace(message: String, cause: Throwable): Unit = log(Trace, message, cause)
def trace(message: String, args: Any*): Unit = log(Trace, message, args)
def whenTraceEnabled(body: => Unit): Unit = body
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import sttp.tapir._
import sttp.tapir.integ.cats.effect.CatsMonadError
import sttp.tapir.serverless.aws.lambda.runtime.AwsLambdaRuntimeInvocationTest._
import sttp.tapir.serverless.aws.lambda.{AwsCatsEffectServerInterpreter, AwsCatsEffectServerOptions, AwsServerOptions}
import PlatformCompat._

import scala.collection.immutable.Seq

Expand Down