Created
March 6, 2023 12:26
-
-
Save tylerreisinger/e1e26c1852c46e0a3b93d136608571ca to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@file:Suppress("WeakerAccess", "MemberVisibilityCanBePrivate") | |
package com.tyler.bitburner.websocket | |
import com.intellij.notification.NotificationGroupManager | |
import com.intellij.openapi.Disposable | |
import com.intellij.openapi.diagnostic.Logger | |
import com.intellij.openapi.project.Project | |
import com.intellij.openapi.ui.MessageType | |
import com.intellij.openapi.util.Disposer | |
import com.intellij.util.io.toByteArray | |
import kotlinx.coroutines.* | |
import kotlinx.coroutines.channels.Channel | |
import java.io.* | |
import java.net.ServerSocket | |
import java.net.Socket | |
import java.nio.ByteBuffer | |
import java.nio.ByteOrder | |
import java.nio.charset.Charset | |
import java.nio.charset.CharsetDecoder | |
import java.nio.charset.CharsetEncoder | |
import java.nio.charset.StandardCharsets | |
import java.security.MessageDigest | |
import java.util.* | |
import kotlin.concurrent.timer | |
import kotlin.experimental.and | |
import kotlin.experimental.or | |
import kotlin.experimental.xor | |
internal class Client(val socket: Socket, input: BufferedInputStream, output: OutputStream) : Disposable { | |
val inputStream = input | |
val outputStream = output | |
var runnerJob: Job? = null | |
private var heartbeatTimer: Timer? = timer(period = 30000) { | |
} | |
var lastPing: Long = System.nanoTime() | |
var lastPong: Long = System.nanoTime() | |
override fun dispose() { | |
heartbeatTimer?.cancel() | |
runnerJob?.cancel() | |
socket.close() | |
} | |
} | |
/** | |
* An input stream reader that doesn't close the underlying stream when it closes. These are effectively views | |
* (and what java should have done for readers from the start) | |
* | |
* @see java.io.InputStreamReader | |
*/ | |
class NonOwningInputStreamReader : InputStreamReader { | |
constructor(stream: InputStream, charsetName: String) : super(stream, charsetName) {} | |
constructor(stream: InputStream, charset: Charset = Charset.defaultCharset()) : super(stream, charset) {} | |
constructor(stream: InputStream, dec: CharsetDecoder) : super(stream, dec) {} | |
override fun close() {} | |
} | |
/** | |
* An output stream writer that doesn't close the underlying stream when it closes | |
* | |
* @see java.io.OutputStreamWriter | |
*/ | |
class NonOwningOutputStreamWriter : OutputStreamWriter { | |
constructor(stream: OutputStream, charsetName: String) : super(stream, charsetName) {} | |
constructor(stream: OutputStream, charset: Charset = Charset.defaultCharset()) : super(stream, charset) {} | |
constructor(stream: OutputStream, enc: CharsetEncoder): super(stream, enc) | |
override fun close() {} | |
} | |
data class WebSocketHeader(val length: Long, val opcode: Int, val final: Boolean = true, | |
val masked: Boolean = false, val mask: ByteArray = ByteArray(4), | |
val resv1: Boolean = false, val resv2: Boolean = false, val resv3: Boolean = false) { | |
fun isTextFrame() = opcode == 0x1 | |
fun isBinaryFrame() = opcode == 0x2 | |
fun isCustomOp() = opcode in 0x3..0x7 | |
fun isCloseFrame() = opcode == 0x8 | |
fun isPing() = opcode == 0x9 | |
fun isPong() = opcode == 0xA | |
fun isReserved() = opcode >= 0xB | |
fun hasExtendedLen() = length > 125 | |
fun hasLongExtendedLen() = length > 65535 | |
init { | |
if (opcode > 0x15) { | |
throw IllegalArgumentException("Opcode must be between 0..15") | |
} | |
} | |
fun byteSize(): Int { | |
val baseSize: Int = | |
if (hasExtendedLen()) { | |
4 | |
} else if (hasLongExtendedLen()) { | |
10 | |
} else { | |
2 | |
} | |
return if (masked) { | |
baseSize | |
} else { | |
baseSize + 4 | |
} | |
} | |
override fun equals(other: Any?): Boolean { | |
if (this === other) return true | |
if (javaClass != other?.javaClass) return false | |
other as WebSocketHeader | |
if (length != other.length) return false | |
if (opcode != other.opcode) return false | |
if (final != other.final) return false | |
if (masked != other.masked) return false | |
if (!mask.contentEquals(other.mask)) return false | |
return true | |
} | |
override fun hashCode(): Int { | |
var result = length.hashCode() | |
result = 31 * result + opcode | |
result = 31 * result + final.hashCode() | |
result = 31 * result + masked.hashCode() | |
result = 31 * result + mask.contentHashCode() | |
return result | |
} | |
/** | |
* Get a ByteArray of the header that is in the format needed to go over the wire | |
*/ | |
fun encode(): ByteArray { | |
val headerData = ByteArray(byteSize()) | |
// The suffering that is dealing with binary data in kotlin/java | |
headerData[0] = | |
(if (final) 0x80.toUByte().toByte() else 0x0.toByte()) | |
.or(if (resv1) 0x40.toByte() else 0x0.toByte()) | |
.or(if (resv1) 0x20.toByte() else 0x0.toByte()) | |
.or(if (resv1) 0x10.toByte() else 0x0.toByte()) | |
.or(opcode.shl(1).toByte()) | |
.or(if (masked) 0x1.toByte() else 0x0.toByte() | |
) | |
val shortLenVal: Byte = (if (!hasExtendedLen()) { | |
length | |
} else if (!hasLongExtendedLen()) { | |
126 | |
} else { | |
127 | |
}).toByte() | |
headerData[1] = shortLenVal.toUByte().or(if (masked) 0x80.toUByte() else 0.toUByte()).toByte() | |
val lenBuffer: ByteBuffer = | |
if (hasLongExtendedLen()) { | |
ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN).putLong(length.toULong().toLong()) | |
} else if (hasExtendedLen()) { | |
ByteBuffer.allocate(2).order(ByteOrder.BIG_ENDIAN).putShort(length.toUShort().toShort()) | |
} else { | |
ByteBuffer.allocate(0) | |
} | |
val unmaskedHeader = headerData + lenBuffer.toByteArray() | |
return if (masked) { | |
unmaskedHeader + mask | |
} else { | |
unmaskedHeader | |
} | |
} | |
} | |
class WebSocketFrame(val header: WebSocketHeader, val data: ByteArray) { | |
val textContent | |
get(): String? { | |
if (header.isTextFrame()) { | |
return String(data, StandardCharsets.UTF_8) | |
} | |
return null | |
} | |
} | |
@Suppress("PrivatePropertyName", "PropertyName") | |
class WebsocketServer : Disposable { | |
private val logger = Logger.getInstance(WebsocketServer::class.java) | |
val CHANNEL_BACKLOG = 100 | |
private var _listenSocket: ServerSocket? = null | |
private var _client: Client? = null | |
private var _coroutineScope: CoroutineScope? = null | |
private var _proj: Project? = null | |
private var _listenJob: Job? = null | |
private var _eventChannel = Channel<String>(CHANNEL_BACKLOG) | |
private val HTTP_REQUEST_REGEX = Regex("GET (/\\w*)\\s+HTTP/[.\\d]+") | |
private val HTTP_HEADER_REGEX = Regex("([-\\w]+):\\s+(.*)\\s*") | |
fun start(port: Int, proj: Project) { | |
_proj = proj | |
_listenSocket = ServerSocket(port) | |
_coroutineScope = CoroutineScope(Dispatchers.Default) | |
_listenJob = _coroutineScope!!.launch(Dispatchers.Default) { | |
acceptConnection(_listenSocket!!) | |
} | |
} | |
fun stop() { | |
_listenJob?.cancel() | |
if (_client != null) { | |
Disposer.dispose(_client!!) | |
} | |
_listenJob = null | |
_client = null | |
_proj = null | |
_eventChannel.close() | |
_eventChannel = Channel(CHANNEL_BACKLOG) | |
_listenSocket?.close() | |
_listenSocket = null | |
_coroutineScope?.cancel() | |
_coroutineScope = null | |
} | |
fun write(str: String): Boolean { | |
return when (_client) { | |
null -> { | |
logger.warn("Trying to send data over WebSocket when the client connection is not open.") | |
false | |
} | |
else -> { | |
val bytes = str.encodeToByteArray() | |
write(str.encodeToByteArray(), WebSocketHeader(bytes.size.toLong(), 0x1)) | |
} | |
} | |
} | |
/** | |
* Send a ping to the client | |
*/ | |
fun ping(): Boolean { | |
return when (_client) { | |
null -> { | |
logger.warn("Trying to send ping over WebSocket when the client connection is not open.") | |
false | |
} | |
else -> { | |
logger.info("Sending ping to client") | |
write(ByteArray(0), WebSocketHeader(0, 0x9)) | |
true | |
} | |
} | |
} | |
fun write(data: ByteArray, header: WebSocketHeader): Boolean { | |
val headerData = header.encode() | |
return true | |
} | |
private suspend fun acceptConnection(listenSocket: ServerSocket) { | |
while (!listenSocket.isClosed) { | |
if (_client != null) { | |
yield() | |
delay(20) | |
} else { | |
val clientSocket = withContext(Dispatchers.IO) { | |
listenSocket.accept() | |
} | |
if (clientSocket != null) { | |
val inputStream = withContext(Dispatchers.IO) { | |
BufferedInputStream(clientSocket.getInputStream()) | |
} | |
val outputStream = withContext(Dispatchers.IO) { | |
clientSocket.getOutputStream() | |
} | |
val addr = clientSocket.inetAddress.toString() | |
val port = clientSocket.localPort.toString() | |
notify( | |
"Connection established with external client: ${addr}:${port}.", | |
MessageType.INFO | |
) | |
val client = Client(clientSocket, inputStream, outputStream) | |
_client = client | |
withContext(Dispatchers.IO) { | |
doHandshake(client) | |
}.let { | |
if (_client != null) { | |
_coroutineScope?.launch { | |
clientWorker(client) | |
} | |
} | |
} | |
} | |
yield() | |
} | |
} | |
} | |
override fun dispose() { | |
stop() | |
} | |
private suspend fun clientWorker(client: Client) { | |
while (!client.socket.isClosed && _client != null) { | |
val json = withContext(Dispatchers.IO) { | |
val frame = parseFrame(client) | |
//Json.parseToJsonElement(data).jsonObject | |
} | |
yield() | |
delay(10) | |
} | |
} | |
private suspend fun parseFrameHeader(client: Client): WebSocketHeader { | |
val headerStart: ByteArray = withContext(Dispatchers.IO) { | |
client.inputStream.readNBytes(2) | |
} | |
val shortLen: Int = headerStart[1].and(0x7F.toByte()).toInt() | |
var longLen: Long = shortLen.toLong() | |
if (shortLen == 0x7E) { | |
val exLen = withContext(Dispatchers.IO) { | |
client.inputStream.readNBytes(2) | |
} | |
longLen = ByteBuffer.wrap(exLen).order(ByteOrder.BIG_ENDIAN).asShortBuffer()[0].toUShort().toLong() | |
} else if (shortLen == 0x7F) { | |
val exLen = withContext(Dispatchers.IO) { | |
client.inputStream.readNBytes(8) | |
} | |
longLen = ByteBuffer.wrap(exLen).order(ByteOrder.BIG_ENDIAN).asLongBuffer()[0].toLong() | |
} | |
// Parse flags | |
val opcode: Int = headerStart[0].and(0xE.toByte()).toInt() | |
val masked: Boolean = headerStart[0].and(0x1.toByte()) > 0 | |
val rsv1: Boolean = headerStart[0].and(0x40.toByte()) > 0 | |
val rsv2: Boolean = headerStart[0].and(0x20.toByte()) > 0 | |
val rsv3: Boolean = headerStart[0].and(0x10.toByte()) > 0 | |
val final: Boolean = headerStart[0].and(0x80.toUByte().toByte()) > 0 | |
val mask = ByteArray(4) | |
if (masked) { | |
withContext(Dispatchers.IO) { | |
client.inputStream.readNBytes(mask, 0, 4) | |
} | |
} | |
return WebSocketHeader(longLen, opcode, final, masked, mask, rsv1, rsv2, rsv3) | |
} | |
private suspend fun parseFrame(client: Client): WebSocketFrame { | |
var data = ByteArray(0) | |
var header: WebSocketHeader | |
do { | |
header = parseFrameHeader(client) | |
if (header.length == 0L) { | |
val dataPart = withContext(Dispatchers.IO) { | |
client.inputStream.readNBytes(header.length.toInt()) | |
} | |
if (header.masked) { | |
applyMasking(data, header.mask) | |
} | |
data += dataPart | |
} | |
if (header.length > Int.MAX_VALUE) { | |
notify("Got a huge packet of size $header.length bytes. This is unsupported") | |
withContext(Dispatchers.Default) { | |
stop() | |
} | |
throw IllegalArgumentException("Packet size over INT_MAX") | |
} | |
} while (!header.final) | |
return WebSocketFrame(header, data) | |
} | |
private fun applyMasking(data: ByteArray, mask: ByteArray) { | |
var maskIdx = 0 | |
for (i in data.indices) { | |
data[i].xor(mask[maskIdx]) | |
maskIdx += 1 | |
} | |
} | |
private suspend fun doHandshake(client: Client) { | |
val logger = Logger.getInstance(WebsocketServer::class.java) | |
val reader = BufferedReader(NonOwningInputStreamReader(client.inputStream)) | |
val requestStr = withContext(Dispatchers.IO) { | |
reader.readLine() | |
} | |
val getPath = HTTP_REQUEST_REGEX.matchEntire(requestStr)?.groups?.get(1)?.value | |
if (getPath == null) { | |
stop() | |
return | |
} | |
val headers = mutableMapOf<String, String>() | |
var line = withContext(Dispatchers.IO) { | |
reader.readLine() | |
} | |
while (line != "") { | |
val regexResult = HTTP_HEADER_REGEX.matchEntire(line) | |
val name = regexResult?.groups?.get(1)?.value | |
val value = regexResult?.groups?.get(2)?.value | |
if (name != null && value != null) { | |
headers[name] = value | |
} | |
logger.info("$name: $value") | |
line = withContext(Dispatchers.IO) { | |
reader.readLine() | |
} | |
} | |
if (headers["Upgrade"] == "websocket" && headers.containsKey("Sec-WebSocket-Key")) { | |
sendHandshakeResponse(client, headers) | |
} else { | |
notify("Client request did not switch to websocket", MessageType.WARNING) | |
stop() | |
} | |
headers.clear() | |
} | |
private suspend fun sendHandshakeResponse(client: Client, headers: Map<String, String>) { | |
val MAGIC_STRING = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" | |
val responseLines = mutableListOf<String>() | |
val secKey = headers["Sec-WebSocket-Key"] | |
// val extensions = headers["Sec-WebSocket-Extensions"] | |
val strArray: MutableList<String> = mutableListOf() | |
logger.warn(headers.mapTo(strArray) { e -> "${e.key}: ${e.value}" }.joinToString("\n")) | |
val responseKey = Base64.getEncoder().encode( | |
MessageDigest.getInstance("SHA1").digest((secKey + MAGIC_STRING).toByteArray()) | |
) | |
val responseKeyStr = String(responseKey) | |
val lang = headers["Content-Language"] | |
responseLines.addAll(listOf( | |
"HTTP/1.1 101 Switching Protocols", | |
"Upgrade: websocket", | |
"Connection: Upgrade", | |
"Content-Language: ${lang ?: "en-US"}", | |
"Sec-WebSocket-Version: 13", | |
"Sec-WebSocket-Accept: $responseKeyStr", | |
)) | |
val responseStr = responseLines.joinToString("\r\n", "", "\r\n") | |
logger.warn("\n" + responseStr) | |
withContext(Dispatchers.IO) { | |
OutputStreamWriter(client.outputStream).write(responseStr) | |
} | |
val clientAddr = client.socket.inetAddress.toString() | |
notify("Client established websocket connection.\n$clientAddr", MessageType.INFO) | |
client.runnerJob = _coroutineScope?.launch(Dispatchers.IO, CoroutineStart.DEFAULT) { | |
clientWorker(client) | |
} | |
} | |
private fun notify(message: String, type: MessageType = MessageType.INFO) { | |
NotificationGroupManager.getInstance()?.getNotificationGroup("Bitburner") | |
?.createNotification(message, type) | |
?.setTitle("Bitburner") | |
?.setSubtitle(_proj?.name) | |
?.notify(_proj) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment