Skip to content

Instantly share code, notes, and snippets.

@samuela
Last active March 14, 2024 20:10
Show Gist options
  • Save samuela/a7b20be5019b359e358457bfcc4a3df1 to your computer and use it in GitHub Desktop.
Save samuela/a7b20be5019b359e358457bfcc4a3df1 to your computer and use it in GitHub Desktop.
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use pty_process::OwnedWritePty;
use russh::server::{Auth, Msg, Session};
use russh::*;
use russh_keys::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::Mutex;
struct Server {}
struct ServerHandler {
pty_writers: Arc<Mutex<HashMap<ChannelId, OwnedWritePty>>>,
pty_requested_sizes: Arc<Mutex<HashMap<ChannelId, (u32, u32)>>>,
}
impl server::Server for Server {
type Handler = ServerHandler;
fn new_client(&mut self, _: Option<std::net::SocketAddr>) -> ServerHandler {
log::info!("new client");
ServerHandler {
pty_writers: Arc::new(Mutex::new(HashMap::new())),
pty_requested_sizes: Arc::new(Mutex::new(HashMap::new())),
}
}
}
#[async_trait]
impl server::Handler for ServerHandler {
type Error = anyhow::Error;
async fn channel_open_session(
self,
channel: Channel<Msg>,
session: Session,
) -> Result<(Self, bool, Session), Self::Error> {
log::info!("channel open session");
Ok((self, true, session))
}
async fn auth_publickey(self, user: &str, public_key: &key::PublicKey) -> Result<(Self, Auth), Self::Error> {
log::info!("auth_publickey: user: {user} public_key: {public_key:?}");
Ok((self, server::Auth::Accept))
}
async fn pty_request(
self,
channel_id: ChannelId,
term: &str,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
modes: &[(russh::Pty, u32)],
session: Session,
) -> Result<(Self, Session), Self::Error> {
log::info!("pty_request, channel_id: {channel_id}, term: {term}, col_width: {col_width}, row_height: {row_height}, pix_width: {pix_width}, pix_height: {pix_height}, modes: {modes:?}");
// TODO: We're currently ignoring the requested modes. We should probably do something with them.
// We save these values so we can resize the pty later in shell_request. Most clients seem to send a pty_request
// before a shell_request.
{
let mut pty_requested_sizes = self.pty_requested_sizes.lock().await;
pty_requested_sizes.insert(channel_id, (col_width, row_height));
}
Ok((self, session))
}
async fn shell_request(self, channel_id: ChannelId, session: Session) -> Result<(Self, Session), Self::Error> {
log::info!("shell_request");
let pty = pty_process::Pty::new().unwrap();
// NOTE: we must get the pts before `.into_split()` because it consumes the pty.
let pts = pty.pts().unwrap();
// split pty into reader + writer
let (mut pty_reader, pty_writer) = pty.into_split();
self.pty_writers.lock().await.insert(channel_id, pty_writer);
// Read bytes from the PTY and send them to the SSH client
let session_handle = session.handle();
tokio::spawn(async move {
let mut buffer = vec![0; 1024];
while let Ok(size) = pty_reader.read(&mut buffer).await {
if size == 0 {
break;
}
session_handle
.data(channel_id, CryptoVec::from_slice(&buffer[0..size]))
.await
.unwrap();
}
});
// Spawn a new /bin/bash process in pty
let mut child = pty_process::Command::new("/bin/bash").spawn(&pts).unwrap();
// Close the channel when child exits
let handle = session.handle();
let channel_pty_writers_ = Arc::clone(&self.pty_writers);
let pty_requested_sizes_ = Arc::clone(&self.pty_requested_sizes);
tokio::spawn(async move {
let status = child.wait().await.unwrap().code().unwrap();
handle
.data(channel_id, CryptoVec::from(format!("Exit status: {status}\r\n")))
.await
.unwrap();
handle
.exit_status_request(channel_id, status.try_into().unwrap_or(1))
.await
.unwrap();
handle.eof(channel_id).await.unwrap();
handle.close(channel_id).await.unwrap();
// Clean up things from our pty_writers and pty_requested_sizes `HashMap`s
channel_pty_writers_.lock().await.remove(&channel_id);
pty_requested_sizes_.lock().await.remove(&channel_id);
});
// Can't resize the pty until after the child is spawned on macOS. See https://github.com/doy/pty-process/issues/7#issuecomment-1826196215 and https://github.com/pkgw/stund/issues/305.
if let Some((col_width, row_height)) = self.pty_requested_sizes.lock().await.get(&channel_id) {
if let Err(e) = self
.pty_writers
.lock()
.await
.get(&channel_id)
.unwrap()
.resize(pty_process::Size::new(*row_height as u16, *col_width as u16))
{
log::error!("pty.resize failed: {:?}", e);
}
} else {
log::warn!("pty_requested_sizes doesn't contain channel_id: {channel_id}, skipping pty resize")
}
Ok((self, session))
}
async fn data(self, channel_id: ChannelId, data: &[u8], session: Session) -> Result<(Self, Session), Self::Error> {
// SSH client sends data, pipe it to the corresponding PTY
{
let mut pty_writers = self.pty_writers.lock().await;
if let Some(pty_writer) = pty_writers.get_mut(&channel_id) {
log::info!("pty_writer: data = {data:02x?}");
pty_writer.write_all(data).await.unwrap();
} else {
log::warn!("pty_writers doesn't contain channel_id: {channel_id}, skipping pty write")
}
}
Ok((self, session))
}
/// The client's pseudo-terminal window size has changed.
async fn window_change_request(
self,
channel_id: ChannelId,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
session: Session,
) -> Result<(Self, Session), Self::Error> {
log::info!("window_change_request channel_id = {channel_id:?} col_width = {col_width} row_height = {row_height}, pix_width = {pix_width}, pix_height = {pix_height}");
{
let mut pty_writers = self.pty_writers.lock().await;
if let Some(pty_writer) = pty_writers.get_mut(&channel_id) {
if let Err(e) = pty_writer.resize(pty_process::Size::new(row_height as u16, col_width as u16)) {
log::error!("pty.resize failed: {:?}", e);
}
} else {
log::warn!("pty_writers doesn't contain channel_id: {channel_id}, skipping pty resize")
}
}
Ok((self, session))
}
async fn channel_close(self, channel_id: ChannelId, session: Session) -> Result<(Self, Session), Self::Error> {
log::info!("channel_close channel_id = {channel_id:?}");
// Clean up things from our pty_writers and pty_requested_sizes `HashMap`s
self.pty_writers.lock().await.remove(&channel_id);
self.pty_requested_sizes.lock().await.remove(&channel_id);
Ok((self, session))
}
}
#[tokio::main]
async fn main() {
env_logger::builder().filter_level(log::LevelFilter::Debug).init();
let config = russh::server::Config {
inactivity_timeout: Some(std::time::Duration::from_secs(60 * 60)),
auth_rejection_time: std::time::Duration::from_secs(5),
auth_rejection_time_initial: Some(std::time::Duration::from_secs(0)),
keys: vec![
// TODO: only do this in dev
russh_keys::key::KeyPair::Ed25519(ed25519_dalek::SigningKey::from([0; 32])),
],
..Default::default()
};
log::info!("Listening on 0.0.0.0:2222");
russh::server::run(Arc::new(config), ("0.0.0.0", 2222), Server {})
.await
.unwrap();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment