Last active
March 14, 2024 20:10
-
-
Save samuela/a7b20be5019b359e358457bfcc4a3df1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::HashMap; | |
use std::sync::Arc; | |
use async_trait::async_trait; | |
use pty_process::OwnedWritePty; | |
use russh::server::{Auth, Msg, Session}; | |
use russh::*; | |
use russh_keys::*; | |
use tokio::io::{AsyncReadExt, AsyncWriteExt}; | |
use tokio::sync::Mutex; | |
struct Server {} | |
struct ServerHandler { | |
pty_writers: Arc<Mutex<HashMap<ChannelId, OwnedWritePty>>>, | |
pty_requested_sizes: Arc<Mutex<HashMap<ChannelId, (u32, u32)>>>, | |
} | |
impl server::Server for Server { | |
type Handler = ServerHandler; | |
fn new_client(&mut self, _: Option<std::net::SocketAddr>) -> ServerHandler { | |
log::info!("new client"); | |
ServerHandler { | |
pty_writers: Arc::new(Mutex::new(HashMap::new())), | |
pty_requested_sizes: Arc::new(Mutex::new(HashMap::new())), | |
} | |
} | |
} | |
#[async_trait] | |
impl server::Handler for ServerHandler { | |
type Error = anyhow::Error; | |
async fn channel_open_session( | |
self, | |
channel: Channel<Msg>, | |
session: Session, | |
) -> Result<(Self, bool, Session), Self::Error> { | |
log::info!("channel open session"); | |
Ok((self, true, session)) | |
} | |
async fn auth_publickey(self, user: &str, public_key: &key::PublicKey) -> Result<(Self, Auth), Self::Error> { | |
log::info!("auth_publickey: user: {user} public_key: {public_key:?}"); | |
Ok((self, server::Auth::Accept)) | |
} | |
async fn pty_request( | |
self, | |
channel_id: ChannelId, | |
term: &str, | |
col_width: u32, | |
row_height: u32, | |
pix_width: u32, | |
pix_height: u32, | |
modes: &[(russh::Pty, u32)], | |
session: Session, | |
) -> Result<(Self, Session), Self::Error> { | |
log::info!("pty_request, channel_id: {channel_id}, term: {term}, col_width: {col_width}, row_height: {row_height}, pix_width: {pix_width}, pix_height: {pix_height}, modes: {modes:?}"); | |
// TODO: We're currently ignoring the requested modes. We should probably do something with them. | |
// We save these values so we can resize the pty later in shell_request. Most clients seem to send a pty_request | |
// before a shell_request. | |
{ | |
let mut pty_requested_sizes = self.pty_requested_sizes.lock().await; | |
pty_requested_sizes.insert(channel_id, (col_width, row_height)); | |
} | |
Ok((self, session)) | |
} | |
async fn shell_request(self, channel_id: ChannelId, session: Session) -> Result<(Self, Session), Self::Error> { | |
log::info!("shell_request"); | |
let pty = pty_process::Pty::new().unwrap(); | |
// NOTE: we must get the pts before `.into_split()` because it consumes the pty. | |
let pts = pty.pts().unwrap(); | |
// split pty into reader + writer | |
let (mut pty_reader, pty_writer) = pty.into_split(); | |
self.pty_writers.lock().await.insert(channel_id, pty_writer); | |
// Read bytes from the PTY and send them to the SSH client | |
let session_handle = session.handle(); | |
tokio::spawn(async move { | |
let mut buffer = vec![0; 1024]; | |
while let Ok(size) = pty_reader.read(&mut buffer).await { | |
if size == 0 { | |
break; | |
} | |
session_handle | |
.data(channel_id, CryptoVec::from_slice(&buffer[0..size])) | |
.await | |
.unwrap(); | |
} | |
}); | |
// Spawn a new /bin/bash process in pty | |
let mut child = pty_process::Command::new("/bin/bash").spawn(&pts).unwrap(); | |
// Close the channel when child exits | |
let handle = session.handle(); | |
let channel_pty_writers_ = Arc::clone(&self.pty_writers); | |
let pty_requested_sizes_ = Arc::clone(&self.pty_requested_sizes); | |
tokio::spawn(async move { | |
let status = child.wait().await.unwrap().code().unwrap(); | |
handle | |
.data(channel_id, CryptoVec::from(format!("Exit status: {status}\r\n"))) | |
.await | |
.unwrap(); | |
handle | |
.exit_status_request(channel_id, status.try_into().unwrap_or(1)) | |
.await | |
.unwrap(); | |
handle.eof(channel_id).await.unwrap(); | |
handle.close(channel_id).await.unwrap(); | |
// Clean up things from our pty_writers and pty_requested_sizes `HashMap`s | |
channel_pty_writers_.lock().await.remove(&channel_id); | |
pty_requested_sizes_.lock().await.remove(&channel_id); | |
}); | |
// Can't resize the pty until after the child is spawned on macOS. See https://github.com/doy/pty-process/issues/7#issuecomment-1826196215 and https://github.com/pkgw/stund/issues/305. | |
if let Some((col_width, row_height)) = self.pty_requested_sizes.lock().await.get(&channel_id) { | |
if let Err(e) = self | |
.pty_writers | |
.lock() | |
.await | |
.get(&channel_id) | |
.unwrap() | |
.resize(pty_process::Size::new(*row_height as u16, *col_width as u16)) | |
{ | |
log::error!("pty.resize failed: {:?}", e); | |
} | |
} else { | |
log::warn!("pty_requested_sizes doesn't contain channel_id: {channel_id}, skipping pty resize") | |
} | |
Ok((self, session)) | |
} | |
async fn data(self, channel_id: ChannelId, data: &[u8], session: Session) -> Result<(Self, Session), Self::Error> { | |
// SSH client sends data, pipe it to the corresponding PTY | |
{ | |
let mut pty_writers = self.pty_writers.lock().await; | |
if let Some(pty_writer) = pty_writers.get_mut(&channel_id) { | |
log::info!("pty_writer: data = {data:02x?}"); | |
pty_writer.write_all(data).await.unwrap(); | |
} else { | |
log::warn!("pty_writers doesn't contain channel_id: {channel_id}, skipping pty write") | |
} | |
} | |
Ok((self, session)) | |
} | |
/// The client's pseudo-terminal window size has changed. | |
async fn window_change_request( | |
self, | |
channel_id: ChannelId, | |
col_width: u32, | |
row_height: u32, | |
pix_width: u32, | |
pix_height: u32, | |
session: Session, | |
) -> Result<(Self, Session), Self::Error> { | |
log::info!("window_change_request channel_id = {channel_id:?} col_width = {col_width} row_height = {row_height}, pix_width = {pix_width}, pix_height = {pix_height}"); | |
{ | |
let mut pty_writers = self.pty_writers.lock().await; | |
if let Some(pty_writer) = pty_writers.get_mut(&channel_id) { | |
if let Err(e) = pty_writer.resize(pty_process::Size::new(row_height as u16, col_width as u16)) { | |
log::error!("pty.resize failed: {:?}", e); | |
} | |
} else { | |
log::warn!("pty_writers doesn't contain channel_id: {channel_id}, skipping pty resize") | |
} | |
} | |
Ok((self, session)) | |
} | |
async fn channel_close(self, channel_id: ChannelId, session: Session) -> Result<(Self, Session), Self::Error> { | |
log::info!("channel_close channel_id = {channel_id:?}"); | |
// Clean up things from our pty_writers and pty_requested_sizes `HashMap`s | |
self.pty_writers.lock().await.remove(&channel_id); | |
self.pty_requested_sizes.lock().await.remove(&channel_id); | |
Ok((self, session)) | |
} | |
} | |
#[tokio::main] | |
async fn main() { | |
env_logger::builder().filter_level(log::LevelFilter::Debug).init(); | |
let config = russh::server::Config { | |
inactivity_timeout: Some(std::time::Duration::from_secs(60 * 60)), | |
auth_rejection_time: std::time::Duration::from_secs(5), | |
auth_rejection_time_initial: Some(std::time::Duration::from_secs(0)), | |
keys: vec![ | |
// TODO: only do this in dev | |
russh_keys::key::KeyPair::Ed25519(ed25519_dalek::SigningKey::from([0; 32])), | |
], | |
..Default::default() | |
}; | |
log::info!("Listening on 0.0.0.0:2222"); | |
russh::server::run(Arc::new(config), ("0.0.0.0", 2222), Server {}) | |
.await | |
.unwrap(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment