Last active
July 31, 2024 03:29
-
-
Save SpencerPark/e2732061ad19c1afa4a33a58cb8f18a9 to your computer and use it in GitHub Desktop.
A little proxy kernel (and installer) that manages a wrapped kernel connected with tcp. It was designed to support the case where the server starts kernels with ipc transport but only tcp is supported (like IJava). See https://gist.github.com/SpencerPark/447de114fcd3e6a272dc140809462e30 for a sample notebook that installs this.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import json | |
import os | |
import os.path | |
import shutil | |
import sys | |
from jupyter_client.kernelspec import KernelSpec, KernelSpecManager, NoSuchKernel | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--kernel", type=str, required=True) | |
parser.add_argument("--implementation", type=str, required=True) | |
parser.add_argument("--quiet", action="store_true", default=False) | |
args = parser.parse_args() | |
def log(*log_args): | |
if not args.quiet: | |
print(*log_args) | |
kernel_spec_manager = KernelSpecManager() | |
try: | |
real_kernel_spec: KernelSpec = kernel_spec_manager.get_kernel_spec(args.kernel) | |
except NoSuchKernel: | |
print(f"No kernel installed with name {args.kernel}. Available kernels:") | |
for name, path in kernel_spec_manager.find_kernel_specs().items(): | |
print(f" - {name}\t{path}") | |
exit(1) | |
log(f"Moving {args.kernel} kernel from {real_kernel_spec.resource_dir}...") | |
real_kernel_install_path = real_kernel_spec.resource_dir | |
new_kernel_name = f"{args.kernel}_tcp" | |
new_kernel_install_path = os.path.join( | |
os.path.dirname(real_kernel_install_path), new_kernel_name | |
) | |
shutil.move(real_kernel_install_path, new_kernel_install_path) | |
# Update the moved kernel name and args. We tag it _tcp because the proxy will | |
# impersonate it and should be the one using the real name. | |
new_kernel_json_path = os.path.join(new_kernel_install_path, "kernel.json") | |
with open(new_kernel_json_path, "r") as in_: | |
real_kernel_json = json.load(in_) | |
real_kernel_json["name"] = new_kernel_name | |
real_kernel_json["argv"] = list( | |
map( | |
lambda arg: arg.replace(real_kernel_install_path, new_kernel_install_path), | |
real_kernel_json["argv"], | |
) | |
) | |
with open(new_kernel_json_path, "w") as out: | |
json.dump(real_kernel_json, out) | |
log(f"Wrote modified kernel.json for {new_kernel_name} in {new_kernel_json_path}") | |
log( | |
f"Installing the proxy kernel in place of {args.kernel} in {real_kernel_install_path}" | |
) | |
os.makedirs(real_kernel_install_path) | |
proxy_kernel_implementation_path = os.path.join( | |
real_kernel_install_path, "ipc_proxy_kernel.py" | |
) | |
proxy_kernel_spec = KernelSpec() | |
proxy_kernel_spec.argv = [ | |
sys.executable, | |
proxy_kernel_implementation_path, | |
"{connection_file}", | |
f"--kernel={new_kernel_name}", | |
] | |
proxy_kernel_spec.display_name = real_kernel_spec.display_name | |
proxy_kernel_spec.interrupt_mode = real_kernel_spec.interrupt_mode or "message" | |
proxy_kernel_spec.language = real_kernel_spec.language | |
proxy_kernel_json_path = os.path.join(real_kernel_install_path, "kernel.json") | |
with open(proxy_kernel_json_path, "w") as out: | |
json.dump(proxy_kernel_spec.to_dict(), out, indent=2) | |
log(f"Installed proxy kernelspec: {proxy_kernel_spec.to_json()}") | |
shutil.copy(args.implementation, proxy_kernel_implementation_path) | |
print("Proxy kernel installed. Go to 'Runtime > Change runtime type' and select 'java'") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import json | |
from threading import Thread | |
import zmq | |
from jupyter_client import KernelClient | |
from jupyter_client.channels import HBChannel | |
from jupyter_client.manager import KernelManager | |
from jupyter_client.session import Session | |
from traitlets.traitlets import Type | |
parser = argparse.ArgumentParser() | |
parser.add_argument("connection_file") | |
parser.add_argument("--kernel", type=str, required=True) | |
args = parser.parse_args() | |
# parse connection file details | |
with open(args.connection_file, "r") as connection_file: | |
connection_file_contents = json.load(connection_file) | |
transport = str(connection_file_contents["transport"]) | |
ip = str(connection_file_contents["ip"]) | |
shell_port = int(connection_file_contents["shell_port"]) | |
stdin_port = int(connection_file_contents["stdin_port"]) | |
control_port = int(connection_file_contents["control_port"]) | |
iopub_port = int(connection_file_contents["iopub_port"]) | |
hb_port = int(connection_file_contents["hb_port"]) | |
signature_scheme = str(connection_file_contents["signature_scheme"]) | |
key = str(connection_file_contents["key"]).encode() | |
# channel | kernel_type | client_type | |
# shell | ROUTER | DEALER | |
# stdin | ROUTER | DEALER | |
# ctrl | ROUTER | DEALER | |
# iopub | PUB | SUB | |
# hb | REP | REQ | |
zmq_context = zmq.Context() | |
def create_and_bind_socket(port: int, socket_type: int): | |
if port <= 0: | |
raise ValueError(f"Invalid port: {port}") | |
if transport == "tcp": | |
addr = f"tcp://{ip}:{port}" | |
elif transport == "ipc": | |
addr = f"ipc://{ip}-{port}" | |
else: | |
raise ValueError(f"Unknown transport: {transport}") | |
socket: zmq.Socket = zmq_context.socket(socket_type) | |
socket.linger = 1000 # ipykernel does this | |
socket.bind(addr) | |
return socket | |
shell_socket = create_and_bind_socket(shell_port, zmq.ROUTER) | |
stdin_socket = create_and_bind_socket(stdin_port, zmq.ROUTER) | |
control_socket = create_and_bind_socket(control_port, zmq.ROUTER) | |
iopub_socket = create_and_bind_socket(iopub_port, zmq.PUB) | |
hb_socket = create_and_bind_socket(hb_port, zmq.REP) | |
# Proxy and the real kernel have their own heartbeats. (shoutout to ipykernel | |
# for this neat little heartbeat implementation) | |
Thread(target=zmq.device, args=(zmq.QUEUE, hb_socket, hb_socket)).start() | |
def ZMQProxyChannel_factory(proxy_server_socket: zmq.Socket): | |
class ZMQProxyChannel(object): | |
kernel_client_socket: zmq.Socket = None | |
session: Session = None | |
def __init__(self, socket: zmq.Socket, session: Session, _=None): | |
super().__init__() | |
self.kernel_client_socket = socket | |
self.session = session | |
def start(self): | |
# Very convenient zmq device here, proxy will handle the actual zmq | |
# proxying on each of our connected sockets (other than heartbeat). | |
# It blocks while they are connected so stick it in a thread. | |
Thread( | |
target=zmq.proxy, | |
args=(proxy_server_socket, self.kernel_client_socket), | |
).start() | |
def stop(self): | |
if self.kernel_client_socket is not None: | |
try: | |
self.kernel_client_socket.close(linger=0) | |
except Exception: | |
pass | |
self.kernel_client_socket = None | |
def is_alive(self): | |
return self.kernel_client_socket is not None | |
return ZMQProxyChannel | |
class ProxyKernelClient(KernelClient): | |
shell_channel_class = Type(ZMQProxyChannel_factory(shell_socket)) | |
stdin_channel_class = Type(ZMQProxyChannel_factory(stdin_socket)) | |
control_channel_class = Type(ZMQProxyChannel_factory(control_socket)) | |
iopub_channel_class = Type(ZMQProxyChannel_factory(iopub_socket)) | |
hb_channel_class = Type(HBChannel) | |
kernel_manager = KernelManager() | |
kernel_manager.kernel_name = args.kernel | |
kernel_manager.transport = "tcp" | |
kernel_manager.client_factory = ProxyKernelClient | |
kernel_manager.autorestart = False | |
# Make sure the wrapped kernel uses the same session info. This way we don't | |
# need to decode them before forwarding, we can directly pass everything | |
# through. | |
kernel_manager.session.signature_scheme = signature_scheme | |
kernel_manager.session.key = key | |
kernel_manager.start_kernel() | |
# Connect to the real kernel we just started and start up all the proxies. | |
kernel_client: ProxyKernelClient = kernel_manager.client() | |
kernel_client.start_channels() | |
# Everything should be up and running. We now just wait for the managed kernel | |
# process to exit and when that happens, shutdown and exit with the same code. | |
exit_code = kernel_manager.kernel.wait() | |
kernel_client.stop_channels() | |
zmq_context.destroy(0) | |
exit(exit_code) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment