refactor: 重构 WebSocketUtil 并修复 connect 方法中的 bug

main
mojo 1 month ago
parent 347465ffa0
commit 61ac84ad3b

@ -10,9 +10,7 @@ import okhttp3.Request
import okhttp3.Response
import okhttp3.WebSocket
import okhttp3.WebSocketListener
import okio.ByteString
import org.json.JSONArray
import java.lang.Exception
import java.net.URI
typealias WebSocketHeader = MutableMap<String, String>
@ -37,7 +35,7 @@ object WebSocketUtil {
sealed class WebSocketEvent(val code: Int) {
data object BuildConnectionSuccess : WebSocketEvent(110)
data object ConnectSuccess : WebSocketCode(11)
data object ConnectSuccess : WebSocketEvent(11)
data object AllMessageSent : WebSocketEvent(112)
}
@ -49,120 +47,188 @@ object WebSocketUtil {
val wsResponse = WsResponse(startTime = System.currentTimeMillis())
val responseHeader: WebSocketHeader = mutableMapOf()
val responseData = JSONArray()
val webSockResponse = Result(WEB_SOCKET_CODE, "")
val result = Result(WEB_SOCKET_CODE, "")
try {
if (socket == null || !isOpen || uri?.toString() != request.url) {
if (socket != null) {
disconnect()
}
uri = URI(request.url)
val subProtocol =
request.headers.filter {
it.key.equals(
WEB_SOCKET_REQUEST_HEADER_PROTOCOL,
true
)
}
.map { NameValue(it.key, it.value) }.firstOrNull()
connect(url = request.url, subProtocol) {
responseData.put(it)
}
responseHeader.put(
"${System.currentTimeMillis()}",
"${WebSocketEvent.BuildConnectionSuccess}"
)
try {
var waitCount = 0
while (!isOpen && waitCount < 10) {
LogUtils.info("Connecting...")
Thread.sleep(1000)
waitCount++
}
} catch (e: Exception) {
LogUtils.error(e)
webSockResponse.code = WebSocketCode.EstablishConnectionException.code
webSockResponse.message = "${e.message}"
return wsResponse
}
if (isOpen) {
responseHeader["${System.currentTimeMillis()}"] =
"${WebSocketEvent.ConnectSuccess.code}"
}
} else {
LogUtils.info("use last time web socket connection")
ensureConnection(request, responseHeader, responseData, result)
if (result.code != WEB_SOCKET_CODE) {
return buildResponse(wsResponse, responseHeader, responseData, result)
}
val messages: List<WsRequestParam> = request.message
if (messages.isNotEmpty() && socket != null && isOpen) {
try {
for (message in messages) {
//当前消息需要判断响应是否包含中断当前连接的关键字
interrupt = message.interrupt.ifBlank {
""
}
if (isOpen) {
send(message.value)
responseHeader.put("${System.currentTimeMillis()}", message.value)
LogUtils.info("${System.currentTimeMillis()} send message ${message.value}")
//当前消息发送完需要等待
if (message.waitTime > 0) {
Thread.sleep(message.waitTime)
}
}
}
} catch (e: Exception) {
LogUtils.error(e)
webSockResponse.code = WebSocketCode.SendMessageFailed.code
webSockResponse.message = "${e.message}"
return wsResponse
}
responseHeader.put(
"${System.currentTimeMillis()}",
"${WebSocketEvent.AllMessageSent.code}"
)
if (request.delay > 0) {
Thread.sleep(request.delay * 1000L)
}
} else {
webSockResponse.code = WebSocketCode.EstablishConnectionFailed.code
sendMessages(request, responseHeader, result)
if (result.code != WEB_SOCKET_CODE) {
return buildResponse(wsResponse, responseHeader, responseData, result)
}
applyDelay(request.delay)
} catch (e: Exception) {
webSockResponse.code = WebSocketCode.BuildConnectionFailed.code
webSockResponse.message = "${e.message}"
LogUtils.error(throwable = e)
return wsResponse
handleBuildConnectionException(result, e)
} finally {
if (responseData.length() < 1 && webSockResponse.message.isNotBlank()) {
responseData.put(webSockResponse.message)
finalizeResponse(wsResponse, responseHeader, responseData, result)
}
LogUtils.info("${System.currentTimeMillis()} start connect web socket ${uri?.toString()}")
return wsResponse
}
private fun ensureConnection(
request: WsRequest,
responseHeader: WebSocketHeader,
responseData: JSONArray,
result: Result
) {
if (needsNewConnection(request)) {
if (socket != null) {
disconnect()
}
establishConnection(request, responseHeader, responseData, result)
} else {
LogUtils.info("use last time web socket connection")
}
}
private fun needsNewConnection(request: WsRequest): Boolean {
return socket == null || !isOpen || uri?.toString() != request.url
}
private fun establishConnection(
request: WsRequest,
responseHeader: WebSocketHeader,
responseData: JSONArray,
result: Result
) {
uri = URI(request.url)
val subProtocol = extractSubProtocol(request.headers)
connect(url = request.url, subProtocol) {
responseData.put(it)
}
responseHeader[System.currentTimeMillis().toString()] =
WebSocketEvent.BuildConnectionSuccess.code.toString()
waitForConnection(result)
if (isOpen) {
responseHeader[System.currentTimeMillis().toString()] =
WebSocketEvent.ConnectSuccess.code.toString()
}
}
private fun extractSubProtocol(headers: MutableMap<String, String>): NameValue? {
return headers
.filter { it.key.equals(WEB_SOCKET_REQUEST_HEADER_PROTOCOL, true) }
.map { NameValue(it.key, it.value) }
.firstOrNull()
}
private fun waitForConnection(result: Result, maxWaitCount: Int = 10) {
try {
var waitCount = 0
while (!isOpen && waitCount < maxWaitCount) {
LogUtils.info("Connecting...")
Thread.sleep(1000)
waitCount++
}
} catch (e: Exception) {
LogUtils.error(e)
result.code = WebSocketCode.EstablishConnectionException.code
result.message = e.message ?: ""
}
}
private fun sendMessages(
request: WsRequest,
responseHeader: WebSocketHeader,
result: Result
) {
val messages = request.message
if (messages.isEmpty() || socket == null || !isOpen) {
result.code = WebSocketCode.EstablishConnectionFailed.code
return
}
try {
for (message in messages) {
if (!isOpen) break
interrupt = message.interrupt.ifBlank { "" }
sendMessage(message, responseHeader)
}
responseHeader[System.currentTimeMillis().toString()] =
WebSocketEvent.AllMessageSent.code.toString()
} catch (e: Exception) {
LogUtils.error(e)
result.code = WebSocketCode.SendMessageFailed.code
result.message = e.message ?: ""
}
}
wsResponse.data = responseData
wsResponse.headers = responseHeader
wsResponse.endTime = System.currentTimeMillis()
wsResponse.code = webSockResponse.code
private fun sendMessage(message: WsRequestParam, responseHeader: WebSocketHeader) {
send(message.value)
responseHeader[System.currentTimeMillis().toString()] = message.value
LogUtils.info("${System.currentTimeMillis()} send message ${message.value}")
if (message.waitTime > 0) {
Thread.sleep(message.waitTime)
}
}
private fun applyDelay(delaySeconds: Int) {
if (delaySeconds > 0) {
Thread.sleep(delaySeconds * 1000L)
}
}
LogUtils.info("${System.currentTimeMillis()} start connect web socket ${uri.toString()}")
private fun handleBuildConnectionException(result: Result, e: Exception) {
result.code = WebSocketCode.BuildConnectionFailed.code
result.message = e.message ?: ""
LogUtils.error(throwable = e)
}
private fun finalizeResponse(
wsResponse: WsResponse,
responseHeader: WebSocketHeader,
responseData: JSONArray,
result: Result
) {
if (responseData.length() < 1 && result.message.isNotBlank()) {
responseData.put(result.message)
}
wsResponse.data = responseData
wsResponse.headers = responseHeader
wsResponse.endTime = System.currentTimeMillis()
wsResponse.code = result.code
}
private fun buildResponse(
wsResponse: WsResponse,
responseHeader: WebSocketHeader,
responseData: JSONArray,
result: Result
): WsResponse {
finalizeResponse(wsResponse, responseHeader, responseData, result)
return wsResponse
}
fun connect(url: String, head: NameValue? = null, callBack: (String) -> Unit) {
val request = Request.Builder()
.url(url)
.apply {
if (head != null) {
addHeader(name = head.name, value = head.name)
}
}
.build()
listener = object : WebSocketListener() {
val requestBuilder = Request.Builder().url(url)
if (head != null) {
requestBuilder.addHeader(head.name, head.value)
}
val request = requestBuilder.build()
listener = createWebSocketListener(callBack)
socket = client.newWebSocket(request, listener!!)
}
private fun createWebSocketListener(callBack: (String) -> Unit): WebSocketListener {
return object : WebSocketListener() {
override fun onOpen(webSocket: WebSocket, response: Response) {
super.onOpen(webSocket, response)
isOpen = true
@ -172,10 +238,7 @@ object WebSocketUtil {
override fun onMessage(webSocket: WebSocket, text: String) {
super.onMessage(webSocket, text)
LogUtils.info("${System.currentTimeMillis()} receive message $text")
if ((interrupt.isNotBlank() && "*" == interrupt) || interrupt == text) {
disconnect()
}
callBack.invoke(text)
handleMessage(text, callBack)
}
override fun onClosed(webSocket: WebSocket, code: Int, reason: String) {
@ -189,7 +252,17 @@ object WebSocketUtil {
LogUtils.info("${System.currentTimeMillis()} current web socket connection error")
}
}
socket = client.newWebSocket(request, listener!!)
}
private fun handleMessage(text: String, callBack: (String) -> Unit) {
if (shouldInterrupt(text)) {
disconnect()
}
callBack.invoke(text)
}
private fun shouldInterrupt(text: String): Boolean {
return interrupt.isNotBlank() && (interrupt == "*" || interrupt == text)
}
fun send(message: String) {

Loading…
Cancel
Save