Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .scalafmt.conf
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ spaces {
optIn.annotationNewlines = true

rewrite.rules = [SortImports, RedundantBraces]

18 changes: 17 additions & 1 deletion benchmarks/src/main/scala/com/devsisters/shardcake/Client.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.devsisters.shardcake

import com.devsisters.shardcake.Server.Message.Ping
import com.devsisters.shardcake.Server.Message.{ Ping, StreamPing }
import com.devsisters.shardcake.Server.PingPongEntity
import zio.{ Config => _, _ }

Expand All @@ -17,4 +17,20 @@ object Client {
} yield ()
)
.provide(config, Server.sharding)

def sendStream(streams: Int, messagesPerStream: Int, parallelism: Int): Task[Unit] =
ZIO
.scoped[Sharding](
for {
ping <- Sharding.messenger(PingPongEntity)
_ <- ZIO
.foreachParDiscard(1 to streams) { _ =>
ping
.sendAndReceiveStream("ping")(StreamPing("ping", messagesPerStream, _))
.flatMap(_.runDrain)
}
.withParallelism(parallelism)
} yield ()
)
.provide(config, Server.sharding)
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import java.util.concurrent.TimeUnit
@State(Scope.Thread)
@BenchmarkMode(Array(Mode.Throughput))
@OutputTimeUnit(TimeUnit.SECONDS)
@Warmup(iterations = 5, time = 5, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 5, time = 5, timeUnit = TimeUnit.SECONDS)
@Warmup(iterations = 3, time = 3, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 3, time = 3, timeUnit = TimeUnit.SECONDS)
@Fork(1)
class SendBenchmark {
private var fiber: Fiber[Any, Any] = _
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.devsisters.shardcake

import org.openjdk.jmh.annotations._
import zio.{ durationInt, Fiber, Runtime, Unsafe, ZIO }

import java.util.concurrent.TimeUnit

@State(Scope.Thread)
@BenchmarkMode(Array(Mode.Throughput))
@OutputTimeUnit(TimeUnit.SECONDS)
@Warmup(iterations = 3, time = 3, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 3, time = 3, timeUnit = TimeUnit.SECONDS)
@Fork(1)
class SendStreamBenchmark {
private var fiber: Fiber[Any, Any] = _

@Setup
def setup(): Unit =
fiber = Unsafe.unsafe(implicit unsafe =>
Runtime.default.unsafe.run(Server.run.forkDaemon <* ZIO.sleep(3.seconds)).getOrThrow()
)

@TearDown
def tearDown(): Unit =
Unsafe.unsafe(implicit unsafe => Runtime.default.unsafe.run(fiber.interrupt))

// 8 parallel server-streams, each receiving 100 messages → 800 messages per op
@Benchmark
def serverStream(): Unit =
Unsafe.unsafe(implicit unsafe => Runtime.default.unsafe.run(Client.sendStream(8, 100, 8)))
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@ object Server {
sealed trait Message

object Message {
case class Ping(msg: String, replier: Replier[String]) extends Message
case class Ping(msg: String, replier: Replier[String]) extends Message
case class StreamPing(msg: String, count: Int, replier: StreamReplier[String]) extends Message
}

object PingPongEntity extends EntityType[Message]("ping-pong")

private def behavior(entityId: String, messages: Dequeue[Message]): RIO[Sharding, Nothing] =
messages.take.flatMap { case Message.Ping(msg, replier) => replier.reply(msg) }.forever
messages.take.flatMap {
case Message.Ping(msg, replier) => replier.reply(msg)
case Message.StreamPing(msg, count, replier) =>
replier.replyStream(zio.stream.ZStream.repeat(msg).take(count.toLong))
}.forever

private val shardManagerClient: ZLayer[Config, Nothing, ShardManagerClient] =
ZLayer {
Expand Down
54 changes: 36 additions & 18 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
val scala3 = "3.3.7"

val zioVersion = "2.1.24"
val zioGrpcVersion = "0.6.3"
val proteusVersion = "0.4.1"
val grpcNettyVersion = "1.71.0"
val zioK8sVersion = "3.2.0"
val zioK8sSttpVersion = "3.11.0"
Expand Down Expand Up @@ -41,7 +41,7 @@ inThisBuild(

name := "shardcake"
addCommandAlias("fmt", "all scalafmtSbt scalafmt test:scalafmt")
addCommandAlias("check", "all scalafmtSbtCheck scalafmtCheck test:scalafmtCheck")
addCommandAlias("check", "all scalafmtSbtCheck scalafmtCheck test:scalafmtCheck grpcProtocol/checkProto")

lazy val root = project
.in(file("."))
Expand Down Expand Up @@ -151,25 +151,47 @@ lazy val serializationKryo = project
)
)

lazy val generateProto = taskKey[Unit]("Regenerate sharding.proto from the Scala protocol definition.")
lazy val checkProto = taskKey[Unit]("Fail if sharding.proto is out of sync with the Scala protocol definition.")

lazy val grpcProtocol = project
.in(file("protocol-grpc"))
.settings(name := "shardcake-protocol-grpc")
.settings(commonSettings)
.settings(protobuf: _*)
.settings(
Compile / PB.targets := Seq(
scalapb.gen(grpc = true) -> (Compile / sourceManaged).value,
scalapb.zio_grpc.ZioCodeGenerator -> (Compile / sourceManaged).value
)
)
.dependsOn(core, entities)
.settings(
libraryDependencies ++= Seq(
"com.thesamet.scalapb" %% "scalapb-runtime" % scalapb.compiler.Version.scalapbVersion % "protobuf",
"com.thesamet.scalapb" %% "scalapb-runtime-grpc" % scalapb.compiler.Version.scalapbVersion,
"com.thesamet.scalapb.zio-grpc" %% "zio-grpc-core" % zioGrpcVersion,
"io.grpc" % "grpc-netty" % grpcNettyVersion
)
"com.github.ghostdogpr" %% "proteus-grpc" % proteusVersion,
"com.github.ghostdogpr" %% "proteus-grpc-zio" % proteusVersion,
"io.grpc" % "grpc-netty" % grpcNettyVersion,
"io.grpc" % "grpc-services" % grpcNettyVersion
),
generateProto := {
val cp = (Compile / fullClasspath).value
val log = streams.value.log
val output = (Compile / sourceDirectory).value / "protobuf"
runner.value
.run(
"com.devsisters.shardcake.protocol.GenerateProto",
cp.files,
Seq(output.getAbsolutePath),
log
)
.get
log.info(s"Regenerated $output/sharding.proto")
},
checkProto := {
val log = streams.value.log
val proto = (Compile / sourceDirectory).value / "protobuf" / "sharding.proto"
val _ = generateProto.value
import scala.sys.process._
val diff = s"git diff --exit-code -- ${proto.getAbsolutePath}".!
if (diff != 0) {
sys.error(
"sharding.proto is out of sync with the Scala protocol definition. Run `sbt grpcProtocol/generateProto` and commit the result."
)
} else log.info("sharding.proto is in sync.")
}
)

lazy val examples = project
Expand All @@ -194,10 +216,6 @@ lazy val benchmarks = project
.enablePlugins(JmhPlugin)
.dependsOn(grpcProtocol, serializationKryo)

lazy val protobuf = Seq(
PB.protocVersion := "3.19.2"
) ++ Project.inConfig(Test)(sbtprotoc.ProtocPlugin.protobufConfigSettings)

lazy val commonSettings = Def.settings(
testFrameworks := Seq(new TestFramework("zio.test.sbt.ZTestFramework")),
libraryDependencies ++=
Expand Down
13 changes: 5 additions & 8 deletions examples/src/test/scala/example/GrpcAuthExampleSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,27 @@ package example
import com.devsisters.shardcake._
import com.devsisters.shardcake.interfaces.{ Pods, Storage }
import io.grpc.{ Metadata, Status }
import scalapb.zio_grpc.{ ZClientInterceptor, ZTransform }
import zio.test._
import zio.{ Config => _, _ }

object GrpcAuthExampleSpec extends ZIOSpecDefault {

private val validAuthenticationKey = "validAuthenticationKey"

private val authKey = Metadata.Key.of("authentication-key", io.grpc.Metadata.ASCII_STRING_MARSHALLER)
private val authKey = Metadata.Key.of("authentication-key", Metadata.ASCII_STRING_MARSHALLER)

private val config = ZLayer.succeed(Config.default.copy(simulateRemotePods = true))

private def grpcConfigLayer(clientAuthKey: String): ULayer[GrpcConfig] =
ZLayer.succeed(
GrpcConfig.default.copy(
clientInterceptors = Seq(
ZClientInterceptor.headersUpdater((_, _, md) => md.put(authKey, clientAuthKey).unit)
ShardingClientInterceptor.headersUpdater(_.put(authKey, clientAuthKey))
),
serverInterceptors = Seq(
ZTransform { requestContext =>
for {
authenticated <- requestContext.metadata.get(authKey).map(_.contains(validAuthenticationKey))
_ <- ZIO.when(!authenticated)(ZIO.fail(Status.UNAUTHENTICATED.asException))
} yield requestContext
ShardingServerInterceptor.beforeEach { ctx =>
val authenticated = Option(ctx.requestMetadata.get(authKey)).contains(validAuthenticationKey)
ZIO.unless(authenticated)(ZIO.fail(Status.UNAUTHENTICATED.asException())).unit
}
)
)
Expand Down
4 changes: 0 additions & 4 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.6")
addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.11.1")
addSbtPlugin("com.thesamet" % "sbt-protoc" % "1.0.7")
addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.4.7")

libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.11.17"
libraryDependencies += "com.thesamet.scalapb.zio-grpc" %% "zio-grpc-codegen" % "0.6.3"
2 changes: 0 additions & 2 deletions protocol-grpc/src/main/protobuf/sharding.proto
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
syntax = "proto3";

option java_package = "com.devsisters.shardcake.protobuf";

service ShardingService {
rpc AssignShards (AssignShardsRequest) returns (AssignShardsResponse) {}
rpc UnassignShards (UnassignShardsRequest) returns (UnassignShardsResponse) {}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,30 +1,37 @@
package com.devsisters.shardcake

import io.grpc.ClientInterceptor
import zio._
import scalapb.zio_grpc.RequestContext
import scalapb.zio_grpc.ZClientInterceptor
import scalapb.zio_grpc.ZTransform

import java.util.concurrent.Executor

/**
* The configuration for the gRPC client.
* The configuration for the gRPC client and server.
*
* @param maxInboundMessageSize the maximum message size allowed to be received by the grpc client
* @param maxInboundMessageSize the maximum message size allowed to be received by the grpc client and server
* @param executor a custom executor to pass to grpc-java when creating gRPC clients and servers
* @param shutdownTimeout the timeout to wait for the gRPC server to shutdown before forcefully shutting it down
* @param clientInterceptors the interceptors to be used by the gRPC client, e.g for adding tracing or logging
* @param serverInterceptors the interceptors to be used by the gRPC Server, e.g for adding tracing or logging
* @param streamingPrefetch the in-flight window for streaming RPCs (request/response messages fetched ahead of the consumer)
* @param clientInterceptors the interceptors to be used by the gRPC client, e.g. for adding tracing or logging
* @param serverInterceptors the interceptors to be used by the gRPC server, e.g. for adding tracing or logging
*/
case class GrpcConfig(
maxInboundMessageSize: Int,
executor: Option[Executor],
shutdownTimeout: Duration,
clientInterceptors: Seq[ZClientInterceptor],
serverInterceptors: Seq[ZTransform[RequestContext, RequestContext]]
streamingPrefetch: Int,
clientInterceptors: Seq[ClientInterceptor],
serverInterceptors: Seq[ShardingServerInterceptor]
)

object GrpcConfig {
val default: GrpcConfig =
GrpcConfig(maxInboundMessageSize = 32 * 1024 * 1024, None, 3.seconds, Seq.empty, Seq.empty)
GrpcConfig(
maxInboundMessageSize = 32 * 1024 * 1024,
executor = None,
shutdownTimeout = 3.seconds,
streamingPrefetch = 16,
clientInterceptors = Seq.empty,
serverInterceptors = Seq.empty
)
}
Loading