Skip to content

Commit ef8ed44

Browse files
feat: new middlewares (#8)
1 parent a905e62 commit ef8ed44

4 files changed

Lines changed: 104 additions & 9 deletions

File tree

src/main/kotlin/net/ccbluex/netty/http/HttpServerHandler.kt

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import io.netty.handler.codec.http.HttpRequest
2727
import io.netty.handler.codec.http.LastHttpContent
2828
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory
2929
import net.ccbluex.netty.http.HttpServer.Companion.logger
30+
import net.ccbluex.netty.http.middleware.Middleware
3031
import net.ccbluex.netty.http.model.RequestContext
3132
import net.ccbluex.netty.http.websocket.WebSocketHandler
3233
import java.net.URLDecoder
@@ -66,6 +67,11 @@ internal class HttpServerHandler(private val server: HttpServer) : ChannelInboun
6667

6768
if (connection.equals("Upgrade", ignoreCase = true) &&
6869
upgrade.equals("WebSocket", ignoreCase = true)) {
70+
71+
if (server.middlewares.any {
72+
it is Middleware.OnWebSocketUpgrade && !it.invoke(ctx, msg)
73+
}) return
74+
6975
// Takes out Http Request Handler from the pipeline and replaces it with WebSocketHandler
7076
ctx.pipeline().replace(this, "websocketHandler", WebSocketHandler(server))
7177

@@ -90,6 +96,10 @@ internal class HttpServerHandler(private val server: HttpServer) : ChannelInboun
9096
msg.headers().associate { it.key to it.value },
9197
)
9298

99+
if (server.middlewares.any {
100+
it is Middleware.OnRequestStart && !it.invoke(ctx, msg, requestContext)
101+
}) return
102+
93103
localRequestContext.set(requestContext)
94104
}
95105
}
@@ -109,9 +119,13 @@ internal class HttpServerHandler(private val server: HttpServer) : ChannelInboun
109119
if (msg is LastHttpContent) {
110120
localRequestContext.remove()
111121

112-
val response = server.processRequestContext(requestContext)
113-
val httpResponse = server.middlewares.fold(response) { acc, f -> f(requestContext, acc) }
114-
ctx.writeAndFlush(httpResponse)
122+
var response = server.processRequestContext(requestContext)
123+
server.middlewares.forEach {
124+
if (it is Middleware.OnFullHttpResponse) {
125+
response = it.invoke(requestContext, response)
126+
}
127+
}
128+
ctx.writeAndFlush(response)
115129
}
116130
}
117131

src/main/kotlin/net/ccbluex/netty/http/middleware/CorsMiddleware.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class CorsMiddleware(
2424
listOf("GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"),
2525
private val allowedHeaders: List<String> =
2626
listOf("Content-Type", "Content-Length", "Authorization", "Accept", "X-Requested-With")
27-
): Middleware {
27+
): Middleware.OnFullHttpResponse {
2828

2929
/**
3030
* Middleware to handle CORS requests.
Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,45 @@
1+
/*
2+
* This file is part of Netty-Rest (https://github.com/CCBlueX/netty-rest)
3+
*
4+
* Copyright (c) 2024 CCBlueX
5+
*
6+
* LiquidBounce is free software: you can redistribute it and/or modify
7+
* it under the terms of the GNU General Public License as published by
8+
* the Free Software Foundation, either version 3 of the License, or
9+
* (at your option) any later version.
10+
*
11+
* Netty-Rest is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14+
* GNU General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU General Public License
17+
* along with Netty-Rest. If not, see <https://www.gnu.org/licenses/>.
18+
*
19+
*/
120
package net.ccbluex.netty.http.middleware
221

22+
import io.netty.channel.ChannelHandlerContext
323
import io.netty.handler.codec.http.FullHttpResponse
24+
import io.netty.handler.codec.http.HttpRequest
425
import net.ccbluex.netty.http.model.RequestContext
526

6-
fun interface Middleware {
7-
operator fun invoke(context: RequestContext, response: FullHttpResponse): FullHttpResponse
8-
}
27+
sealed interface Middleware {
28+
fun interface OnWebSocketUpgrade : Middleware {
29+
/**
30+
* @return if it's accepted
31+
*/
32+
operator fun invoke(ctx: ChannelHandlerContext, request: HttpRequest): Boolean
33+
}
34+
35+
fun interface OnRequestStart : Middleware {
36+
/**
37+
* @return if it's accepted
38+
*/
39+
operator fun invoke(ctx: ChannelHandlerContext, request: HttpRequest, requestContext: RequestContext): Boolean
40+
}
41+
42+
fun interface OnFullHttpResponse : Middleware {
43+
operator fun invoke(context: RequestContext, response: FullHttpResponse): FullHttpResponse
44+
}
45+
}

src/test/kotlin/HttpMiddlewareServerTest.kt

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11
import com.google.gson.JsonObject
22
import io.netty.handler.codec.http.FullHttpResponse
3+
import io.netty.handler.codec.http.HttpResponseStatus
34
import net.ccbluex.netty.http.HttpServer
5+
import net.ccbluex.netty.http.middleware.Middleware
46
import net.ccbluex.netty.http.model.RequestObject
7+
import net.ccbluex.netty.http.util.httpBadRequest
58
import net.ccbluex.netty.http.util.httpOk
69
import okhttp3.OkHttpClient
710
import okhttp3.Request
811
import okhttp3.Response
12+
import okhttp3.WebSocket
13+
import okhttp3.WebSocketListener
14+
import okio.Buffer
915
import org.junit.jupiter.api.*
16+
import java.net.ProtocolException
17+
import java.util.concurrent.CompletableFuture
18+
import java.util.concurrent.TimeUnit
1019
import kotlin.test.assertEquals
20+
import kotlin.test.assertIs
1121
import kotlin.test.assertNotNull
1222
import kotlin.test.assertTrue
1323

@@ -55,7 +65,7 @@ class HttpMiddlewareServerTest {
5565
get("/", ::static)
5666
}
5767

58-
server.middleware { requestContext, fullHttpResponse ->
68+
server.middleware(Middleware.OnFullHttpResponse { requestContext, fullHttpResponse ->
5969
// Add custom headers to the response
6070
fullHttpResponse.headers().add("X-Custom-Header", "Custom Value")
6171

@@ -66,7 +76,13 @@ class HttpMiddlewareServerTest {
6676
}
6777

6878
fullHttpResponse
69-
}
79+
}).middleware(Middleware.OnWebSocketUpgrade { context, _ ->
80+
context.writeAndFlush(
81+
httpBadRequest("WebSocket unsupported")
82+
)
83+
84+
false
85+
})
7086

7187
server.start(8080) // Start the server on port 8080
7288
return server
@@ -125,4 +141,32 @@ class HttpMiddlewareServerTest {
125141
"Query parameter should be present in the response")
126142
}
127143

144+
@Test
145+
fun testWebSocketShouldBeBadRequest() {
146+
val future = CompletableFuture<Boolean>()
147+
148+
client.newWebSocket(
149+
Request.Builder()
150+
.url("http://localhost:8080")
151+
.build(),
152+
object : WebSocketListener() {
153+
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
154+
assertIs<ProtocolException>(t)
155+
assertNotNull(response)
156+
assertEquals(HttpResponseStatus.BAD_REQUEST.code(), response.code())
157+
val exceptedResponseBody = httpBadRequest("WebSocket unsupported")
158+
val buffer = Buffer()
159+
buffer.write(exceptedResponseBody.content().nioBuffer())
160+
assertEquals(
161+
buffer.readUtf8(),
162+
response.body()!!.string()
163+
)
164+
future.complete(true)
165+
}
166+
}
167+
)
168+
169+
assertTrue(future.get(10, TimeUnit.SECONDS))
170+
}
171+
128172
}

0 commit comments

Comments
 (0)