Forked from SpencerPark/install_ipc_proxy_kernel.py
Last active
June 16, 2024 08:33
-
-
Save wiseaidev/bc102165f43db4ebd84fcdb4c5bfb129 to your computer and use it in GitHub Desktop.
A little proxy kernel (and installer) that manages a wrapped kernel connected with tcp. It was designed to support the case where the server starts kernels with ipc transport but only tcp is supported (like Rust).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import json | |
import os | |
import os.path | |
import shutil | |
import sys | |
from jupyter_client.kernelspec import (KernelSpec, KernelSpecManager, | |
NoSuchKernel) | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--kernel", type=str, required=True) | |
parser.add_argument("--implementation", type=str, required=True) | |
parser.add_argument("--quiet", action="store_true", default=False) | |
args = parser.parse_args() | |
def log(*log_args): | |
if not args.quiet: | |
print(*log_args) | |
kernel_spec_manager = KernelSpecManager() | |
try: | |
real_kernel_spec: KernelSpec = kernel_spec_manager.get_kernel_spec(args.kernel) | |
except NoSuchKernel: | |
print(f"No kernel installed with name {args.kernel}. Available kernels:") | |
for name, path in kernel_spec_manager.find_kernel_specs().items(): | |
print(f" - {name}\t{path}") | |
exit(1) | |
log(f"Moving {args.kernel} kernel from {real_kernel_spec.resource_dir}...") | |
real_kernel_install_path = real_kernel_spec.resource_dir | |
new_kernel_name = f"{args.kernel}_tcp" | |
new_kernel_install_path = os.path.join( | |
os.path.dirname(real_kernel_install_path), new_kernel_name | |
) | |
shutil.move(real_kernel_install_path, new_kernel_install_path) | |
# Update the moved kernel name and args. We tag it _tcp because the proxy will | |
# impersonate it and should be the one using the real name. | |
new_kernel_json_path = os.path.join(new_kernel_install_path, "kernel.json") | |
with open(new_kernel_json_path, "r") as in_: | |
real_kernel_json = json.load(in_) | |
real_kernel_json["name"] = new_kernel_name | |
real_kernel_json["argv"] = list( | |
map( | |
lambda arg: arg.replace(real_kernel_install_path, new_kernel_install_path), | |
real_kernel_json["argv"], | |
) | |
) | |
with open(new_kernel_json_path, "w") as out: | |
json.dump(real_kernel_json, out) | |
log(f"Wrote modified kernel.json for {new_kernel_name} in {new_kernel_json_path}") | |
log( | |
f"Installing the proxy kernel in place of {args.kernel} in {real_kernel_install_path}" | |
) | |
os.makedirs(real_kernel_install_path) | |
proxy_kernel_implementation_path = os.path.join( | |
real_kernel_install_path, "ipc_proxy_kernel.py" | |
) | |
proxy_kernel_spec = KernelSpec() | |
proxy_kernel_spec.argv = [ | |
sys.executable, | |
proxy_kernel_implementation_path, | |
"{connection_file}", | |
f"--kernel={new_kernel_name}", | |
] | |
proxy_kernel_spec.display_name = real_kernel_spec.display_name | |
proxy_kernel_spec.interrupt_mode = real_kernel_spec.interrupt_mode or "message" | |
proxy_kernel_spec.language = real_kernel_spec.language | |
proxy_kernel_json_path = os.path.join(real_kernel_install_path, "kernel.json") | |
with open(proxy_kernel_json_path, "w") as out: | |
json.dump(proxy_kernel_spec.to_dict(), out, indent=2) | |
log(f"Installed proxy kernelspec: {proxy_kernel_spec.to_json()}") | |
shutil.copy(args.implementation, proxy_kernel_implementation_path) | |
print("Proxy kernel installed. Go to 'Runtime > Change runtime type' and select 'Rust'") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import json | |
from threading import Thread | |
import zmq | |
from jupyter_client import KernelClient | |
from jupyter_client.channels import HBChannel | |
from jupyter_client.manager import KernelManager | |
from jupyter_client.session import Session | |
from traitlets.traitlets import Type | |
parser = argparse.ArgumentParser() | |
parser.add_argument("connection_file") | |
parser.add_argument("--kernel", type=str, required=True) | |
args = parser.parse_args() | |
# parse connection file details | |
with open(args.connection_file, "r") as connection_file: | |
connection_file_contents = json.load(connection_file) | |
transport = str(connection_file_contents["transport"]) | |
ip = str(connection_file_contents["ip"]) | |
shell_port = int(connection_file_contents["shell_port"]) | |
stdin_port = int(connection_file_contents["stdin_port"]) | |
control_port = int(connection_file_contents["control_port"]) | |
iopub_port = int(connection_file_contents["iopub_port"]) | |
hb_port = int(connection_file_contents["hb_port"]) | |
signature_scheme = str(connection_file_contents["signature_scheme"]) | |
key = str(connection_file_contents["key"]).encode() | |
# channel | kernel_type | client_type | |
# shell | ROUTER | DEALER | |
# stdin | ROUTER | DEALER | |
# ctrl | ROUTER | DEALER | |
# iopub | PUB | SUB | |
# hb | REP | REQ | |
zmq_context = zmq.Context() | |
def create_and_bind_socket(port: int, socket_type: int): | |
if port <= 0: | |
raise ValueError(f"Invalid port: {port}") | |
if transport == "tcp": | |
addr = f"tcp://{ip}:{port}" | |
elif transport == "ipc": | |
addr = f"ipc://{ip}-{port}" | |
else: | |
raise ValueError(f"Unknown transport: {transport}") | |
socket: zmq.Socket = zmq_context.socket(socket_type) | |
socket.linger = 1000 # ipykernel does this | |
socket.bind(addr) | |
return socket | |
shell_socket = create_and_bind_socket(shell_port, zmq.ROUTER) | |
stdin_socket = create_and_bind_socket(stdin_port, zmq.ROUTER) | |
control_socket = create_and_bind_socket(control_port, zmq.ROUTER) | |
iopub_socket = create_and_bind_socket(iopub_port, zmq.PUB) | |
hb_socket = create_and_bind_socket(hb_port, zmq.REP) | |
# Proxy and the real kernel have their own heartbeats. (shoutout to ipykernel | |
# for this neat little heartbeat implementation) | |
Thread(target=zmq.device, args=(zmq.QUEUE, hb_socket, hb_socket)).start() | |
def ZMQProxyChannel_factory(proxy_server_socket: zmq.Socket): | |
class ZMQProxyChannel(object): | |
kernel_client_socket: zmq.Socket = None | |
session: Session = None | |
def __init__(self, socket: zmq.Socket, session: Session, _=None): | |
super().__init__() | |
self.kernel_client_socket = socket | |
self.session = session | |
def start(self): | |
# Very convenient zmq device here, proxy will handle the actual zmq | |
# proxying on each of our connected sockets (other than heartbeat). | |
# It blocks while they are connected so stick it in a thread. | |
Thread( | |
target=zmq.proxy, | |
args=(proxy_server_socket, self.kernel_client_socket), | |
).start() | |
def stop(self): | |
if self.kernel_client_socket is not None: | |
try: | |
self.kernel_client_socket.close(linger=0) | |
except Exception: | |
pass | |
self.kernel_client_socket = None | |
def is_alive(self): | |
return self.kernel_client_socket is not None | |
return ZMQProxyChannel | |
class ProxyKernelClient(KernelClient): | |
shell_channel_class = Type(ZMQProxyChannel_factory(shell_socket)) | |
stdin_channel_class = Type(ZMQProxyChannel_factory(stdin_socket)) | |
control_channel_class = Type(ZMQProxyChannel_factory(control_socket)) | |
iopub_channel_class = Type(ZMQProxyChannel_factory(iopub_socket)) | |
hb_channel_class = Type(HBChannel) | |
kernel_manager = KernelManager() | |
kernel_manager.kernel_name = args.kernel | |
kernel_manager.transport = "tcp" | |
kernel_manager.client_factory = ProxyKernelClient | |
kernel_manager.autorestart = False | |
# Make sure the wrapped kernel uses the same session info. This way we don't | |
# need to decode them before forwarding, we can directly pass everything | |
# through. | |
kernel_manager.session.signature_scheme = signature_scheme | |
kernel_manager.session.key = key | |
kernel_manager.start_kernel() | |
# Connect to the real kernel we just started and start up all the proxies. | |
kernel_client: ProxyKernelClient = kernel_manager.client() | |
kernel_client.start_channels() | |
# Everything should be up and running. We now just wait for the managed kernel | |
# process to exit and when that happens, shutdown and exit with the same code. | |
exit_code = kernel_manager.kernel.wait() | |
kernel_client.stop_channels() | |
zmq_context.destroy(0) | |
exit(exit_code) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment