refactor: 重构 WebSocketUtil 并修复 connect 方法中的 bug

main
mojo 1 month ago
parent 347465ffa0
commit 61ac84ad3b

@ -10,9 +10,7 @@ import okhttp3.Request
import okhttp3.Response import okhttp3.Response
import okhttp3.WebSocket import okhttp3.WebSocket
import okhttp3.WebSocketListener import okhttp3.WebSocketListener
import okio.ByteString
import org.json.JSONArray import org.json.JSONArray
import java.lang.Exception
import java.net.URI import java.net.URI
typealias WebSocketHeader = MutableMap<String, String> typealias WebSocketHeader = MutableMap<String, String>
@ -37,7 +35,7 @@ object WebSocketUtil {
sealed class WebSocketEvent(val code: Int) { sealed class WebSocketEvent(val code: Int) {
data object BuildConnectionSuccess : WebSocketEvent(110) data object BuildConnectionSuccess : WebSocketEvent(110)
data object ConnectSuccess : WebSocketCode(11) data object ConnectSuccess : WebSocketEvent(11)
data object AllMessageSent : WebSocketEvent(112) data object AllMessageSent : WebSocketEvent(112)
} }
@ -49,120 +47,188 @@ object WebSocketUtil {
val wsResponse = WsResponse(startTime = System.currentTimeMillis()) val wsResponse = WsResponse(startTime = System.currentTimeMillis())
val responseHeader: WebSocketHeader = mutableMapOf() val responseHeader: WebSocketHeader = mutableMapOf()
val responseData = JSONArray() val responseData = JSONArray()
val result = Result(WEB_SOCKET_CODE, "")
val webSockResponse = Result(WEB_SOCKET_CODE, "")
try { try {
ensureConnection(request, responseHeader, responseData, result)
if (socket == null || !isOpen || uri?.toString() != request.url) {
if (socket != null) { if (result.code != WEB_SOCKET_CODE) {
disconnect() return buildResponse(wsResponse, responseHeader, responseData, result)
}
uri = URI(request.url)
val subProtocol =
request.headers.filter {
it.key.equals(
WEB_SOCKET_REQUEST_HEADER_PROTOCOL,
true
)
}
.map { NameValue(it.key, it.value) }.firstOrNull()
connect(url = request.url, subProtocol) {
responseData.put(it)
}
responseHeader.put(
"${System.currentTimeMillis()}",
"${WebSocketEvent.BuildConnectionSuccess}"
)
try {
var waitCount = 0
while (!isOpen && waitCount < 10) {
LogUtils.info("Connecting...")
Thread.sleep(1000)
waitCount++
}
} catch (e: Exception) {
LogUtils.error(e)
webSockResponse.code = WebSocketCode.EstablishConnectionException.code
webSockResponse.message = "${e.message}"
return wsResponse
}
if (isOpen) {
responseHeader["${System.currentTimeMillis()}"] =
"${WebSocketEvent.ConnectSuccess.code}"
}
} else {
LogUtils.info("use last time web socket connection")
} }
val messages: List<WsRequestParam> = request.message
if (messages.isNotEmpty() && socket != null && isOpen) { sendMessages(request, responseHeader, result)
try {
for (message in messages) { if (result.code != WEB_SOCKET_CODE) {
//当前消息需要判断响应是否包含中断当前连接的关键字 return buildResponse(wsResponse, responseHeader, responseData, result)
interrupt = message.interrupt.ifBlank {
""
}
if (isOpen) {
send(message.value)
responseHeader.put("${System.currentTimeMillis()}", message.value)
LogUtils.info("${System.currentTimeMillis()} send message ${message.value}")
//当前消息发送完需要等待
if (message.waitTime > 0) {
Thread.sleep(message.waitTime)
}
}
}
} catch (e: Exception) {
LogUtils.error(e)
webSockResponse.code = WebSocketCode.SendMessageFailed.code
webSockResponse.message = "${e.message}"
return wsResponse
}
responseHeader.put(
"${System.currentTimeMillis()}",
"${WebSocketEvent.AllMessageSent.code}"
)
if (request.delay > 0) {
Thread.sleep(request.delay * 1000L)
}
} else {
webSockResponse.code = WebSocketCode.EstablishConnectionFailed.code
} }
applyDelay(request.delay)
} catch (e: Exception) { } catch (e: Exception) {
webSockResponse.code = WebSocketCode.BuildConnectionFailed.code handleBuildConnectionException(result, e)
webSockResponse.message = "${e.message}"
LogUtils.error(throwable = e)
return wsResponse
} finally { } finally {
if (responseData.length() < 1 && webSockResponse.message.isNotBlank()) { finalizeResponse(wsResponse, responseHeader, responseData, result)
responseData.put(webSockResponse.message) }
LogUtils.info("${System.currentTimeMillis()} start connect web socket ${uri?.toString()}")
return wsResponse
}
private fun ensureConnection(
request: WsRequest,
responseHeader: WebSocketHeader,
responseData: JSONArray,
result: Result
) {
if (needsNewConnection(request)) {
if (socket != null) {
disconnect()
}
establishConnection(request, responseHeader, responseData, result)
} else {
LogUtils.info("use last time web socket connection")
}
}
private fun needsNewConnection(request: WsRequest): Boolean {
return socket == null || !isOpen || uri?.toString() != request.url
}
private fun establishConnection(
request: WsRequest,
responseHeader: WebSocketHeader,
responseData: JSONArray,
result: Result
) {
uri = URI(request.url)
val subProtocol = extractSubProtocol(request.headers)
connect(url = request.url, subProtocol) {
responseData.put(it)
}
responseHeader[System.currentTimeMillis().toString()] =
WebSocketEvent.BuildConnectionSuccess.code.toString()
waitForConnection(result)
if (isOpen) {
responseHeader[System.currentTimeMillis().toString()] =
WebSocketEvent.ConnectSuccess.code.toString()
}
}
private fun extractSubProtocol(headers: MutableMap<String, String>): NameValue? {
return headers
.filter { it.key.equals(WEB_SOCKET_REQUEST_HEADER_PROTOCOL, true) }
.map { NameValue(it.key, it.value) }
.firstOrNull()
}
private fun waitForConnection(result: Result, maxWaitCount: Int = 10) {
try {
var waitCount = 0
while (!isOpen && waitCount < maxWaitCount) {
LogUtils.info("Connecting...")
Thread.sleep(1000)
waitCount++
}
} catch (e: Exception) {
LogUtils.error(e)
result.code = WebSocketCode.EstablishConnectionException.code
result.message = e.message ?: ""
}
}
private fun sendMessages(
request: WsRequest,
responseHeader: WebSocketHeader,
result: Result
) {
val messages = request.message
if (messages.isEmpty() || socket == null || !isOpen) {
result.code = WebSocketCode.EstablishConnectionFailed.code
return
}
try {
for (message in messages) {
if (!isOpen) break
interrupt = message.interrupt.ifBlank { "" }
sendMessage(message, responseHeader)
} }
responseHeader[System.currentTimeMillis().toString()] =
WebSocketEvent.AllMessageSent.code.toString()
} catch (e: Exception) {
LogUtils.error(e)
result.code = WebSocketCode.SendMessageFailed.code
result.message = e.message ?: ""
}
}
wsResponse.data = responseData private fun sendMessage(message: WsRequestParam, responseHeader: WebSocketHeader) {
wsResponse.headers = responseHeader send(message.value)
wsResponse.endTime = System.currentTimeMillis() responseHeader[System.currentTimeMillis().toString()] = message.value
wsResponse.code = webSockResponse.code LogUtils.info("${System.currentTimeMillis()} send message ${message.value}")
if (message.waitTime > 0) {
Thread.sleep(message.waitTime)
} }
}
private fun applyDelay(delaySeconds: Int) {
if (delaySeconds > 0) {
Thread.sleep(delaySeconds * 1000L)
}
}
LogUtils.info("${System.currentTimeMillis()} start connect web socket ${uri.toString()}") private fun handleBuildConnectionException(result: Result, e: Exception) {
result.code = WebSocketCode.BuildConnectionFailed.code
result.message = e.message ?: ""
LogUtils.error(throwable = e)
}
private fun finalizeResponse(
wsResponse: WsResponse,
responseHeader: WebSocketHeader,
responseData: JSONArray,
result: Result
) {
if (responseData.length() < 1 && result.message.isNotBlank()) {
responseData.put(result.message)
}
wsResponse.data = responseData
wsResponse.headers = responseHeader
wsResponse.endTime = System.currentTimeMillis()
wsResponse.code = result.code
}
private fun buildResponse(
wsResponse: WsResponse,
responseHeader: WebSocketHeader,
responseData: JSONArray,
result: Result
): WsResponse {
finalizeResponse(wsResponse, responseHeader, responseData, result)
return wsResponse return wsResponse
} }
fun connect(url: String, head: NameValue? = null, callBack: (String) -> Unit) { fun connect(url: String, head: NameValue? = null, callBack: (String) -> Unit) {
val request = Request.Builder() val requestBuilder = Request.Builder().url(url)
.url(url)
.apply { if (head != null) {
if (head != null) { requestBuilder.addHeader(head.name, head.value)
addHeader(name = head.name, value = head.name) }
}
} val request = requestBuilder.build()
.build() listener = createWebSocketListener(callBack)
listener = object : WebSocketListener() { socket = client.newWebSocket(request, listener!!)
}
private fun createWebSocketListener(callBack: (String) -> Unit): WebSocketListener {
return object : WebSocketListener() {
override fun onOpen(webSocket: WebSocket, response: Response) { override fun onOpen(webSocket: WebSocket, response: Response) {
super.onOpen(webSocket, response) super.onOpen(webSocket, response)
isOpen = true isOpen = true
@ -172,10 +238,7 @@ object WebSocketUtil {
override fun onMessage(webSocket: WebSocket, text: String) { override fun onMessage(webSocket: WebSocket, text: String) {
super.onMessage(webSocket, text) super.onMessage(webSocket, text)
LogUtils.info("${System.currentTimeMillis()} receive message $text") LogUtils.info("${System.currentTimeMillis()} receive message $text")
if ((interrupt.isNotBlank() && "*" == interrupt) || interrupt == text) { handleMessage(text, callBack)
disconnect()
}
callBack.invoke(text)
} }
override fun onClosed(webSocket: WebSocket, code: Int, reason: String) { override fun onClosed(webSocket: WebSocket, code: Int, reason: String) {
@ -189,7 +252,17 @@ object WebSocketUtil {
LogUtils.info("${System.currentTimeMillis()} current web socket connection error") LogUtils.info("${System.currentTimeMillis()} current web socket connection error")
} }
} }
socket = client.newWebSocket(request, listener!!) }
private fun handleMessage(text: String, callBack: (String) -> Unit) {
if (shouldInterrupt(text)) {
disconnect()
}
callBack.invoke(text)
}
private fun shouldInterrupt(text: String): Boolean {
return interrupt.isNotBlank() && (interrupt == "*" || interrupt == text)
} }
fun send(message: String) { fun send(message: String) {

Loading…
Cancel
Save