diff --git a/lib/src/main/java/com/example/utils/WebSocketUtil.kt b/lib/src/main/java/com/example/utils/WebSocketUtil.kt index 8804e79..a60df7b 100644 --- a/lib/src/main/java/com/example/utils/WebSocketUtil.kt +++ b/lib/src/main/java/com/example/utils/WebSocketUtil.kt @@ -10,9 +10,7 @@ import okhttp3.Request import okhttp3.Response import okhttp3.WebSocket import okhttp3.WebSocketListener -import okio.ByteString import org.json.JSONArray -import java.lang.Exception import java.net.URI typealias WebSocketHeader = MutableMap @@ -37,7 +35,7 @@ object WebSocketUtil { sealed class WebSocketEvent(val code: Int) { data object BuildConnectionSuccess : WebSocketEvent(110) - data object ConnectSuccess : WebSocketCode(11) + data object ConnectSuccess : WebSocketEvent(11) data object AllMessageSent : WebSocketEvent(112) } @@ -49,120 +47,188 @@ object WebSocketUtil { val wsResponse = WsResponse(startTime = System.currentTimeMillis()) val responseHeader: WebSocketHeader = mutableMapOf() val responseData = JSONArray() - - val webSockResponse = Result(WEB_SOCKET_CODE, "") + val result = Result(WEB_SOCKET_CODE, "") try { - - if (socket == null || !isOpen || uri?.toString() != request.url) { - if (socket != null) { - disconnect() - } - uri = URI(request.url) - val subProtocol = - request.headers.filter { - it.key.equals( - WEB_SOCKET_REQUEST_HEADER_PROTOCOL, - true - ) - } - .map { NameValue(it.key, it.value) }.firstOrNull() - connect(url = request.url, subProtocol) { - responseData.put(it) - } - responseHeader.put( - "${System.currentTimeMillis()}", - "${WebSocketEvent.BuildConnectionSuccess}" - ) - - try { - var waitCount = 0 - while (!isOpen && waitCount < 10) { - LogUtils.info("Connecting...") - Thread.sleep(1000) - waitCount++ - } - } catch (e: Exception) { - LogUtils.error(e) - webSockResponse.code = WebSocketCode.EstablishConnectionException.code - webSockResponse.message = "${e.message}" - return wsResponse - } - - if (isOpen) { - responseHeader["${System.currentTimeMillis()}"] = - "${WebSocketEvent.ConnectSuccess.code}" - } - } else { - LogUtils.info("use last time web socket connection") + ensureConnection(request, responseHeader, responseData, result) + + if (result.code != WEB_SOCKET_CODE) { + return buildResponse(wsResponse, responseHeader, responseData, result) } - val messages: List = request.message - if (messages.isNotEmpty() && socket != null && isOpen) { - try { - for (message in messages) { - //当前消息需要判断响应是否包含中断当前连接的关键字 - interrupt = message.interrupt.ifBlank { - "" - } - - if (isOpen) { - send(message.value) - responseHeader.put("${System.currentTimeMillis()}", message.value) - LogUtils.info("${System.currentTimeMillis()} send message ${message.value}") - //当前消息发送完需要等待 - if (message.waitTime > 0) { - Thread.sleep(message.waitTime) - } - - } - } - } catch (e: Exception) { - LogUtils.error(e) - webSockResponse.code = WebSocketCode.SendMessageFailed.code - webSockResponse.message = "${e.message}" - return wsResponse - } - responseHeader.put( - "${System.currentTimeMillis()}", - "${WebSocketEvent.AllMessageSent.code}" - ) - if (request.delay > 0) { - Thread.sleep(request.delay * 1000L) - } - } else { - webSockResponse.code = WebSocketCode.EstablishConnectionFailed.code + + sendMessages(request, responseHeader, result) + + if (result.code != WEB_SOCKET_CODE) { + return buildResponse(wsResponse, responseHeader, responseData, result) } + + applyDelay(request.delay) } catch (e: Exception) { - webSockResponse.code = WebSocketCode.BuildConnectionFailed.code - webSockResponse.message = "${e.message}" - LogUtils.error(throwable = e) - return wsResponse + handleBuildConnectionException(result, e) } finally { - if (responseData.length() < 1 && webSockResponse.message.isNotBlank()) { - responseData.put(webSockResponse.message) + finalizeResponse(wsResponse, responseHeader, responseData, result) + } + + LogUtils.info("${System.currentTimeMillis()} start connect web socket ${uri?.toString()}") + return wsResponse + } + + private fun ensureConnection( + request: WsRequest, + responseHeader: WebSocketHeader, + responseData: JSONArray, + result: Result + ) { + if (needsNewConnection(request)) { + if (socket != null) { + disconnect() + } + establishConnection(request, responseHeader, responseData, result) + } else { + LogUtils.info("use last time web socket connection") + } + } + + private fun needsNewConnection(request: WsRequest): Boolean { + return socket == null || !isOpen || uri?.toString() != request.url + } + + private fun establishConnection( + request: WsRequest, + responseHeader: WebSocketHeader, + responseData: JSONArray, + result: Result + ) { + uri = URI(request.url) + val subProtocol = extractSubProtocol(request.headers) + + connect(url = request.url, subProtocol) { + responseData.put(it) + } + + responseHeader[System.currentTimeMillis().toString()] = + WebSocketEvent.BuildConnectionSuccess.code.toString() + + waitForConnection(result) + + if (isOpen) { + responseHeader[System.currentTimeMillis().toString()] = + WebSocketEvent.ConnectSuccess.code.toString() + } + } + + private fun extractSubProtocol(headers: MutableMap): NameValue? { + return headers + .filter { it.key.equals(WEB_SOCKET_REQUEST_HEADER_PROTOCOL, true) } + .map { NameValue(it.key, it.value) } + .firstOrNull() + } + + private fun waitForConnection(result: Result, maxWaitCount: Int = 10) { + try { + var waitCount = 0 + while (!isOpen && waitCount < maxWaitCount) { + LogUtils.info("Connecting...") + Thread.sleep(1000) + waitCount++ + } + } catch (e: Exception) { + LogUtils.error(e) + result.code = WebSocketCode.EstablishConnectionException.code + result.message = e.message ?: "" + } + } + + private fun sendMessages( + request: WsRequest, + responseHeader: WebSocketHeader, + result: Result + ) { + val messages = request.message + if (messages.isEmpty() || socket == null || !isOpen) { + result.code = WebSocketCode.EstablishConnectionFailed.code + return + } + + try { + for (message in messages) { + if (!isOpen) break + + interrupt = message.interrupt.ifBlank { "" } + sendMessage(message, responseHeader) } + + responseHeader[System.currentTimeMillis().toString()] = + WebSocketEvent.AllMessageSent.code.toString() + } catch (e: Exception) { + LogUtils.error(e) + result.code = WebSocketCode.SendMessageFailed.code + result.message = e.message ?: "" + } + } - wsResponse.data = responseData - wsResponse.headers = responseHeader - wsResponse.endTime = System.currentTimeMillis() - wsResponse.code = webSockResponse.code + private fun sendMessage(message: WsRequestParam, responseHeader: WebSocketHeader) { + send(message.value) + responseHeader[System.currentTimeMillis().toString()] = message.value + LogUtils.info("${System.currentTimeMillis()} send message ${message.value}") + + if (message.waitTime > 0) { + Thread.sleep(message.waitTime) } + } + + private fun applyDelay(delaySeconds: Int) { + if (delaySeconds > 0) { + Thread.sleep(delaySeconds * 1000L) + } + } - LogUtils.info("${System.currentTimeMillis()} start connect web socket ${uri.toString()}") + private fun handleBuildConnectionException(result: Result, e: Exception) { + result.code = WebSocketCode.BuildConnectionFailed.code + result.message = e.message ?: "" + LogUtils.error(throwable = e) + } + + private fun finalizeResponse( + wsResponse: WsResponse, + responseHeader: WebSocketHeader, + responseData: JSONArray, + result: Result + ) { + if (responseData.length() < 1 && result.message.isNotBlank()) { + responseData.put(result.message) + } + wsResponse.data = responseData + wsResponse.headers = responseHeader + wsResponse.endTime = System.currentTimeMillis() + wsResponse.code = result.code + } + + private fun buildResponse( + wsResponse: WsResponse, + responseHeader: WebSocketHeader, + responseData: JSONArray, + result: Result + ): WsResponse { + finalizeResponse(wsResponse, responseHeader, responseData, result) return wsResponse } fun connect(url: String, head: NameValue? = null, callBack: (String) -> Unit) { - val request = Request.Builder() - .url(url) - .apply { - if (head != null) { - addHeader(name = head.name, value = head.name) - } - } - .build() - listener = object : WebSocketListener() { + val requestBuilder = Request.Builder().url(url) + + if (head != null) { + requestBuilder.addHeader(head.name, head.value) + } + + val request = requestBuilder.build() + listener = createWebSocketListener(callBack) + socket = client.newWebSocket(request, listener!!) + } + + private fun createWebSocketListener(callBack: (String) -> Unit): WebSocketListener { + return object : WebSocketListener() { override fun onOpen(webSocket: WebSocket, response: Response) { super.onOpen(webSocket, response) isOpen = true @@ -172,10 +238,7 @@ object WebSocketUtil { override fun onMessage(webSocket: WebSocket, text: String) { super.onMessage(webSocket, text) LogUtils.info("${System.currentTimeMillis()} receive message $text") - if ((interrupt.isNotBlank() && "*" == interrupt) || interrupt == text) { - disconnect() - } - callBack.invoke(text) + handleMessage(text, callBack) } override fun onClosed(webSocket: WebSocket, code: Int, reason: String) { @@ -189,7 +252,17 @@ object WebSocketUtil { LogUtils.info("${System.currentTimeMillis()} current web socket connection error") } } - socket = client.newWebSocket(request, listener!!) + } + + private fun handleMessage(text: String, callBack: (String) -> Unit) { + if (shouldInterrupt(text)) { + disconnect() + } + callBack.invoke(text) + } + + private fun shouldInterrupt(text: String): Boolean { + return interrupt.isNotBlank() && (interrupt == "*" || interrupt == text) } fun send(message: String) {