Skip to content

Commit 53265a9

Browse files
authored
Merge pull request #83 from dsrees/feature/accept-params-closure
Allow providing params via a closure
2 parents edd2bf3 + 3019390 commit 53265a9

4 files changed

Lines changed: 188 additions & 87 deletions

File tree

src/main/kotlin/org/phoenixframework/Defaults.kt

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ package org.phoenixframework
2525
import com.google.gson.FieldNamingPolicy
2626
import com.google.gson.Gson
2727
import com.google.gson.GsonBuilder
28+
import okhttp3.HttpUrl
29+
import java.net.URL
2830

2931
object Defaults {
3032

@@ -36,19 +38,56 @@ object Defaults {
3638

3739
/** Default reconnect algorithm for the socket */
3840
val reconnectSteppedBackOff: (Int) -> Long = { tries ->
39-
if (tries > 9) 5_000 else listOf(10L, 50L, 100L, 150L, 200L, 250L, 500L, 1_000L, 2_000L)[tries - 1]
41+
if (tries > 9) 5_000 else listOf(
42+
10L, 50L, 100L, 150L, 200L, 250L, 500L, 1_000L, 2_000L
43+
)[tries - 1]
4044
}
4145

4246
/** Default rejoin algorithm for individual channels */
4347
val rejoinSteppedBackOff: (Int) -> Long = { tries ->
4448
if (tries > 3) 10_000 else listOf(1_000L, 2_000L, 5_000L)[tries - 1]
4549
}
4650

47-
4851
/** The default Gson configuration to use when parsing messages */
4952
val gson: Gson
5053
get() = GsonBuilder()
51-
.setLenient()
52-
.setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES)
53-
.create()
54+
.setLenient()
55+
.setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES)
56+
.create()
57+
58+
/**
59+
* Takes an endpoint and a params closure given by the User and constructs a URL that
60+
* is ready to be sent to the Socket connection.
61+
*
62+
* Will convert "ws://" and "wss://" to http/s which is what OkHttp expects.
63+
*
64+
* @throws IllegalArgumentException if [endpoint] is not a valid URL endpoint.
65+
*/
66+
internal fun buildEndpointUrl(
67+
endpoint: String,
68+
paramsClosure: PayloadClosure
69+
): URL {
70+
var mutableUrl = endpoint
71+
// Silently replace web socket URLs with HTTP URLs.
72+
if (endpoint.regionMatches(0, "ws:", 0, 3, ignoreCase = true)) {
73+
mutableUrl = "http:" + endpoint.substring(3)
74+
} else if (endpoint.regionMatches(0, "wss:", 0, 4, ignoreCase = true)) {
75+
mutableUrl = "https:" + endpoint.substring(4)
76+
}
77+
78+
// If there are query params, append them now
79+
var httpUrl =
80+
HttpUrl.parse(mutableUrl) ?: throw IllegalArgumentException("invalid url: $endpoint")
81+
paramsClosure.invoke()?.let {
82+
val httpBuilder = httpUrl.newBuilder()
83+
it.forEach { (key, value) ->
84+
httpBuilder.addQueryParameter(key, value.toString())
85+
}
86+
87+
httpUrl = httpBuilder.build()
88+
}
89+
90+
// Store the URL that will be used to establish a connection
91+
return httpUrl.url()
92+
}
5493
}

src/main/kotlin/org/phoenixframework/Socket.kt

Lines changed: 90 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -45,22 +45,34 @@ internal class StateChangeCallbacks {
4545
private set
4646

4747
/** Safely adds an onOpen callback */
48-
fun onOpen(ref: String, callback: () -> Unit) {
48+
fun onOpen(
49+
ref: String,
50+
callback: () -> Unit
51+
) {
4952
this.open = this.open + Pair(ref, callback)
5053
}
5154

5255
/** Safely adds an onClose callback */
53-
fun onClose(ref: String, callback: () -> Unit) {
56+
fun onClose(
57+
ref: String,
58+
callback: () -> Unit
59+
) {
5460
this.close = this.close + Pair(ref, callback)
5561
}
5662

5763
/** Safely adds an onError callback */
58-
fun onError(ref: String, callback: (Throwable, Response?) -> Unit) {
64+
fun onError(
65+
ref: String,
66+
callback: (Throwable, Response?) -> Unit
67+
) {
5968
this.error = this.error + Pair(ref, callback)
6069
}
6170

6271
/** Safely adds an onMessage callback */
63-
fun onMessage(ref: String, callback: (Message) -> Unit) {
72+
fun onMessage(
73+
ref: String,
74+
callback: (Message) -> Unit
75+
) {
6476
this.message = this.message + Pair(ref, callback)
6577
}
6678

@@ -87,12 +99,31 @@ const val WS_CLOSE_NORMAL = 1000
8799
/** RFC 6455: indicates that the connection was closed abnormally */
88100
const val WS_CLOSE_ABNORMAL = 1006
89101

102+
/**
103+
* A closure that will return an optional Payload
104+
*/
105+
typealias PayloadClosure = () -> Payload?
106+
90107
/**
91108
* Connects to a Phoenix Server
92109
*/
110+
111+
/**
112+
* A [Socket] which connects to a Phoenix Server. Takes a closure to allow for changing parameters
113+
* to be sent to the server when connecting.
114+
*
115+
* ## Example
116+
* ```
117+
* val socket = Socket("https://example.com/socket", { mapOf("token" to mAuthToken) })
118+
* ```
119+
* @param url Url to connect to such as https://example.com/socket
120+
* @param paramsClosure Closure which allows to change parameters sent during connection.
121+
* @param gson Default GSON Client to parse JSON. You can provide your own if needed.
122+
* @param client Default OkHttpClient to connect with. You can provide your own if needed.
123+
*/
93124
class Socket(
94125
url: String,
95-
params: Payload? = null,
126+
val paramsClosure: PayloadClosure,
96127
private val gson: Gson = Defaults.gson,
97128
private val client: OkHttpClient = OkHttpClient.Builder().build()
98129
) {
@@ -109,13 +140,8 @@ class Socket(
109140
val endpoint: String
110141

111142
/** The fully qualified socket URL */
112-
val endpointUrl: URL
113-
114-
/**
115-
* The optional params to pass when connecting. Must be set when
116-
* initializing the Socket. These will be appended to the URL.
117-
*/
118-
val params: Payload? = params
143+
var endpointUrl: URL
144+
private set
119145

120146
/** Timeout to use when opening a connection */
121147
var timeout: Long = Defaults.TIMEOUT
@@ -189,6 +215,27 @@ class Socket(
189215
//------------------------------------------------------------------------------
190216
// Initialization
191217
//------------------------------------------------------------------------------
218+
/**
219+
* A [Socket] which connects to a Phoenix Server. Takes a constant parameter to be sent to the
220+
* server when connecting. Defaults to null if excluded.
221+
*
222+
* ## Example
223+
* ```
224+
* val socket = Socket("https://example.com/socket", mapOf("token" to mAuthToken))
225+
* ```
226+
*
227+
* @param url Url to connect to such as https://example.com/socket
228+
* @param params Constant parameters to send when connecting. Defaults to null
229+
* @param gson Default GSON Client to parse JSON. You can provide your own if needed.
230+
* @param client Default OkHttpClient to connect with. You can provide your own if needed.
231+
*/
232+
constructor(
233+
url: String,
234+
params: Payload? = null,
235+
gson: Gson = Defaults.gson,
236+
client: OkHttpClient = OkHttpClient.Builder().build()
237+
) : this(url, { params }, gson, client)
238+
192239
init {
193240
var mutableUrl = url
194241

@@ -206,35 +253,18 @@ class Socket(
206253
// Store the endpoint before changing the protocol
207254
this.endpoint = mutableUrl
208255

209-
// Silently replace web socket URLs with HTTP URLs.
210-
if (url.regionMatches(0, "ws:", 0, 3, ignoreCase = true)) {
211-
mutableUrl = "http:" + url.substring(3)
212-
} else if (url.regionMatches(0, "wss:", 0, 4, ignoreCase = true)) {
213-
mutableUrl = "https:" + url.substring(4)
214-
}
215-
216-
// If there are query params, append them now
217-
var httpUrl = HttpUrl.parse(mutableUrl) ?: throw IllegalArgumentException("invalid url: $url")
218-
params?.let {
219-
val httpBuilder = httpUrl.newBuilder()
220-
it.forEach { (key, value) ->
221-
httpBuilder.addQueryParameter(key, value.toString())
222-
}
223-
224-
httpUrl = httpBuilder.build()
225-
}
226-
227-
// Store the URL that will be used to establish a connection
228-
this.endpointUrl = httpUrl.url()
256+
// Store the URL that will be used to establish a connection. Could potentially be
257+
// different at the time connect() is called based on a changing params closure.
258+
this.endpointUrl = Defaults.buildEndpointUrl(this.endpoint, this.paramsClosure)
229259

230260
// Create reconnect timer
231261
this.reconnectTimer = TimeoutTimer(
232-
dispatchQueue = dispatchQueue,
233-
timerCalculation = reconnectAfterMs,
234-
callback = {
235-
this.logItems("Socket attempting to reconnect")
236-
this.teardown { this.connect() }
237-
})
262+
dispatchQueue = dispatchQueue,
263+
timerCalculation = reconnectAfterMs,
264+
callback = {
265+
this.logItems("Socket attempting to reconnect")
266+
this.teardown { this.connect() }
267+
})
238268
}
239269

240270
//------------------------------------------------------------------------------
@@ -262,6 +292,11 @@ class Socket(
262292
// Reset the clean close flag when attempting to connect
263293
this.closeWasClean = false
264294

295+
// Build the new endpointUrl with the params closure. The payload returned
296+
// from the closure could be different such as a changing authToken.
297+
this.endpointUrl = Defaults.buildEndpointUrl(this.endpoint, this.paramsClosure)
298+
299+
// Now create the connection transport and attempt to connect
265300
this.connection = this.transport(endpointUrl)
266301
this.connection?.onOpen = { onConnectionOpened() }
267302
this.connection?.onClose = { code -> onConnectionClosed(code) }
@@ -281,7 +316,6 @@ class Socket(
281316
// Reset any reconnects and teardown the socket connection
282317
this.reconnectTimer.reset()
283318
this.teardown(code, reason, callback)
284-
285319
}
286320

287321
fun onOpen(callback: (() -> Unit)): String {
@@ -304,7 +338,10 @@ class Socket(
304338
this.stateChangeCallbacks.release()
305339
}
306340

307-
fun channel(topic: String, params: Payload = mapOf()): Channel {
341+
fun channel(
342+
topic: String,
343+
params: Payload = mapOf()
344+
): Channel {
308345
val channel = Channel(topic, params, this)
309346
this.channels = this.channels + channel
310347

@@ -318,7 +355,7 @@ class Socket(
318355
// removed instead of calling .remove() on the list, thus returning a new list
319356
// that does not contain the channel that was removed.
320357
this.channels = channels
321-
.filter { it.joinRef != channel.joinRef }
358+
.filter { it.joinRef != channel.joinRef }
322359
}
323360

324361
/**
@@ -449,7 +486,7 @@ class Socket(
449486
val period = heartbeatIntervalMs
450487

451488
heartbeatTask =
452-
dispatchQueue.queueAtFixedRate(delay, period, TimeUnit.MILLISECONDS) { sendHeartbeat() }
489+
dispatchQueue.queueAtFixedRate(delay, period, TimeUnit.MILLISECONDS) { sendHeartbeat() }
453490
}
454491

455492
internal fun sendHeartbeat() {
@@ -471,10 +508,11 @@ class Socket(
471508
// The last heartbeat was acknowledged by the server. Send another one
472509
this.pendingHeartbeatRef = this.makeRef()
473510
this.push(
474-
topic = "phoenix",
475-
event = Channel.Event.HEARTBEAT.value,
476-
payload = mapOf(),
477-
ref = pendingHeartbeatRef)
511+
topic = "phoenix",
512+
event = Channel.Event.HEARTBEAT.value,
513+
payload = mapOf(),
514+
ref = pendingHeartbeatRef
515+
)
478516
}
479517

480518
private fun abnormalClose(reason: String) {
@@ -538,14 +576,17 @@ class Socket(
538576

539577
// Dispatch the message to all channels that belong to the topic
540578
this.channels
541-
.filter { it.isMember(message) }
542-
.forEach { it.trigger(message) }
579+
.filter { it.isMember(message) }
580+
.forEach { it.trigger(message) }
543581

544582
// Inform all onMessage callbacks of the message
545583
this.stateChangeCallbacks.message.forEach { it.second.invoke(message) }
546584
}
547585

548-
internal fun onConnectionError(t: Throwable, response: Response?) {
586+
internal fun onConnectionError(
587+
t: Throwable,
588+
response: Response?
589+
) {
549590
this.logItems("Transport: error $t")
550591

551592
// Send an error to all channels
@@ -554,5 +595,4 @@ class Socket(
554595
// Inform any state callbacks of the error
555596
this.stateChangeCallbacks.error.forEach { it.second.invoke(t, response) }
556597
}
557-
558598
}

src/test/kotlin/org/phoenixframework/ChannelTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ class ChannelTest {
155155

156156
@BeforeEach
157157
internal fun setUp() {
158-
socket = spy(Socket(url ="https://localhost:4000/socket", client = okHttpClient))
158+
socket = spy(Socket(url = "https://localhost:4000/socket", client = okHttpClient))
159159
socket.dispatchQueue = fakeClock
160160
channel = Channel("topic", kDefaultPayload, socket)
161161
}

0 commit comments

Comments
 (0)