Skip to content

Instantly share code, notes, and snippets.

@xring
Last active December 16, 2024 03:54
Show Gist options
  • Save xring/98241f48db4b4fa24be0b3d7ea339974 to your computer and use it in GitHub Desktop.
Save xring/98241f48db4b4fa24be0b3d7ea339974 to your computer and use it in GitHub Desktop.
Gemini English Teacher, powered by GenerativeService and BidiGenerateContent. Rust version of https://github.com/nishuzumi/gemini-teacher
[package]
name = "gemini-teacher"
version = "0.1.0"
edition = "2021"
default-run = "gemini-teacher"
[dependencies]
tokio = { version = "1.32", features = ["full"] }
tokio-tungstenite = { version = "0.17",features = ["rustls-tls-native-roots"] }
tungstenite = { version = "0.20", features = ["rustls-tls-native-roots"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
base64 = "0.21"
cpal = "0.15"
colored = "2.0"
futures-util = "0.3.31"
rubato = "0.11"
use std::io::Write;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use tokio::sync::{broadcast, mpsc};
use tokio::task::JoinSet;
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::{connect_async, tungstenite::Message};
use base64::engine::general_purpose::STANDARD as base64_engine;
use base64::Engine as _;
use colored::*;
use serde_json::json;
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use rubato::Resampler;
use rubato::{InterpolationParameters, InterpolationType, SincFixedIn, WindowFunction};
const HOST: &str = "generativelanguage.googleapis.com";
const MODEL: &str = "gemini-2.0-flash-exp";
const PROMPT: &str = "你是一名专业的英语口语指导老师,你需要帮助用户纠正语法发音,用户将会说一句英文,然后你会给出识别出来的英语是什么,并且告诉他发音中有什么问题,语法有什么错误,并且一步一步的纠正他的发音,当一次发音正确后,根据当前语句提出下一个场景的语句,然后一直循环这个过程,直到用户说OK,我要退出。你的回答永远要保持中文。如果明白了请回答OK两个字";
#[tokio::main]
async fn main() {
let api_key = env::var("GEMINI_API_KEY").expect("Missing GOOGLE_API_KEY");
let uri = format!(
"wss://{}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={}",
HOST, api_key
);
println!("{}", "Gemini 英语口语助手".green());
println!("{}", "Made by twitter: @BoxMrChen".green());
println!("{}", "Rewritten to Rust by xring: @xringxie".green());
println!(
"{}",
"============================================".yellow()
);
match connect_async(&uri).await {
Ok((ws_stream, response)) => {
println!("Handshake response status: {:?}", response.status());
let (mut ws_sink, mut ws_stream) = ws_stream.split();
let (tx_ws, mut rx_ws_out) = mpsc::channel::<Message>(100);
let (tx_audio, _) = broadcast::channel::<Vec<u8>>(100);
let audio_loop = AudioLoop::new(tx_ws.clone(), tx_audio);
audio_loop.startup(MODEL).await;
let (tx_ws_in, rx_ws_in) = mpsc::channel::<Message>(100);
let mut tasks = JoinSet::new();
{
let audio_loop = audio_loop.clone();
tasks.spawn_blocking(move || {
audio_loop.listen_audio_blocking();
});
}
{
let audio_loop = audio_loop.clone();
tasks.spawn(async move {
audio_loop.send_audio().await;
});
}
{
let step = audio_loop.running_step.clone();
tasks.spawn(async move {
receive_audio(rx_ws_in, step).await;
});
}
tasks.spawn(async move {
while let Some(Ok(msg)) = ws_stream.next().await {
if let Err(e) = tx_ws_in.send(msg).await {
eprintln!("Forward to rx_ws_in error: {}", e);
break;
}
}
});
tasks.spawn(async move {
while let Some(msg) = rx_ws_out.recv().await {
match ws_sink.send(msg).await {
Ok(_) => {}
Err(e) => {
eprintln!("Failed to send WebSocket message {}", e);
break;
}
}
}
});
while let Some(res) = tasks.join_next().await {
if let Err(e) = res {
eprintln!("A task panicked: {:?}", e);
}
}
}
Err(e) => {
eprintln!("Failed to connect Gemini: {}", e);
return;
}
}
}
#[derive(Debug, Clone)]
struct AudioLoop {
tx_ws: mpsc::Sender<Message>,
tx_audio: broadcast::Sender<Vec<u8>>,
running_step: Arc<AtomicUsize>,
}
impl AudioLoop {
fn new(tx_ws: mpsc::Sender<Message>, tx_audio: broadcast::Sender<Vec<u8>>) -> Self {
Self {
tx_ws,
tx_audio,
running_step: Arc::new(AtomicUsize::new(0)),
}
}
async fn startup(&self, model: &str) {
let setup_msg = json!({
"setup": {
"model": format!("models/{}", model),
"generation_config": {
"response_modalities": ["TEXT"]
}
}
});
let _ = self.tx_ws.send(Message::Text(setup_msg.to_string())).await;
let initial_msg = json!({
"client_content": {
"turns": [
{
"role": "user",
"parts": [{
"text": PROMPT
}]
}
],
"turn_complete": true
}
});
let _ = self
.tx_ws
.send(Message::Text(initial_msg.to_string()))
.await;
}
fn listen_audio_blocking(&self) {
let host = cpal::default_host();
let input_device = host
.default_input_device()
.expect("No default input device available");
let mut supported_configs_range = input_device
.supported_input_configs()
.expect("Error while querying supported input configs");
let supported_config = supported_configs_range
.next()
.expect("No supported input configs on this device");
let config = supported_config.with_max_sample_rate();
let src_rate = config.sample_rate().0 as f64;
println!("{}", "🎤 说一句英语吧!比如: What is Rust language?".yellow());
let tx_audio = self.tx_audio.clone();
let running_step = self.running_step.clone();
println!("{:?}", config);
let stream = input_device
.build_input_stream(
&config.into(),
move |data: &[f32], _cbinfo| {
audio_callback_f32(src_rate, data, &running_step, &tx_audio);
},
move |err| {
eprintln!("cpal input stream error: {}", err);
},
None,
)
.expect("Failed to build input stream");
stream.play().expect("Failed to start input stream");
loop {
std::thread::sleep(std::time::Duration::from_secs(1));
}
}
async fn send_audio(&self) {
let mut rx_audio = self.tx_audio.subscribe();
let tx_ws = self.tx_ws.clone();
while let Ok(chunk) = rx_audio.recv().await {
let encoded = base64_engine.encode(&chunk);
let msg = json!({
"realtime_input": {
"media_chunks": [
{"data": encoded, "mime_type": "audio/pcm"}
]
}
});
let _ = tx_ws.send(Message::Text(msg.to_string())).await;
}
}
}
async fn receive_audio(mut rx_ws_in: mpsc::Receiver<Message>, running_step: Arc<AtomicUsize>) {
let mut current_response = String::new();
while let Some(msg) = rx_ws_in.recv().await {
if let Message::Binary(text) = msg {
let string_text = String::from_utf8(text.clone()).unwrap();
match serde_json::from_str::<serde_json::Value>(&string_text) {
Ok(js) => {
let step = running_step.load(Ordering::SeqCst);
if step == 1 {
print!("\n{}", "♻️ 处理中:".yellow());
let _ = std::io::stdout().flush();
running_step.store(2, Ordering::SeqCst);
}
if let Some(server_content) = js.get("serverContent") {
if let Some(model_turn) = server_content.get("modelTurn") {
if let Some(parts) = model_turn.get("parts").and_then(|p| p.as_array())
{
for part in parts {
if let Some(text_part) =
part.get("text").and_then(|v| v.as_str())
{
current_response.push_str(text_part);
print!("{}", "-".blue());
let _ = std::io::stdout().flush();
}
}
}
}
if let Some(turn_complete) = server_content
.get("turnComplete")
.and_then(|tc| tc.as_bool())
{
if turn_complete {
if !current_response.is_empty() {
println!(
"\n{}",
"🤖 =============================================".yellow()
);
println!("{}", current_response);
current_response.clear();
running_step.store(0, Ordering::SeqCst);
}
}
}
}
if let Some(server_content) = js.get("serverContent") {
if let Some(model_turn) = server_content.get("modelTurn") {
if let Some(parts) = model_turn.get("parts").and_then(|p| p.as_array())
{
let mut combined = String::new();
for part in parts {
if let Some(text_part) =
part.get("text").and_then(|v| v.as_str())
{
combined.push_str(text_part);
}
}
if combined.starts_with("OK") {
println!("初始化完成 ✅");
}
}
}
}
}
Err(e) => {
eprintln!("JSON parse error: {}", e);
}
}
} else {
println!("{}", msg);
}
}
}
fn audio_callback_f32(
src_rate: f64,
data: &[f32],
running_step: &Arc<AtomicUsize>,
tx_audio: &broadcast::Sender<Vec<u8>>,
) {
let resampled_f32 = resample_to_16000(src_rate, data);
let mut buf = Vec::with_capacity(resampled_f32.len() * 2);
let mut volume_sum = 0.0_f32;
for &sample_f32 in &resampled_f32 {
volume_sum += sample_f32.abs();
let sample_i16 = (sample_f32.clamp(-1.0, 1.0) * i16::MAX as f32) as i16;
buf.extend_from_slice(&sample_i16.to_le_bytes());
}
let volume = if !resampled_f32.is_empty() {
volume_sum / resampled_f32.len() as f32
} else {
0.0
};
let volume_threshold = 0.01;
if volume > volume_threshold {
let step = running_step.load(Ordering::SeqCst);
if step == 0 {
print!("{}", "🎤 :");
let _ = std::io::stdout().flush();
running_step.store(1, Ordering::SeqCst);
}
print!("*");
let _ = std::io::stdout().flush();
}
if let Err(e) = tx_audio.send(buf) {
eprintln!("Audio broadcast send error: {}", e);
}
}
fn resample_to_16000(src_rate: f64, input: &[f32]) -> Vec<f32> {
let dst_rate = 16000.0;
let resample_ratio = dst_rate / src_rate;
let max_resample_ratio_relative = 1.0;
let chunk_size = input.len();
let nbr_channels = 1;
let params = InterpolationParameters {
interpolation: InterpolationType::Cubic,
sinc_len: 256,
f_cutoff: 0.95,
oversampling_factor: 160,
window: WindowFunction::BlackmanHarris2,
};
let mut resampler = SincFixedIn::<f32>::new(
resample_ratio,
max_resample_ratio_relative,
params,
chunk_size,
nbr_channels,
)
.expect("Failed to create SincFixedIn re-sampler");
let input_2d = vec![input.to_vec()];
let output_2d = resampler
.process(&input_2d, None)
.expect("Error during re-sampler.process");
let output_f32 = output_2d[0].clone();
output_f32
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment