Skip to content

Instantly share code, notes, and snippets.

@tylerreisinger
Created March 6, 2023 12:26
Show Gist options
  • Save tylerreisinger/e1e26c1852c46e0a3b93d136608571ca to your computer and use it in GitHub Desktop.
Save tylerreisinger/e1e26c1852c46e0a3b93d136608571ca to your computer and use it in GitHub Desktop.
@file:Suppress("WeakerAccess", "MemberVisibilityCanBePrivate")
package com.tyler.bitburner.websocket
import com.intellij.notification.NotificationGroupManager
import com.intellij.openapi.Disposable
import com.intellij.openapi.diagnostic.Logger
import com.intellij.openapi.project.Project
import com.intellij.openapi.ui.MessageType
import com.intellij.openapi.util.Disposer
import com.intellij.util.io.toByteArray
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import java.io.*
import java.net.ServerSocket
import java.net.Socket
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.charset.Charset
import java.nio.charset.CharsetDecoder
import java.nio.charset.CharsetEncoder
import java.nio.charset.StandardCharsets
import java.security.MessageDigest
import java.util.*
import kotlin.concurrent.timer
import kotlin.experimental.and
import kotlin.experimental.or
import kotlin.experimental.xor
internal class Client(val socket: Socket, input: BufferedInputStream, output: OutputStream) : Disposable {
val inputStream = input
val outputStream = output
var runnerJob: Job? = null
private var heartbeatTimer: Timer? = timer(period = 30000) {
}
var lastPing: Long = System.nanoTime()
var lastPong: Long = System.nanoTime()
override fun dispose() {
heartbeatTimer?.cancel()
runnerJob?.cancel()
socket.close()
}
}
/**
* An input stream reader that doesn't close the underlying stream when it closes. These are effectively views
* (and what java should have done for readers from the start)
*
* @see java.io.InputStreamReader
*/
class NonOwningInputStreamReader : InputStreamReader {
constructor(stream: InputStream, charsetName: String) : super(stream, charsetName) {}
constructor(stream: InputStream, charset: Charset = Charset.defaultCharset()) : super(stream, charset) {}
constructor(stream: InputStream, dec: CharsetDecoder) : super(stream, dec) {}
override fun close() {}
}
/**
* An output stream writer that doesn't close the underlying stream when it closes
*
* @see java.io.OutputStreamWriter
*/
class NonOwningOutputStreamWriter : OutputStreamWriter {
constructor(stream: OutputStream, charsetName: String) : super(stream, charsetName) {}
constructor(stream: OutputStream, charset: Charset = Charset.defaultCharset()) : super(stream, charset) {}
constructor(stream: OutputStream, enc: CharsetEncoder): super(stream, enc)
override fun close() {}
}
data class WebSocketHeader(val length: Long, val opcode: Int, val final: Boolean = true,
val masked: Boolean = false, val mask: ByteArray = ByteArray(4),
val resv1: Boolean = false, val resv2: Boolean = false, val resv3: Boolean = false) {
fun isTextFrame() = opcode == 0x1
fun isBinaryFrame() = opcode == 0x2
fun isCustomOp() = opcode in 0x3..0x7
fun isCloseFrame() = opcode == 0x8
fun isPing() = opcode == 0x9
fun isPong() = opcode == 0xA
fun isReserved() = opcode >= 0xB
fun hasExtendedLen() = length > 125
fun hasLongExtendedLen() = length > 65535
init {
if (opcode > 0x15) {
throw IllegalArgumentException("Opcode must be between 0..15")
}
}
fun byteSize(): Int {
val baseSize: Int =
if (hasExtendedLen()) {
4
} else if (hasLongExtendedLen()) {
10
} else {
2
}
return if (masked) {
baseSize
} else {
baseSize + 4
}
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as WebSocketHeader
if (length != other.length) return false
if (opcode != other.opcode) return false
if (final != other.final) return false
if (masked != other.masked) return false
if (!mask.contentEquals(other.mask)) return false
return true
}
override fun hashCode(): Int {
var result = length.hashCode()
result = 31 * result + opcode
result = 31 * result + final.hashCode()
result = 31 * result + masked.hashCode()
result = 31 * result + mask.contentHashCode()
return result
}
/**
* Get a ByteArray of the header that is in the format needed to go over the wire
*/
fun encode(): ByteArray {
val headerData = ByteArray(byteSize())
// The suffering that is dealing with binary data in kotlin/java
headerData[0] =
(if (final) 0x80.toUByte().toByte() else 0x0.toByte())
.or(if (resv1) 0x40.toByte() else 0x0.toByte())
.or(if (resv1) 0x20.toByte() else 0x0.toByte())
.or(if (resv1) 0x10.toByte() else 0x0.toByte())
.or(opcode.shl(1).toByte())
.or(if (masked) 0x1.toByte() else 0x0.toByte()
)
val shortLenVal: Byte = (if (!hasExtendedLen()) {
length
} else if (!hasLongExtendedLen()) {
126
} else {
127
}).toByte()
headerData[1] = shortLenVal.toUByte().or(if (masked) 0x80.toUByte() else 0.toUByte()).toByte()
val lenBuffer: ByteBuffer =
if (hasLongExtendedLen()) {
ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN).putLong(length.toULong().toLong())
} else if (hasExtendedLen()) {
ByteBuffer.allocate(2).order(ByteOrder.BIG_ENDIAN).putShort(length.toUShort().toShort())
} else {
ByteBuffer.allocate(0)
}
val unmaskedHeader = headerData + lenBuffer.toByteArray()
return if (masked) {
unmaskedHeader + mask
} else {
unmaskedHeader
}
}
}
class WebSocketFrame(val header: WebSocketHeader, val data: ByteArray) {
val textContent
get(): String? {
if (header.isTextFrame()) {
return String(data, StandardCharsets.UTF_8)
}
return null
}
}
@Suppress("PrivatePropertyName", "PropertyName")
class WebsocketServer : Disposable {
private val logger = Logger.getInstance(WebsocketServer::class.java)
val CHANNEL_BACKLOG = 100
private var _listenSocket: ServerSocket? = null
private var _client: Client? = null
private var _coroutineScope: CoroutineScope? = null
private var _proj: Project? = null
private var _listenJob: Job? = null
private var _eventChannel = Channel<String>(CHANNEL_BACKLOG)
private val HTTP_REQUEST_REGEX = Regex("GET (/\\w*)\\s+HTTP/[.\\d]+")
private val HTTP_HEADER_REGEX = Regex("([-\\w]+):\\s+(.*)\\s*")
fun start(port: Int, proj: Project) {
_proj = proj
_listenSocket = ServerSocket(port)
_coroutineScope = CoroutineScope(Dispatchers.Default)
_listenJob = _coroutineScope!!.launch(Dispatchers.Default) {
acceptConnection(_listenSocket!!)
}
}
fun stop() {
_listenJob?.cancel()
if (_client != null) {
Disposer.dispose(_client!!)
}
_listenJob = null
_client = null
_proj = null
_eventChannel.close()
_eventChannel = Channel(CHANNEL_BACKLOG)
_listenSocket?.close()
_listenSocket = null
_coroutineScope?.cancel()
_coroutineScope = null
}
fun write(str: String): Boolean {
return when (_client) {
null -> {
logger.warn("Trying to send data over WebSocket when the client connection is not open.")
false
}
else -> {
val bytes = str.encodeToByteArray()
write(str.encodeToByteArray(), WebSocketHeader(bytes.size.toLong(), 0x1))
}
}
}
/**
* Send a ping to the client
*/
fun ping(): Boolean {
return when (_client) {
null -> {
logger.warn("Trying to send ping over WebSocket when the client connection is not open.")
false
}
else -> {
logger.info("Sending ping to client")
write(ByteArray(0), WebSocketHeader(0, 0x9))
true
}
}
}
fun write(data: ByteArray, header: WebSocketHeader): Boolean {
val headerData = header.encode()
return true
}
private suspend fun acceptConnection(listenSocket: ServerSocket) {
while (!listenSocket.isClosed) {
if (_client != null) {
yield()
delay(20)
} else {
val clientSocket = withContext(Dispatchers.IO) {
listenSocket.accept()
}
if (clientSocket != null) {
val inputStream = withContext(Dispatchers.IO) {
BufferedInputStream(clientSocket.getInputStream())
}
val outputStream = withContext(Dispatchers.IO) {
clientSocket.getOutputStream()
}
val addr = clientSocket.inetAddress.toString()
val port = clientSocket.localPort.toString()
notify(
"Connection established with external client: ${addr}:${port}.",
MessageType.INFO
)
val client = Client(clientSocket, inputStream, outputStream)
_client = client
withContext(Dispatchers.IO) {
doHandshake(client)
}.let {
if (_client != null) {
_coroutineScope?.launch {
clientWorker(client)
}
}
}
}
yield()
}
}
}
override fun dispose() {
stop()
}
private suspend fun clientWorker(client: Client) {
while (!client.socket.isClosed && _client != null) {
val json = withContext(Dispatchers.IO) {
val frame = parseFrame(client)
//Json.parseToJsonElement(data).jsonObject
}
yield()
delay(10)
}
}
private suspend fun parseFrameHeader(client: Client): WebSocketHeader {
val headerStart: ByteArray = withContext(Dispatchers.IO) {
client.inputStream.readNBytes(2)
}
val shortLen: Int = headerStart[1].and(0x7F.toByte()).toInt()
var longLen: Long = shortLen.toLong()
if (shortLen == 0x7E) {
val exLen = withContext(Dispatchers.IO) {
client.inputStream.readNBytes(2)
}
longLen = ByteBuffer.wrap(exLen).order(ByteOrder.BIG_ENDIAN).asShortBuffer()[0].toUShort().toLong()
} else if (shortLen == 0x7F) {
val exLen = withContext(Dispatchers.IO) {
client.inputStream.readNBytes(8)
}
longLen = ByteBuffer.wrap(exLen).order(ByteOrder.BIG_ENDIAN).asLongBuffer()[0].toLong()
}
// Parse flags
val opcode: Int = headerStart[0].and(0xE.toByte()).toInt()
val masked: Boolean = headerStart[0].and(0x1.toByte()) > 0
val rsv1: Boolean = headerStart[0].and(0x40.toByte()) > 0
val rsv2: Boolean = headerStart[0].and(0x20.toByte()) > 0
val rsv3: Boolean = headerStart[0].and(0x10.toByte()) > 0
val final: Boolean = headerStart[0].and(0x80.toUByte().toByte()) > 0
val mask = ByteArray(4)
if (masked) {
withContext(Dispatchers.IO) {
client.inputStream.readNBytes(mask, 0, 4)
}
}
return WebSocketHeader(longLen, opcode, final, masked, mask, rsv1, rsv2, rsv3)
}
private suspend fun parseFrame(client: Client): WebSocketFrame {
var data = ByteArray(0)
var header: WebSocketHeader
do {
header = parseFrameHeader(client)
if (header.length == 0L) {
val dataPart = withContext(Dispatchers.IO) {
client.inputStream.readNBytes(header.length.toInt())
}
if (header.masked) {
applyMasking(data, header.mask)
}
data += dataPart
}
if (header.length > Int.MAX_VALUE) {
notify("Got a huge packet of size $header.length bytes. This is unsupported")
withContext(Dispatchers.Default) {
stop()
}
throw IllegalArgumentException("Packet size over INT_MAX")
}
} while (!header.final)
return WebSocketFrame(header, data)
}
private fun applyMasking(data: ByteArray, mask: ByteArray) {
var maskIdx = 0
for (i in data.indices) {
data[i].xor(mask[maskIdx])
maskIdx += 1
}
}
private suspend fun doHandshake(client: Client) {
val logger = Logger.getInstance(WebsocketServer::class.java)
val reader = BufferedReader(NonOwningInputStreamReader(client.inputStream))
val requestStr = withContext(Dispatchers.IO) {
reader.readLine()
}
val getPath = HTTP_REQUEST_REGEX.matchEntire(requestStr)?.groups?.get(1)?.value
if (getPath == null) {
stop()
return
}
val headers = mutableMapOf<String, String>()
var line = withContext(Dispatchers.IO) {
reader.readLine()
}
while (line != "") {
val regexResult = HTTP_HEADER_REGEX.matchEntire(line)
val name = regexResult?.groups?.get(1)?.value
val value = regexResult?.groups?.get(2)?.value
if (name != null && value != null) {
headers[name] = value
}
logger.info("$name: $value")
line = withContext(Dispatchers.IO) {
reader.readLine()
}
}
if (headers["Upgrade"] == "websocket" && headers.containsKey("Sec-WebSocket-Key")) {
sendHandshakeResponse(client, headers)
} else {
notify("Client request did not switch to websocket", MessageType.WARNING)
stop()
}
headers.clear()
}
private suspend fun sendHandshakeResponse(client: Client, headers: Map<String, String>) {
val MAGIC_STRING = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
val responseLines = mutableListOf<String>()
val secKey = headers["Sec-WebSocket-Key"]
// val extensions = headers["Sec-WebSocket-Extensions"]
val strArray: MutableList<String> = mutableListOf()
logger.warn(headers.mapTo(strArray) { e -> "${e.key}: ${e.value}" }.joinToString("\n"))
val responseKey = Base64.getEncoder().encode(
MessageDigest.getInstance("SHA1").digest((secKey + MAGIC_STRING).toByteArray())
)
val responseKeyStr = String(responseKey)
val lang = headers["Content-Language"]
responseLines.addAll(listOf(
"HTTP/1.1 101 Switching Protocols",
"Upgrade: websocket",
"Connection: Upgrade",
"Content-Language: ${lang ?: "en-US"}",
"Sec-WebSocket-Version: 13",
"Sec-WebSocket-Accept: $responseKeyStr",
))
val responseStr = responseLines.joinToString("\r\n", "", "\r\n")
logger.warn("\n" + responseStr)
withContext(Dispatchers.IO) {
OutputStreamWriter(client.outputStream).write(responseStr)
}
val clientAddr = client.socket.inetAddress.toString()
notify("Client established websocket connection.\n$clientAddr", MessageType.INFO)
client.runnerJob = _coroutineScope?.launch(Dispatchers.IO, CoroutineStart.DEFAULT) {
clientWorker(client)
}
}
private fun notify(message: String, type: MessageType = MessageType.INFO) {
NotificationGroupManager.getInstance()?.getNotificationGroup("Bitburner")
?.createNotification(message, type)
?.setTitle("Bitburner")
?.setSubtitle(_proj?.name)
?.notify(_proj)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment