Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@ import dev.kdriver.cdp.*
import dev.kdriver.cdp.domain.*
import dev.kdriver.core.browser.Browser
import dev.kdriver.core.browser.Config.Defaults
import dev.kdriver.core.browser.WebSocketInfo
import io.ktor.client.*
import io.ktor.client.plugins.websocket.*
import io.ktor.http.*
import io.ktor.util.logging.*
import io.ktor.websocket.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.Flow
Expand All @@ -35,17 +30,23 @@ open class DefaultConnection(

private val logger = KtorSimpleLogger("Connection")

private val client = HttpClient(getWebSocketClientEngine()) {
install(WebSockets)
}

private var wsSession: ClientWebSocketSession? = null
private val transport: WebSocketTransport by lazy { createTransport() }

private var socketSubscription: Job? = null

private val currentIdMutex = Mutex()
private var currentId = 0L

private val pendingRequestsMutex = Mutex()
private val pendingRequests = mutableMapOf<Long, CompletableDeferred<Message.Response>>()

/**
* Creates the [WebSocketTransport] used to talk to the browser.
*
* Overridable so tests can inject a fake transport without a real browser.
*/
protected open fun createTransport(): WebSocketTransport = KtorWebSocketTransport(websocketUrl)

private var prepareHeadlessDone = false
private var prepareExpertDone = false

Expand All @@ -61,29 +62,23 @@ open class DefaultConnection(
override val generatedDomains: MutableMap<KClass<out Domain>, Domain> = mutableMapOf()

private suspend fun connect() {
if (wsSession != null && wsSession?.isActive == true) return
wsSession = client.webSocketSession {
url {
val parsed = parseWebSocketUrl(websocketUrl)
this.protocol = URLProtocol.WS
this.host = parsed.host
this.port = parsed.port
this.path(parsed.path)
}
}
if (transport.isActive) return
transport.connect()
startListening()
}

private fun startListening() {
socketSubscription?.cancel()
socketSubscription = messageListeningScope.launch {
try {
for (frame in wsSession?.incoming ?: return@launch) {
transport.incoming().collect { text ->
try {
frame as? Frame.Text ?: continue
val text = frame.readText()
logger.debug("WS < CDP: ${text.take(owner?.config?.debugStringLimit ?: Defaults.DEBUG_STRING_LIMIT)}")
val received = Serialization.json.decodeFromString<Message>(text)
if (received is Message.Response) {
pendingRequestsMutex.withLock { pendingRequests.remove(received.id) }
?.complete(received)
}
allMessages.emit(received)
} catch (e: CancellationException) {
throw e
Expand All @@ -110,19 +105,27 @@ open class DefaultConnection(
}

val requestId = currentIdMutex.withLock { currentId++ }
val jsonString = Serialization.json.encodeToString(Request(requestId, method, parameter))
wsSession?.send(jsonString)
logger.debug("WS > CDP: ${jsonString.take(owner?.config?.debugStringLimit ?: Defaults.DEBUG_STRING_LIMIT)}")

val result = responses.first { it.id == requestId }
result.error?.throwAsException(method)
return result.result
// Register the response waiter *before* sending, so a reply that arrives before we start
// awaiting is still captured (the receive loop completes this deferred). Awaiting the
// response via a replay-0 shared flow after sending could miss it and hang (ISSUE-1).
val deferred = CompletableDeferred<Message.Response>()
pendingRequestsMutex.withLock { pendingRequests[requestId] = deferred }
try {
val jsonString = Serialization.json.encodeToString(Request(requestId, method, parameter))
transport.send(jsonString)
logger.debug("WS > CDP: ${jsonString.take(owner?.config?.debugStringLimit ?: Defaults.DEBUG_STRING_LIMIT)}")

val result = deferred.await()
result.error?.throwAsException(method)
return result.result
} finally {
pendingRequestsMutex.withLock { pendingRequests.remove(requestId) }
}
}

@InternalCdpApi
override suspend fun close() {
wsSession?.close()
wsSession = null
transport.close()
socketSubscription?.cancel()
socketSubscription = null
}
Expand Down Expand Up @@ -219,20 +222,6 @@ open class DefaultConnection(
}
}

private fun parseWebSocketUrl(url: String): WebSocketInfo {
val uri = Url(url)

val host = uri.host
val port = if (uri.port != -1) uri.port else when (uri.protocol) {
URLProtocol.WS -> 80
URLProtocol.WSS -> 443
else -> throw IllegalArgumentException("Unsupported scheme: ${uri.protocol}")
}
val path = uri.encodedPath

return WebSocketInfo(host, port, path)
}

override fun toString(): String {
return "Connection: ${targetInfo?.toString() ?: "no target"}"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package dev.kdriver.core.connection

import dev.kdriver.core.browser.WebSocketInfo
import io.ktor.client.*
import io.ktor.client.plugins.websocket.*
import io.ktor.http.*
import io.ktor.websocket.*
import kotlinx.coroutines.isActive
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow

/**
* Default [WebSocketTransport] backed by a Ktor WebSocket session.
*/
class KtorWebSocketTransport(
private val websocketUrl: String,
) : WebSocketTransport {

private val client = HttpClient(getWebSocketClientEngine()) {
install(WebSockets)
}

private var session: ClientWebSocketSession? = null

override val isActive: Boolean
get() = session?.isActive == true

override suspend fun connect() {
if (isActive) return
session = client.webSocketSession {
url {
val parsed = parseWebSocketUrl(websocketUrl)
this.protocol = URLProtocol.WS
this.host = parsed.host
this.port = parsed.port
this.path(parsed.path)
}
}
}

override suspend fun send(message: String) {
session?.send(message)
}

override fun incoming(): Flow<String> = flow {
val session = session ?: return@flow
for (frame in session.incoming) {
val text = (frame as? Frame.Text)?.readText() ?: continue
emit(text)
}
}

override suspend fun close() {
session?.close()
session = null
}

private fun parseWebSocketUrl(url: String): WebSocketInfo {
val uri = Url(url)

val host = uri.host
val port = if (uri.port != -1) uri.port else when (uri.protocol) {
URLProtocol.WS -> 80

Check warning on line 63 in core/src/commonMain/kotlin/dev/kdriver/core/connection/KtorWebSocketTransport.kt

View check run for this annotation

codefactor.io / CodeFactor

core/src/commonMain/kotlin/dev/kdriver/core/connection/KtorWebSocketTransport.kt#L63

This expression contains a magic number. Consider defining it to a well named constant. (detekt.MagicNumber)
URLProtocol.WSS -> 443

Check warning on line 64 in core/src/commonMain/kotlin/dev/kdriver/core/connection/KtorWebSocketTransport.kt

View check run for this annotation

codefactor.io / CodeFactor

core/src/commonMain/kotlin/dev/kdriver/core/connection/KtorWebSocketTransport.kt#L64

This expression contains a magic number. Consider defining it to a well named constant. (detekt.MagicNumber)
else -> throw IllegalArgumentException("Unsupported scheme: ${uri.protocol}")
}
val path = uri.encodedPath

return WebSocketInfo(host, port, path)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package dev.kdriver.core.connection

import kotlinx.coroutines.flow.Flow

/**
* Abstraction over the raw WebSocket connection used to talk to the browser.
*
* Extracting this lets [DefaultConnection]'s message plumbing (request/response correlation,
* event dispatch) be exercised without a real browser, by injecting a fake transport in tests.
*/
interface WebSocketTransport {

/**
* Whether the underlying connection is currently open.
*/
val isActive: Boolean

/**
* Opens the connection. Must be called before [send] or [incoming]. No-op if already open.
*/
suspend fun connect()

/**
* Sends a raw text payload to the browser.
*/
suspend fun send(message: String)

/**
* Cold stream of raw text payloads received from the browser. Collecting it starts consuming
* frames; the flow completes when the connection is closed.
*/
fun incoming(): Flow<String>

/**
* Closes the connection.
*/
suspend fun close()

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@ import dev.kdriver.cdp.domain.Fetch.HeaderEntry
import dev.kdriver.cdp.domain.Network
import dev.kdriver.cdp.domain.fetch
import dev.kdriver.core.tab.Tab
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.launch
import kotlin.coroutines.coroutineContext
import kotlinx.coroutines.*

/**
* Default implementation of [FetchInterception].
Expand All @@ -30,7 +26,12 @@ open class BaseFetchInterception(
}

private suspend fun setup() {
val coroutineScope = CoroutineScope(coroutineContext)
val coroutineScope = CoroutineScope(currentCoroutineContext())
// Subscribe before enabling fetch, so a requestPaused fired after enable() can't be missed.
// UNDISPATCHED guarantees the collector is subscribed before launch returns (ISSUE-2).
job = coroutineScope.launch(start = CoroutineStart.UNDISPATCHED) {
tab.fetch.requestPaused.collect { handler(it) }
}
tab.fetch.enable(
listOf(
Fetch.RequestPattern(
Expand All @@ -40,9 +41,6 @@ open class BaseFetchInterception(
)
)
)
job = coroutineScope.launch {
tab.fetch.requestPaused.collect { handler(it) }
}
}

private suspend fun teardown() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@ package dev.kdriver.core.network
import dev.kdriver.cdp.domain.Network
import dev.kdriver.cdp.domain.network
import dev.kdriver.core.tab.Tab
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.launch
import kotlin.coroutines.coroutineContext
import kotlinx.coroutines.*

/**
* Default implementation of [RequestExpectation].
Expand Down Expand Up @@ -53,17 +49,20 @@ open class BaseRequestExpectation(
}

private suspend fun setup() {
val coroutineScope = CoroutineScope(coroutineContext)
tab.network.enable()
requestJob = coroutineScope.launch {
val coroutineScope = CoroutineScope(currentCoroutineContext())
// Subscribe to the event flows before enabling the domain, so no event fired after
// enable() can be missed. UNDISPATCHED guarantees each collector is actually subscribed
// before launch returns (ISSUE-2).
requestJob = coroutineScope.launch(start = CoroutineStart.UNDISPATCHED) {
tab.network.requestWillBeSent.collect { requestHandler(it) }
}
responseJob = coroutineScope.launch {
responseJob = coroutineScope.launch(start = CoroutineStart.UNDISPATCHED) {
tab.network.responseReceived.collect { responseHandler(it) }
}
loadingFinishedJob = coroutineScope.launch {
loadingFinishedJob = coroutineScope.launch(start = CoroutineStart.UNDISPATCHED) {
tab.network.loadingFinished.collect { loadingFinishedHandler(it) }
}
tab.network.enable()
}

private fun teardown() {
Expand Down
Loading
Loading