Last active
December 16, 2024 03:54
-
-
Save xring/98241f48db4b4fa24be0b3d7ea339974 to your computer and use it in GitHub Desktop.
Gemini English Teacher, powered by GenerativeService and BidiGenerateContent. Rust version of https://github.com/nishuzumi/gemini-teacher
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[package] | |
name = "gemini-teacher" | |
version = "0.1.0" | |
edition = "2021" | |
default-run = "gemini-teacher" | |
[dependencies] | |
tokio = { version = "1.32", features = ["full"] } | |
tokio-tungstenite = { version = "0.17",features = ["rustls-tls-native-roots"] } | |
tungstenite = { version = "0.20", features = ["rustls-tls-native-roots"] } | |
serde = { version = "1.0", features = ["derive"] } | |
serde_json = "1.0" | |
base64 = "0.21" | |
cpal = "0.15" | |
colored = "2.0" | |
futures-util = "0.3.31" | |
rubato = "0.11" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::io::Write; | |
use std::sync::{ | |
atomic::{AtomicUsize, Ordering}, | |
Arc, | |
}; | |
use tokio::sync::{broadcast, mpsc}; | |
use tokio::task::JoinSet; | |
use futures_util::{SinkExt, StreamExt}; | |
use tokio_tungstenite::{connect_async, tungstenite::Message}; | |
use base64::engine::general_purpose::STANDARD as base64_engine; | |
use base64::Engine as _; | |
use colored::*; | |
use serde_json::json; | |
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; | |
use rubato::Resampler; | |
use rubato::{InterpolationParameters, InterpolationType, SincFixedIn, WindowFunction}; | |
const HOST: &str = "generativelanguage.googleapis.com"; | |
const MODEL: &str = "gemini-2.0-flash-exp"; | |
const PROMPT: &str = "你是一名专业的英语口语指导老师,你需要帮助用户纠正语法发音,用户将会说一句英文,然后你会给出识别出来的英语是什么,并且告诉他发音中有什么问题,语法有什么错误,并且一步一步的纠正他的发音,当一次发音正确后,根据当前语句提出下一个场景的语句,然后一直循环这个过程,直到用户说OK,我要退出。你的回答永远要保持中文。如果明白了请回答OK两个字"; | |
#[tokio::main] | |
async fn main() { | |
let api_key = env::var("GEMINI_API_KEY").expect("Missing GOOGLE_API_KEY"); | |
let uri = format!( | |
"wss://{}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={}", | |
HOST, api_key | |
); | |
println!("{}", "Gemini 英语口语助手".green()); | |
println!("{}", "Made by twitter: @BoxMrChen".green()); | |
println!("{}", "Rewritten to Rust by xring: @xringxie".green()); | |
println!( | |
"{}", | |
"============================================".yellow() | |
); | |
match connect_async(&uri).await { | |
Ok((ws_stream, response)) => { | |
println!("Handshake response status: {:?}", response.status()); | |
let (mut ws_sink, mut ws_stream) = ws_stream.split(); | |
let (tx_ws, mut rx_ws_out) = mpsc::channel::<Message>(100); | |
let (tx_audio, _) = broadcast::channel::<Vec<u8>>(100); | |
let audio_loop = AudioLoop::new(tx_ws.clone(), tx_audio); | |
audio_loop.startup(MODEL).await; | |
let (tx_ws_in, rx_ws_in) = mpsc::channel::<Message>(100); | |
let mut tasks = JoinSet::new(); | |
{ | |
let audio_loop = audio_loop.clone(); | |
tasks.spawn_blocking(move || { | |
audio_loop.listen_audio_blocking(); | |
}); | |
} | |
{ | |
let audio_loop = audio_loop.clone(); | |
tasks.spawn(async move { | |
audio_loop.send_audio().await; | |
}); | |
} | |
{ | |
let step = audio_loop.running_step.clone(); | |
tasks.spawn(async move { | |
receive_audio(rx_ws_in, step).await; | |
}); | |
} | |
tasks.spawn(async move { | |
while let Some(Ok(msg)) = ws_stream.next().await { | |
if let Err(e) = tx_ws_in.send(msg).await { | |
eprintln!("Forward to rx_ws_in error: {}", e); | |
break; | |
} | |
} | |
}); | |
tasks.spawn(async move { | |
while let Some(msg) = rx_ws_out.recv().await { | |
match ws_sink.send(msg).await { | |
Ok(_) => {} | |
Err(e) => { | |
eprintln!("Failed to send WebSocket message {}", e); | |
break; | |
} | |
} | |
} | |
}); | |
while let Some(res) = tasks.join_next().await { | |
if let Err(e) = res { | |
eprintln!("A task panicked: {:?}", e); | |
} | |
} | |
} | |
Err(e) => { | |
eprintln!("Failed to connect Gemini: {}", e); | |
return; | |
} | |
} | |
} | |
#[derive(Debug, Clone)] | |
struct AudioLoop { | |
tx_ws: mpsc::Sender<Message>, | |
tx_audio: broadcast::Sender<Vec<u8>>, | |
running_step: Arc<AtomicUsize>, | |
} | |
impl AudioLoop { | |
fn new(tx_ws: mpsc::Sender<Message>, tx_audio: broadcast::Sender<Vec<u8>>) -> Self { | |
Self { | |
tx_ws, | |
tx_audio, | |
running_step: Arc::new(AtomicUsize::new(0)), | |
} | |
} | |
async fn startup(&self, model: &str) { | |
let setup_msg = json!({ | |
"setup": { | |
"model": format!("models/{}", model), | |
"generation_config": { | |
"response_modalities": ["TEXT"] | |
} | |
} | |
}); | |
let _ = self.tx_ws.send(Message::Text(setup_msg.to_string())).await; | |
let initial_msg = json!({ | |
"client_content": { | |
"turns": [ | |
{ | |
"role": "user", | |
"parts": [{ | |
"text": PROMPT | |
}] | |
} | |
], | |
"turn_complete": true | |
} | |
}); | |
let _ = self | |
.tx_ws | |
.send(Message::Text(initial_msg.to_string())) | |
.await; | |
} | |
fn listen_audio_blocking(&self) { | |
let host = cpal::default_host(); | |
let input_device = host | |
.default_input_device() | |
.expect("No default input device available"); | |
let mut supported_configs_range = input_device | |
.supported_input_configs() | |
.expect("Error while querying supported input configs"); | |
let supported_config = supported_configs_range | |
.next() | |
.expect("No supported input configs on this device"); | |
let config = supported_config.with_max_sample_rate(); | |
let src_rate = config.sample_rate().0 as f64; | |
println!("{}", "🎤 说一句英语吧!比如: What is Rust language?".yellow()); | |
let tx_audio = self.tx_audio.clone(); | |
let running_step = self.running_step.clone(); | |
println!("{:?}", config); | |
let stream = input_device | |
.build_input_stream( | |
&config.into(), | |
move |data: &[f32], _cbinfo| { | |
audio_callback_f32(src_rate, data, &running_step, &tx_audio); | |
}, | |
move |err| { | |
eprintln!("cpal input stream error: {}", err); | |
}, | |
None, | |
) | |
.expect("Failed to build input stream"); | |
stream.play().expect("Failed to start input stream"); | |
loop { | |
std::thread::sleep(std::time::Duration::from_secs(1)); | |
} | |
} | |
async fn send_audio(&self) { | |
let mut rx_audio = self.tx_audio.subscribe(); | |
let tx_ws = self.tx_ws.clone(); | |
while let Ok(chunk) = rx_audio.recv().await { | |
let encoded = base64_engine.encode(&chunk); | |
let msg = json!({ | |
"realtime_input": { | |
"media_chunks": [ | |
{"data": encoded, "mime_type": "audio/pcm"} | |
] | |
} | |
}); | |
let _ = tx_ws.send(Message::Text(msg.to_string())).await; | |
} | |
} | |
} | |
async fn receive_audio(mut rx_ws_in: mpsc::Receiver<Message>, running_step: Arc<AtomicUsize>) { | |
let mut current_response = String::new(); | |
while let Some(msg) = rx_ws_in.recv().await { | |
if let Message::Binary(text) = msg { | |
let string_text = String::from_utf8(text.clone()).unwrap(); | |
match serde_json::from_str::<serde_json::Value>(&string_text) { | |
Ok(js) => { | |
let step = running_step.load(Ordering::SeqCst); | |
if step == 1 { | |
print!("\n{}", "♻️ 处理中:".yellow()); | |
let _ = std::io::stdout().flush(); | |
running_step.store(2, Ordering::SeqCst); | |
} | |
if let Some(server_content) = js.get("serverContent") { | |
if let Some(model_turn) = server_content.get("modelTurn") { | |
if let Some(parts) = model_turn.get("parts").and_then(|p| p.as_array()) | |
{ | |
for part in parts { | |
if let Some(text_part) = | |
part.get("text").and_then(|v| v.as_str()) | |
{ | |
current_response.push_str(text_part); | |
print!("{}", "-".blue()); | |
let _ = std::io::stdout().flush(); | |
} | |
} | |
} | |
} | |
if let Some(turn_complete) = server_content | |
.get("turnComplete") | |
.and_then(|tc| tc.as_bool()) | |
{ | |
if turn_complete { | |
if !current_response.is_empty() { | |
println!( | |
"\n{}", | |
"🤖 =============================================".yellow() | |
); | |
println!("{}", current_response); | |
current_response.clear(); | |
running_step.store(0, Ordering::SeqCst); | |
} | |
} | |
} | |
} | |
if let Some(server_content) = js.get("serverContent") { | |
if let Some(model_turn) = server_content.get("modelTurn") { | |
if let Some(parts) = model_turn.get("parts").and_then(|p| p.as_array()) | |
{ | |
let mut combined = String::new(); | |
for part in parts { | |
if let Some(text_part) = | |
part.get("text").and_then(|v| v.as_str()) | |
{ | |
combined.push_str(text_part); | |
} | |
} | |
if combined.starts_with("OK") { | |
println!("初始化完成 ✅"); | |
} | |
} | |
} | |
} | |
} | |
Err(e) => { | |
eprintln!("JSON parse error: {}", e); | |
} | |
} | |
} else { | |
println!("{}", msg); | |
} | |
} | |
} | |
fn audio_callback_f32( | |
src_rate: f64, | |
data: &[f32], | |
running_step: &Arc<AtomicUsize>, | |
tx_audio: &broadcast::Sender<Vec<u8>>, | |
) { | |
let resampled_f32 = resample_to_16000(src_rate, data); | |
let mut buf = Vec::with_capacity(resampled_f32.len() * 2); | |
let mut volume_sum = 0.0_f32; | |
for &sample_f32 in &resampled_f32 { | |
volume_sum += sample_f32.abs(); | |
let sample_i16 = (sample_f32.clamp(-1.0, 1.0) * i16::MAX as f32) as i16; | |
buf.extend_from_slice(&sample_i16.to_le_bytes()); | |
} | |
let volume = if !resampled_f32.is_empty() { | |
volume_sum / resampled_f32.len() as f32 | |
} else { | |
0.0 | |
}; | |
let volume_threshold = 0.01; | |
if volume > volume_threshold { | |
let step = running_step.load(Ordering::SeqCst); | |
if step == 0 { | |
print!("{}", "🎤 :"); | |
let _ = std::io::stdout().flush(); | |
running_step.store(1, Ordering::SeqCst); | |
} | |
print!("*"); | |
let _ = std::io::stdout().flush(); | |
} | |
if let Err(e) = tx_audio.send(buf) { | |
eprintln!("Audio broadcast send error: {}", e); | |
} | |
} | |
fn resample_to_16000(src_rate: f64, input: &[f32]) -> Vec<f32> { | |
let dst_rate = 16000.0; | |
let resample_ratio = dst_rate / src_rate; | |
let max_resample_ratio_relative = 1.0; | |
let chunk_size = input.len(); | |
let nbr_channels = 1; | |
let params = InterpolationParameters { | |
interpolation: InterpolationType::Cubic, | |
sinc_len: 256, | |
f_cutoff: 0.95, | |
oversampling_factor: 160, | |
window: WindowFunction::BlackmanHarris2, | |
}; | |
let mut resampler = SincFixedIn::<f32>::new( | |
resample_ratio, | |
max_resample_ratio_relative, | |
params, | |
chunk_size, | |
nbr_channels, | |
) | |
.expect("Failed to create SincFixedIn re-sampler"); | |
let input_2d = vec![input.to_vec()]; | |
let output_2d = resampler | |
.process(&input_2d, None) | |
.expect("Error during re-sampler.process"); | |
let output_f32 = output_2d[0].clone(); | |
output_f32 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment