Last active
June 10, 2024 00:17
-
-
Save littletsu/09b9d3bf581b759e1dfbf40d2275df29 to your computer and use it in GitHub Desktop.
Websocket and HTTP server for obs lua scripting
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Based on https://github.com/stonetoad/obs-lua-httpd | |
local obs = obslua | |
-- From https://github.com/stonetoad/obs-lua-httpd/blob/e1c167f6c5231e605cf8531750153e728765f587/ljsocket.lua | |
local socket = require("ljsocket") | |
-- From https://gist.githubusercontent.com/PedroAlvesV/872a108f187f57c2a5b7b5bc34398496/raw/4ee8e36c9ee4b55a3d6bef768258ec8f9c6c3bc2/sha1.lua | |
local sha1 = require("sha1") | |
local bit = require("bit") | |
-- From https://devforum.roblox.com/t/base64-encoding-and-decoding-in-lua/1719860 | |
function to_base64(data) | |
local b = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/' | |
return ((data:gsub('.', function(x) | |
local r,b='',x:byte() | |
for i=8,1,-1 do r=r..(b%2^i-b%2^(i-1)>0 and '1' or '0') end | |
return r; | |
end)..'0000'):gsub('%d%d%d?%d?%d?%d?', function(x) | |
if (#x < 6) then return '' end | |
local c=0 | |
for i=1,6 do c=c+(x:sub(i,i)=='1' and 2^(6-i) or 0) end | |
return b:sub(c+1,c+1) | |
end)..({ '', '==', '=' })[#data%3+1]) | |
end | |
function b64_sha1(str) | |
return to_base64(sha1_binary(str)) | |
end | |
function debug_print(str) | |
print(str) | |
end | |
local server = nil | |
function script_load() | |
server = open_server( | |
"127.0.0.1", 12464, "/ws", | |
function(client, text, wss) | |
if text == "hi" then | |
wss.send_text(client, "Hi bro :3!") | |
end | |
end, | |
{ | |
["/"] = function() | |
return "<h1>Index</h1>" | |
end | |
} | |
).socket | |
end | |
function script_unload() | |
if server then | |
server:close() | |
end | |
end | |
function split(inputstr, sep) | |
if sep == nil then sep = "%s" end | |
local t = {} | |
for str in string.gmatch(inputstr, "([^" .. sep .. "]+)") do | |
table.insert(t, str) | |
end | |
return t | |
end | |
function read_header_list_includes(value, str) | |
local elements = split(value, ",") | |
debug_print(value, str) | |
for _,element in pairs(elements) do | |
if element:sub(1,1) == " " then | |
element = element:sub(2) | |
end | |
if element:sub(-1) == " " then | |
element = element:sub(1,-2) | |
end | |
debug_print(element) | |
if element == str then | |
return true | |
end | |
end | |
return false | |
end | |
function open_server(bind_host, port, wspath, on_text, routes) | |
local debug = false | |
local poll_interval = 500 -- millisecond poll interval | |
local poll_interval_fast = 30 | |
local max_fast_idle = 20 | |
local sock = assert(socket.create("inet", "stream", "tcp")) | |
print(bind_host) | |
assert(sock:set_blocking(false)) -- critical! don't hang obs UI! | |
assert(sock:set_option("reuseaddr", true)) | |
assert(sock:bind(bind_host, port)) | |
assert(sock:listen()) | |
local server = {} | |
local connected_ws = {} | |
function server.close() | |
if sock then | |
assert(sock:close()) | |
sock = nil | |
end | |
end | |
function do_slow_poll() | |
do_poll() | |
end | |
local fast_poll = false | |
local idle_count = 0 | |
function do_fast_poll() | |
idle_count = idle_count + 1 | |
if idle_count > max_fast_idle then | |
obs.remove_current_callback() | |
fast_poll = false | |
else | |
do_poll() | |
end | |
end | |
function do_poll() | |
if sock == nil then | |
obs.remove_current_callback() | |
return | |
end | |
local client, err, errno = sock:accept() | |
if client and client:is_connected() then | |
idle_count = 0 | |
if not fast_poll then | |
fast_poll = true | |
obs.timer_add(do_fast_poll, poll_interval_fast) | |
end | |
debug_print("Got client " .. tostring(client)) | |
debug_print("\tName is " .. tostring(client:get_name())) | |
debug_print("\tPeername is " .. tostring(client:get_peer_name())) | |
assert(client:set_blocking(false)) -- critical! don't hang obs UI! | |
do_request(client) | |
elseif err ~= "timeout" then | |
error(err) | |
end | |
for k,client in pairs(connected_ws) do | |
if not client:is_connected() then | |
connected_ws[k] = nil | |
else | |
local request_raw, err = client:receive() | |
if not request_raw then | |
if err ~= "timeout" then | |
-- ws probably closed | |
connected_ws[k] = nil | |
end | |
else | |
do_ws_request(client, request_raw, k) | |
end | |
end | |
end | |
end | |
function send_ws(client, op, payload) | |
local fin = 1 | |
local rsv1 = 0 | |
local rsv2 = 0 | |
local rsv3 = 0 | |
local first = bit.bor( | |
bit.lshift(fin, 7), | |
bit.lshift(rsv1, 6), | |
bit.lshift(rsv2, 5), | |
bit.lshift(rsv3, 4), | |
op | |
) | |
if #payload > 125 then | |
print("Unimplemented payload length " .. tostring(#payload)) | |
return | |
end | |
client:send(string.char(first, #payload) .. payload) | |
end | |
function send_text_ws(client, text) | |
send_ws(client, 1, text) | |
end | |
function broadcast(op, payload) | |
for k,client in pairs(connected_ws) do | |
send_ws(client, op, payload) | |
end | |
end | |
local wss = { | |
socket = sock, | |
broadcast = broadcast, | |
broadcast_text = function(text) | |
return broadcast(1, text) | |
end, | |
send_text = send_text_ws, | |
send = send_ws | |
} | |
function do_ws_request(client, data, id) | |
local first = data:byte() | |
local fin = bit.band(first, 128) | |
local rsv1 = bit.band(first, 64) | |
local rsv2 = bit.band(first, 32) | |
local rsv3 = bit.band(first, 16) | |
local opcode = bit.band(first, 15) | |
local second = data:byte(2) | |
local masked = bit.band(second, 128) | |
local len = bit.band(second, 127) | |
if masked == 0 then | |
client:close() | |
return | |
end | |
if len > 125 then | |
print("Unimplemented client payload length " .. tonumber(len)) | |
return | |
end | |
local mask = data:sub(3, 6) | |
local payload = data:sub(7) | |
local decoded = {} | |
-- debug_print("mask") | |
-- for i=1,#mask do | |
-- -- debug_print(mask:byte(i)) | |
-- end | |
for i=0,len-1 do | |
-- debug_print(tostring(i) .. "+1 demasking " .. tostring(payload:byte(i+1)) .. " with " .. tostring((i % #mask)+1) .. " mask byte") | |
decoded[i+1] = bit.bxor(payload:byte(i+1), mask:byte((i % #mask)+1)) | |
end | |
debug_print(("fin: %i, rsv: %i,%i,%i, op: %i, masked: %i, len: %i"):format(fin, rsv1, rsv2, rsv3, opcode, masked, len)) | |
if opcode == 1 then | |
local text = string.char(unpack(decoded)) | |
debug_print("decoded: " .. text) | |
on_text(client, text, wss) | |
elseif opcode == 8 then | |
connected_ws[id] = nil | |
client:close() | |
elseif opcode == 9 then | |
send_ws(client, 10, "") | |
end | |
end | |
function hex_to_char(x) | |
return string.char(tonumber(x, 16)) | |
end | |
function url_decode(s) | |
return string.gsub(s, "%%(%x%x)", hex_to_char) | |
end | |
local response_forbidden = [[ | |
HTTP/1.1 403 Forbidden | |
Connection: Close | |
Content-Type: text/html; charset=utf-8 | |
Access-Control-Allow-Origin: * | |
Cross-Origin-Opener-Policy: same-origin | |
Cross-Origin-Embedder-Policy: require-corp | |
Cache-Control: max-age=15 | |
<div style="background-color: darkgrey; foreground-color: white"> | |
<h1>You cannot access this location, please check the script log and config.</h1> | |
</div> | |
]] | |
local response_404 = [[ | |
HTTP/1.1 404 Not Found | |
Connection: Close | |
Content-Type: text/html; charset=utf-8 | |
Access-Control-Allow-Origin: * | |
Cross-Origin-Opener-Policy: same-origin | |
Cross-Origin-Embedder-Policy: require-corp | |
Cache-Control: max-age=15 | |
<div style="background-color: darkgrey; foreground-color: white"> | |
<h1>Route was not found.</h1> | |
</div> | |
]] | |
local response_bad = function(client, msg) | |
debug_print(msg) | |
client:send([[ | |
HTTP/1.1 400 Bad Request | |
Connection: Close | |
Content-Type: text/html; charset=utf-8 | |
Access-Control-Allow-Origin: * | |
Sec-Websocket-Version: 13 | |
Cross-Origin-Opener-Policy: same-origin | |
Cross-Origin-Embedder-Policy: require-corp | |
Cache-Control: max-age=15]] .. "\r\n\r\n" .. msg) | |
client:close() | |
end | |
local client_i = 1 | |
function do_request(client) | |
local request_raw, err = client:receive() | |
if not request_raw then | |
if err ~= "timeout" then | |
error("Client read error: " .. err) | |
else | |
-- client probably closed the connection before we got to it | |
print("Client socket timeout before processing") | |
end | |
return | |
end | |
debug_print(request_raw) | |
local line = string.gmatch(request_raw, "[^\r\n]+") | |
local method, url, ver = string.match(line(), "(%g+) (%g*) HTTP/(%g+)") | |
url = split(url_decode(url), "?")[1] | |
if method ~= "GET" then | |
error("Error: client requested unsupported http method") | |
return | |
end | |
if ver ~= "1.1" then | |
error("Error: client requested unsupported http version") | |
return | |
end | |
print("\trequest for " .. url) | |
if url == wspath then | |
local headers = {} | |
for s in line do | |
local header = split(s, ":") | |
header[2] = header[2]:sub(2) | |
debug_print(header[1] .. "=" .. header[2] .. " (" .. s .. ")") | |
headers[header[1]] = header[2] | |
end | |
debug_print(headers["Upgrade"]) | |
if headers["Upgrade"] == nil or headers["Upgrade"] ~= "websocket" then | |
response_bad(client, "bad upgrade") | |
return | |
end | |
if headers["Connection"] == nil or not read_header_list_includes(headers["Connection"], "Upgrade") then | |
response_bad(client, "bad connection") | |
return | |
end | |
if headers["Sec-WebSocket-Version"] == nil or headers["Sec-WebSocket-Version"] ~= "13" then | |
response_bad(client, "bad version") | |
return | |
end | |
if headers["Sec-WebSocket-Key"] == nil then | |
response_bad(client, "no key") | |
return | |
end | |
local accept = b64_sha1(headers["Sec-WebSocket-Key"].."258EAFA5-E914-47DA-95CA-C5AB0DC85B11") | |
client:send( | |
[[HTTP/1.1 101 Switching Protocols | |
Upgrade: websocket | |
Connection: Upgrade | |
Sec-WebSocket-Accept: ]] .. accept .. "\r\n\r\n" | |
) | |
connected_ws[accept .. client_i] = client | |
client_i = client_i + 1 | |
return | |
end | |
local route = routes[url] | |
if route == nil then | |
client:send(response_404) | |
client:close() | |
return | |
end | |
local content_type = "text/html; charset=utf8" | |
local content = route() | |
local content_length = string.len(content) | |
local headers = string.format( | |
[[HTTP/1.1 200 OK | |
Connection: Close | |
Content-Type: %s | |
Access-Control-Allow-Origin: * | |
Cross-Origin-Opener-Policy: same-origin | |
Cross-Origin-Embedder-Policy: require-corp]] .. "\r\n\r\n", content_type) | |
print(string.format("Serving request for %s from %s (%s, %dkB)", url, filename, content_type, content_length / 1024)) | |
client:send(headers) | |
client:send(content) | |
client:close() | |
end | |
obs.timer_add(do_slow_poll, poll_interval) | |
do_poll() -- no delay for testing | |
return wss | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment