Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions java-runtime/src/main/scala/client/Fs2ClientCall.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,27 @@ package java_runtime
package client

import cats.effect._
import cats.effect.concurrent.{Deferred, Ref}
import cats.implicits._
import io.grpc.{Metadata, _}
import fs2._

final case class UnaryResult[A](value: Option[A], status: Option[GrpcStatus])
final case class GrpcStatus(status: Status, trailers: Metadata)

class Fs2ClientCall[F[_], Request, Response] private[client] (val call: ClientCall[Request, Response]) extends AnyVal {
class Fs2ClientCall[F[_], Request, Response] private[client] (val call: ClientCall[Request, Response],
val wakeOnReady: Ref[F, Option[Deferred[F, Unit]]]) {
def onReady()(implicit F: Sync[F]): F[Unit] = {
wakeOnReady
.modify({
case None => (None, F.unit)
case Some(wake) => (None, wake.complete(()))
})
.flatten
}

private def isReady(implicit F: Sync[F]): F[Boolean] =
F.delay(call.isReady)

private def cancel(message: Option[String], cause: Option[Throwable])(implicit F: Sync[F]): F[Unit] =
F.delay(call.cancel(message.orNull, cause.orNull))
Expand All @@ -21,22 +34,35 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (val call: ClientCa
private def request(numMessages: Int)(implicit F: Sync[F]): F[Unit] =
F.delay(call.request(numMessages))

private def sendMessage(message: Request)(implicit F: Sync[F]): F[Unit] =
private def sendMessage(message: Request)(implicit F: Sync[F]): F[Unit] = {
F.delay(call.sendMessage(message))
}

private def sendMessageOrDelay(message: Request)(implicit F: Concurrent[F]): F[Unit] = {
isReady.ifM(
sendMessage(message), {
Deferred[F, Unit].flatMap { wakeup =>
wakeOnReady.set(wakeup.some) *>
isReady.ifM(sendMessage(message), wakeup.get *> sendMessage(message))
}
}
)
}

private def start(listener: ClientCall.Listener[Response], metadata: Metadata)(implicit F: Sync[F]): F[Unit] =
F.delay(call.start(listener, metadata))

def startListener[A <: ClientCall.Listener[Response]](createListener: F[A], headers: Metadata)(implicit F: Sync[F]): F[A] = {
def startListener[A <: ClientCall.Listener[Response]](createListener: F[A], headers: Metadata)(
implicit F: Sync[F]): F[A] = {
createListener.flatTap(start(_, headers)) <* request(1)
}

def sendSingleMessage(message: Request)(implicit F: Sync[F]): F[Unit] = {
sendMessage(message) *> halfClose
}

def sendStream(stream: Stream[F, Request])(implicit F: Sync[F]): Stream[F, Unit] = {
stream.evalMap(sendMessage) ++ Stream.eval(halfClose)
def sendStream(stream: Stream[F, Request])(implicit F: Concurrent[F]): Stream[F, Unit] = {
stream.evalMap(sendMessageOrDelay) ++ Stream.eval(halfClose)
}

def handleCallError(
Expand Down Expand Up @@ -82,8 +108,13 @@ object Fs2ClientCall {
channel: Channel,
methodDescriptor: MethodDescriptor[Request, Response],
callOptions: CallOptions)(implicit F: Sync[F]): F[Fs2ClientCall[F, Request, Response]] =
F.delay(new Fs2ClientCall(channel.newCall[Request, Response](methodDescriptor, callOptions)))
apply(channel.newCall[Request, Response](methodDescriptor, callOptions))

def apply[Request, Response](call: ClientCall[Request, Response])(
implicit F: Sync[F]): F[Fs2ClientCall[F, Request, Response]] =
for {
wakeOnReady <- Ref[F].of(none[Deferred[F, Unit]])
} yield new Fs2ClientCall(call, wakeOnReady)
}

def apply[F[_]]: PartiallyAppliedClientCall[F] =
Expand Down
34 changes: 31 additions & 3 deletions java-runtime/src/main/scala/server/Fs2ServerCall.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,25 @@ package java_runtime
package server

import cats.effect._
import cats.effect.concurrent.{Deferred, Ref}
import cats.implicits._
import io.grpc._

// TODO: Add attributes, compression, message compression.
private[server] class Fs2ServerCall[F[_], Request, Response](val call: ServerCall[Request, Response]) extends AnyVal {
private[server] class Fs2ServerCall[F[_], Request, Response](val call: ServerCall[Request, Response],
val wakeOnReady: Ref[F, Option[Deferred[F, Unit]]]) {
def onReady()(implicit F: Sync[F]): F[Unit] = {
wakeOnReady
.modify({
case None => (None, F.unit)
case Some(wake) => (None, wake.complete(()))
})
.flatten
}

def isReady(implicit F: Sync[F]): F[Boolean] =
F.delay(call.isReady)

def sendHeaders(headers: Metadata)(implicit F: Sync[F]): F[Unit] =
F.delay(call.sendHeaders(headers))

Expand All @@ -16,11 +31,24 @@ private[server] class Fs2ServerCall[F[_], Request, Response](val call: ServerCal
def sendMessage(message: Response)(implicit F: Sync[F]): F[Unit] =
F.delay(call.sendMessage(message))

def sendMessageOrDelay(message: Response)(implicit F: Concurrent[F]): F[Unit] =
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this method be called inside Fs2ServerCallListener#handleStreamResponse instead of Fs2ServerCall#sendMessage? It looks like current implementation of server streaming doesn't take backpressure into account.

isReady.ifM(
sendMessage(message), {
Deferred[F, Unit].flatMap { wakeup =>
wakeOnReady.set(wakeup.some) *>
isReady.ifM(sendMessage(message), wakeup.get *> sendMessage(message))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we check isReady after wakeup.get completes to deal with missed wakeup?

}
}
)

def request(numMessages: Int)(implicit F: Sync[F]): F[Unit] =
F.delay(call.request(numMessages))
}

private[server] object Fs2ServerCall {
def apply[F[_], Request, Response](call: ServerCall[Request, Response]): Fs2ServerCall[F, Request, Response] =
new Fs2ServerCall[F, Request, Response](call)
def apply[F[_], Request, Response](call: ServerCall[Request, Response])(
implicit F: Concurrent[F]): F[Fs2ServerCall[F, Request, Response]] =
for {
wakeOnReady <- Ref[F].of(none[Deferred[F, Unit]])
} yield new Fs2ServerCall[F, Request, Response](call, wakeOnReady)
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class Fs2StreamServerCallListener[F[_], Request, Response] private (
val call: Fs2ServerCall[F, Request, Response])(implicit F: Effect[F])
extends ServerCall.Listener[Request]
with Fs2ServerCallListener[F, Stream[F, ?], Request, Response] {
override def onReady(): Unit = {
call.onReady().unsafeRun()
}

override def onCancel(): Unit = {
isCancelled.complete(()).unsafeRun()
Expand All @@ -40,10 +43,8 @@ object Fs2StreamServerCallListener {
for {
inputQ <- Queue.unbounded[F, Option[Request]]
isCancelled <- Deferred[F, Unit]
} yield
new Fs2StreamServerCallListener[F, Request, Response](inputQ,
isCancelled,
Fs2ServerCall[F, Request, Response](call))
serverCall <- Fs2ServerCall[F, Request, Response](call)
} yield new Fs2StreamServerCallListener[F, Request, Response](inputQ, isCancelled, serverCall)
}

def apply[F[_]] = new PartialFs2StreamServerCallListener[F]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class Fs2UnaryServerCallListener[F[_], Request, Response] private (

import Fs2UnaryServerCallListener._

override def onReady(): Unit = {
call.onReady().unsafeRun()
}

override def onCancel(): Unit = {
isCancelled.complete(()).unsafeRun()
}
Expand Down Expand Up @@ -62,11 +66,8 @@ object Fs2UnaryServerCallListener {
request <- Ref.of[F, Option[Request]](none)
isComplete <- Deferred[F, Unit]
isCancelled <- Deferred[F, Unit]
} yield
new Fs2UnaryServerCallListener[F, Request, Response](request,
isComplete,
isCancelled,
Fs2ServerCall[F, Request, Response](call))
serverCall <- Fs2ServerCall[F, Request, Response](call)
} yield new Fs2UnaryServerCallListener[F, Request, Response](request, isComplete, isCancelled, serverCall)
}

def apply[F[_]] = new PartialFs2UnaryServerCallListener[F]
Expand Down
25 changes: 11 additions & 14 deletions java-runtime/src/test/scala/client/ClientSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ object ClientSuite extends SimpleTestSuite {
implicit val cs: ContextShift[IO] = IO.contextShift(ec)

val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = Fs2ClientCall[IO](dummy).unsafeRunSync()
val result = client.unaryToUnaryCall("hello", new Metadata()).unsafeToFuture()
dummy.listener.get.onMessage(5)

Expand All @@ -44,7 +44,7 @@ object ClientSuite extends SimpleTestSuite {
implicit val timer: Timer[IO] = ec.timer

val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = Fs2ClientCall[IO](dummy).unsafeRunSync()
val result = client.unaryToUnaryCall("hello", new Metadata()).timeout(1.second).unsafeToFuture()

ec.tick()
Expand All @@ -68,7 +68,7 @@ object ClientSuite extends SimpleTestSuite {
implicit val cs: ContextShift[IO] = IO.contextShift(ec)

val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = Fs2ClientCall[IO](dummy).unsafeRunSync()
val result = client.unaryToUnaryCall("hello", new Metadata()).unsafeToFuture()

dummy.listener.get.onClose(Status.OK, new Metadata())
Expand All @@ -87,9 +87,8 @@ object ClientSuite extends SimpleTestSuite {
implicit val ec: TestContext = TestContext()
implicit val cs: ContextShift[IO] = IO.contextShift(ec)


val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = Fs2ClientCall[IO](dummy).unsafeRunSync()
val result = client.unaryToUnaryCall("hello", new Metadata()).unsafeToFuture()
dummy.listener.get.onMessage(5)

Expand All @@ -113,9 +112,8 @@ object ClientSuite extends SimpleTestSuite {
implicit val ec: TestContext = TestContext()
implicit val cs: ContextShift[IO] = IO.contextShift(ec)


val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = Fs2ClientCall[IO](dummy).unsafeRunSync()
val result = client
.streamingToUnaryCall(Stream.emits(List("a", "b", "c")), new Metadata())
.unsafeToFuture()
Expand All @@ -140,9 +138,8 @@ object ClientSuite extends SimpleTestSuite {
implicit val ec: TestContext = TestContext()
implicit val cs: ContextShift[IO] = IO.contextShift(ec)


val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = Fs2ClientCall[IO](dummy).unsafeRunSync()
val result = client
.streamingToUnaryCall(Stream.empty, new Metadata())
.unsafeToFuture()
Expand All @@ -168,7 +165,7 @@ object ClientSuite extends SimpleTestSuite {
implicit val cs: ContextShift[IO] = IO.contextShift(ec)

val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = Fs2ClientCall[IO](dummy).unsafeRunSync()
val result = client.unaryToStreamingCall("hello", new Metadata()).compile.toList.unsafeToFuture()

dummy.listener.get.onMessage(1)
Expand All @@ -194,7 +191,7 @@ object ClientSuite extends SimpleTestSuite {
implicit val cs: ContextShift[IO] = IO.contextShift(ec)

val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = Fs2ClientCall[IO](dummy).unsafeRunSync()
val result =
client
.streamingToStreamingCall(Stream.emits(List("a", "b", "c", "d", "e")), new Metadata())
Expand Down Expand Up @@ -225,7 +222,7 @@ object ClientSuite extends SimpleTestSuite {
implicit val timer: Timer[IO] = ec.timer

val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = Fs2ClientCall[IO](dummy).unsafeRunSync()
val result =
client
.streamingToStreamingCall(Stream.emits(List("a", "b", "c", "d", "e")), new Metadata())
Expand Down Expand Up @@ -255,7 +252,7 @@ object ClientSuite extends SimpleTestSuite {
implicit val cs: ContextShift[IO] = IO.contextShift(ec)

val dummy = new DummyClientCall()
val client = new Fs2ClientCall[IO, String, Int](dummy)
val client = Fs2ClientCall[IO](dummy).unsafeRunSync()
val result =
client
.streamingToStreamingCall(Stream.emits(List("a", "b", "c", "d", "e")), new Metadata())
Expand Down Expand Up @@ -287,7 +284,7 @@ object ClientSuite extends SimpleTestSuite {
}

test("resource awaits termination of managed channel") {
implicit val ec: TestContext = TestContext()
implicit val ec: TestContext = TestContext()

import implicits._
val result = ManagedChannelBuilder.forAddress("127.0.0.1", 0).resource[IO].use(IO.pure).unsafeToFuture()
Expand Down