Created
November 25, 2023 05:16
-
-
Save samuela/ecfbb04162f3697b2cd634dae7015be0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use async_trait::async_trait; | |
use core::panic; | |
use russh::server::{Msg, Session}; | |
use russh::{Channel, ChannelId, ChannelStream, CryptoVec, Sig}; | |
use std::collections::HashMap; | |
use std::io::Read; | |
use std::net::SocketAddr; | |
use std::sync::Arc; | |
use tokio::io::AsyncReadExt; | |
use tokio::sync::Mutex; | |
#[tokio::main] | |
async fn main() { | |
// Default logging level does not show debug, info, etc. | |
env_logger::builder().filter_level(log::LevelFilter::Info).init(); | |
let config = russh::server::Config { | |
inactivity_timeout: Some(std::time::Duration::from_secs(60 * 60)), | |
auth_rejection_time: std::time::Duration::from_secs(5), | |
auth_rejection_time_initial: Some(std::time::Duration::from_secs(0)), | |
keys: vec![ | |
// TODO: only do this in dev | |
russh_keys::key::KeyPair::Ed25519(ed25519_dalek::SigningKey::from([0; 32])), | |
], | |
..Default::default() | |
}; | |
log::info!("Listening on 0.0.0.0:2222"); | |
russh::server::run(Arc::new(config), ("0.0.0.0", 2222), Server {}) | |
.await | |
.unwrap(); | |
} | |
#[derive(Clone)] | |
struct Server {} | |
// TODO: can we make this an enum such that the state of the Session is captured by the variants of the enum? | |
struct ServerHandler { | |
client_addr: SocketAddr, | |
username: Option<String>, | |
channel_streams: Arc<Mutex<HashMap<ChannelId, Arc<Mutex<ChannelStream>>>>>, | |
} | |
impl russh::server::Server for Server { | |
type Handler = ServerHandler; | |
fn new_client(&mut self, addr: Option<SocketAddr>) -> ServerHandler { | |
// TODO(docs): when is addr None? | |
log::info!("new client"); | |
let Some(a) = addr else { todo!() }; | |
ServerHandler { | |
client_addr: a, | |
username: None, | |
channel_streams: Arc::new(Mutex::new(HashMap::new())), | |
} | |
} | |
} | |
#[async_trait] | |
impl russh::server::Handler for ServerHandler { | |
type Error = anyhow::Error; | |
async fn channel_open_session( | |
self, | |
channel: Channel<Msg>, | |
session: Session, | |
) -> Result<(Self, bool, Session), Self::Error> { | |
log::info!("channel_open_session, channel_id: {}", channel.id()); | |
{ | |
let mut channel_streams = self.channel_streams.lock().await; | |
channel_streams.insert(channel.id(), Arc::new(Mutex::new(channel.into_stream()))); | |
} | |
// TODO(docs): what do these return values really mean? | |
Ok((self, true, session)) | |
} | |
async fn auth_publickey( | |
self, | |
username: &str, | |
public_key: &russh_keys::key::PublicKey, | |
) -> Result<(Self, russh::server::Auth), Self::Error> { | |
log::info!("auth_publickey, username: {:?}", username); | |
Ok(( | |
Self { | |
username: Some(username.into()), | |
..self | |
}, | |
russh::server::Auth::Accept, | |
)) | |
} | |
async fn shell_request(self, channel_id: ChannelId, session: Session) -> Result<(Self, Session), Self::Error> { | |
log::info!( | |
"shell_request, username: {:?}, channel_id: {}", | |
self.username, | |
channel_id | |
); | |
let pty = pty_process::blocking::Pty::new().unwrap(); | |
// TODO: get these values from the client, also use new_with_pixel if possible | |
if let Err(e) = pty.resize(pty_process::Size::new(24, 80)) { | |
// See https://github.com/doy/pty-process/issues/7#issuecomment-1826196215. | |
log::error!("pty.resize failed: {:?}", e); | |
} | |
let mut child = pty_process::blocking::Command::new("ranger") | |
.spawn(&pty.pts().unwrap()) | |
.unwrap(); | |
// Read output from pty and send it to the client | |
let handle = session.handle(); | |
tokio::spawn(async move { | |
let mut reader = std::io::BufReader::new(&pty); | |
let mut buffer = vec![0; 1024]; | |
while let Ok(n) = reader.read(&mut buffer) { | |
if n == 0 { | |
break; | |
} | |
handle | |
.data(channel_id, CryptoVec::from_slice(&buffer[..n])) | |
.await | |
.unwrap(); | |
} | |
}); | |
// Close the channel when child exits | |
let handle = session.handle(); | |
tokio::spawn(async move { | |
let status = child.wait().unwrap().code().unwrap(); | |
handle | |
.data(channel_id, CryptoVec::from(format!("Exit status: {}\r\n", status))) | |
.await | |
.unwrap(); | |
handle | |
.exit_status_request(channel_id, status.try_into().unwrap_or(1)) | |
.await | |
.unwrap(); | |
handle.eof(channel_id).await.unwrap(); | |
handle.close(channel_id).await.unwrap(); | |
}); | |
// Watch the shell Channel for data from the client and send it to the pty | |
// NOTE: I'm reading off of the Channel, but it's not clear if this is the correct thing to do in order to receive | |
// data from the client. | |
let channel_streams_arc = Arc::clone(&self.channel_streams); | |
tokio::spawn(async move { | |
let mut channel_streams = channel_streams_arc.lock().await; | |
let mut stream = channel_streams.get_mut(&channel_id).unwrap().lock().await; | |
let mut buffer = vec![0; 1024]; | |
while let Ok(n) = stream.read(&mut buffer).await { | |
panic!("this never happens :(") | |
// TODO: | |
// if n == 0 { | |
// break; | |
// } | |
// pty.write_all(&buffer[..n]).unwrap(); | |
} | |
}); | |
Ok((self, session)) | |
} | |
async fn signal(self, channel_id: ChannelId, signal: Sig, session: Session) -> Result<(Self, Session), Self::Error> { | |
panic!("this never happens :("); | |
} | |
async fn data( | |
mut self, | |
channel_id: ChannelId, | |
data: &[u8], | |
mut session: Session, | |
) -> Result<(Self, Session), Self::Error> { | |
panic!("this never happens :("); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment