Created
April 18, 2012 21:31
-
-
Save allenhwkim/2416740 to your computer and use it in GitHub Desktop.
Ruby websocket client
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# This is modified version of https://github.com/gimite/web-socket-ruby/blob/master/lib/web_socket.rb | |
# | |
# Lincense: New BSD Lincense | |
# | |
require "base64" | |
require "socket" | |
require "uri" | |
require "digest/md5" | |
require "digest/sha1" | |
require "openssl" | |
require "stringio" | |
class WebSocketClient | |
class << self; attr_accessor(:debug); end | |
attr_accessor(:debug) | |
attr_reader(:server, :header, :path) | |
class Error < RuntimeError; end | |
WEB_SOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" | |
OPCODE_CONTINUATION = 0x00 | |
OPCODE_TEXT = 0x01 | |
OPCODE_BINARY = 0x02 | |
OPCODE_CLOSE = 0x08 | |
OPCODE_PING = 0x09 | |
OPCODE_PONG = 0x0a | |
def initialize(arg, params = {}) | |
@web_socket_version = "hixie-76" | |
uri = arg.is_a?(String) ? URI.parse(arg) : arg | |
if uri.scheme == "ws" | |
default_port = 80 | |
elsif uri.scheme = "wss" | |
default_port = 443 | |
else | |
raise(WebSocketClient::Error, "unsupported scheme: #{uri.scheme}") | |
end | |
@path = (uri.path.empty? ? "/" : uri.path) + (uri.query ? "?" + uri.query : "") | |
host = uri.host + ((!uri.port || uri.port == default_port) ? "" : ":#{uri.port}") | |
origin = params[:origin] || "http://#{uri.host}" | |
key1 = generate_key() | |
key2 = generate_key() | |
key3 = generate_key3() | |
socket = TCPSocket.new(uri.host, uri.port || default_port) | |
if uri.scheme == "ws" | |
@socket = socket | |
else | |
@socket = ssl_handshake(socket) | |
end | |
write( | |
"GET #{@path} HTTP/1.1\r\n" + | |
"Upgrade: WebSocket\r\n" + | |
"Connection: Upgrade\r\n" + | |
"Host: #{host}\r\n" + | |
"Origin: #{origin}\r\n" + | |
"Sec-WebSocket-Key1: #{key1}\r\n" + | |
"Sec-WebSocket-Key2: #{key2}\r\n" + | |
"\r\n" + | |
"#{key3}") | |
flush() | |
line = gets().chomp() | |
raise(WebSocketClient::Error, "bad response: #{line}") if !(line =~ /\AHTTP\/1.1 101 /n) | |
read_header() | |
if (@header["sec-websocket-origin"] || "").downcase() != origin.downcase() | |
raise(WebSocketClient::Error, | |
"origin doesn't match: '#{@header["sec-websocket-origin"]}' != '#{origin}'") | |
end | |
reply_digest = read(16) | |
expected_digest = hixie_76_security_digest(key1, key2, key3) | |
if reply_digest != expected_digest | |
raise(WebSocketClient::Error, | |
"security digest doesn't match: %p != %p" % [reply_digest, expected_digest]) | |
end | |
@handshaked = true | |
@received = [] | |
@buffer = "" | |
@closing_started = false | |
end | |
def send(data) | |
if !@handshaked | |
raise(WebSocketClient::Error, "call WebSocketClient\#handshake first") | |
end | |
case @web_socket_version | |
when "hixie-75", "hixie-76" | |
data = force_encoding(data.dup(), "ASCII-8BIT") | |
write("\x00#{data}\xff") | |
flush() | |
else | |
send_frame(OPCODE_TEXT, data, !@server) | |
end | |
end | |
def receive() | |
if !@handshaked | |
raise(WebSocketClient::Error, "call WebSocketClient\#handshake first") | |
end | |
case @web_socket_version | |
when "hixie-75", "hixie-76" | |
packet = gets("\xff") | |
return nil if !packet | |
if packet =~ /\A\x00(.*)\xff\z/nm | |
return force_encoding($1, "UTF-8") | |
elsif packet == "\xff" && read(1) == "\x00" # closing | |
close(1005, "", :peer) | |
return nil | |
else | |
raise(WebSocketClient::Error, "input must be either '\\x00...\\xff' or '\\xff\\x00'") | |
end | |
else | |
begin | |
bytes = read(2).unpack("C*") | |
fin = (bytes[0] & 0x80) != 0 | |
opcode = bytes[0] & 0x0f | |
mask = (bytes[1] & 0x80) != 0 | |
plength = bytes[1] & 0x7f | |
if plength == 126 | |
bytes = read(2) | |
plength = bytes.unpack("n")[0] | |
elsif plength == 127 | |
bytes = read(8) | |
(high, low) = bytes.unpack("NN") | |
plength = high * (2 ** 32) + low | |
end | |
if @server && !mask | |
# Masking is required. | |
@socket.close() | |
raise(WebSocketClient::Error, "received unmasked data") | |
end | |
mask_key = mask ? read(4).unpack("C*") : nil | |
payload = read(plength) | |
payload = apply_mask(payload, mask_key) if mask | |
case opcode | |
when OPCODE_TEXT | |
return force_encoding(payload, "UTF-8") | |
when OPCODE_BINARY | |
raise(WebSocketClient::Error, "received binary data, which is not supported") | |
when OPCODE_CLOSE | |
close(1005, "", :peer) | |
return nil | |
when OPCODE_PING | |
raise(WebSocketClient::Error, "received ping, which is not supported") | |
when OPCODE_PONG | |
else | |
raise(WebSocketClient::Error, "received unknown opcode: %d" % opcode) | |
end | |
rescue EOFError | |
return nil | |
end | |
end | |
end | |
def host | |
return @header["host"] | |
end | |
def origin | |
case @web_socket_version | |
when "7", "8" | |
name = "sec-websocket-origin" | |
else | |
name = "origin" | |
end | |
if @header[name] | |
return @header[name] | |
else | |
raise(WebSocketClient::Error, "%s header is missing" % name) | |
end | |
end | |
def location | |
return "ws://#{self.host}#{@path}" | |
end | |
# Does closing handshake. | |
def close(code = 1005, reason = "", origin = :self) | |
if !@closing_started | |
case @web_socket_version | |
when "hixie-75", "hixie-76" | |
write("\xff\x00") | |
else | |
if code == 1005 | |
payload = "" | |
else | |
payload = [code].pack("n") + force_encoding(reason.dup(), "ASCII-8BIT") | |
end | |
send_frame(OPCODE_CLOSE, payload, false) | |
end | |
end | |
@socket.close() if origin == :peer | |
@closing_started = true | |
end | |
def close_socket() | |
@socket.close() | |
end | |
private | |
NOISE_CHARS = ("\x21".."\x2f").to_a() + ("\x3a".."\x7e").to_a() | |
def read_header() | |
@header = {} | |
while line = gets() | |
line = line.chomp() | |
break if line.empty? | |
if !(line =~ /\A(\S+): (.*)\z/n) | |
raise(WebSocketClient::Error, "invalid request: #{line}") | |
end | |
@header[$1] = $2 | |
@header[$1.downcase()] = $2 | |
end | |
if !@header["upgrade"] | |
raise(WebSocketClient::Error, "Upgrade header is missing") | |
end | |
if !(@header["upgrade"] =~ /\AWebSocket\z/i) | |
raise(WebSocketClient::Error, "invalid Upgrade: " + @header["upgrade"]) | |
end | |
if !@header["connection"] | |
raise(WebSocketClient::Error, "Connection header is missing") | |
end | |
if @header["connection"].split(/,/).grep(/\A\s*Upgrade\s*\z/i).empty? | |
raise(WebSocketClient::Error, "invalid Connection: " + @header["connection"]) | |
end | |
end | |
def send_frame(opcode, payload, mask) | |
payload = force_encoding(payload.dup(), "ASCII-8BIT") | |
# Setting StringIO's encoding to ASCII-8BIT. | |
buffer = StringIO.new(force_encoding("", "ASCII-8BIT")) | |
write_byte(buffer, 0x80 | opcode) | |
masked_byte = mask ? 0x80 : 0x00 | |
if payload.bytesize <= 125 | |
write_byte(buffer, masked_byte | payload.bytesize) | |
elsif payload.bytesize < 2 ** 16 | |
write_byte(buffer, masked_byte | 126) | |
buffer.write([payload.bytesize].pack("n")) | |
else | |
write_byte(buffer, masked_byte | 127) | |
buffer.write([payload.bytesize / (2 ** 32), payload.bytesize % (2 ** 32)].pack("NN")) | |
end | |
if mask | |
mask_key = Array.new(4){ rand(256) } | |
buffer.write(mask_key.pack("C*")) | |
payload = apply_mask(payload, mask_key) | |
end | |
buffer.write(payload) | |
write(buffer.string) | |
end | |
def gets(rs = $/) | |
line = @socket.gets(rs) | |
$stderr.printf("recv> %p\n", line) if WebSocketClient.debug | |
return line | |
end | |
def read(num_bytes) | |
str = @socket.read(num_bytes) | |
$stderr.printf("recv> %p\n", str) if WebSocketClient.debug | |
if str && str.bytesize == num_bytes | |
return str | |
else | |
raise(EOFError) | |
end | |
end | |
def write(data) | |
if WebSocketClient.debug | |
data.scan(/\G(.*?(\n|\z))/n) do | |
$stderr.printf("send> %p\n", $&) if !$&.empty? | |
end | |
end | |
@socket.write(data) | |
end | |
def flush() | |
@socket.flush() | |
end | |
def write_byte(buffer, byte) | |
buffer.write([byte].pack("C")) | |
end | |
def security_digest(key) | |
return Base64.encode64(Digest::SHA1.digest(key + WEB_SOCKET_GUID)).gsub(/\n/, "") | |
end | |
def hixie_76_security_digest(key1, key2, key3) | |
bytes1 = websocket_key_to_bytes(key1) | |
bytes2 = websocket_key_to_bytes(key2) | |
return Digest::MD5.digest(bytes1 + bytes2 + key3) | |
end | |
def apply_mask(payload, mask_key) | |
orig_bytes = payload.unpack("C*") | |
new_bytes = [] | |
orig_bytes.each_with_index() do |b, i| | |
new_bytes.push(b ^ mask_key[i % 4]) | |
end | |
return new_bytes.pack("C*") | |
end | |
def generate_key() | |
spaces = 1 + rand(12) | |
max = 0xffffffff / spaces | |
number = rand(max + 1) | |
key = (number * spaces).to_s() | |
(1 + rand(12)).times() do | |
char = NOISE_CHARS[rand(NOISE_CHARS.size)] | |
pos = rand(key.size + 1) | |
key[pos...pos] = char | |
end | |
spaces.times() do | |
pos = 1 + rand(key.size - 1) | |
key[pos...pos] = " " | |
end | |
return key | |
end | |
def generate_key3() | |
return [rand(0x100000000)].pack("N") + [rand(0x100000000)].pack("N") | |
end | |
def websocket_key_to_bytes(key) | |
num = key.gsub(/[^\d]/n, "").to_i() / key.scan(/ /).size | |
return [num].pack("N") | |
end | |
def force_encoding(str, encoding) | |
if str.respond_to?(:force_encoding) | |
return str.force_encoding(encoding) | |
else | |
return str | |
end | |
end | |
def ssl_handshake(socket) | |
ssl_context = OpenSSL::SSL::SSLContext.new() | |
ssl_socket = OpenSSL::SSL::SSLSocket.new(socket, ssl_context) | |
ssl_socket.sync_close = true | |
ssl_socket.connect() | |
return ssl_socket | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment