Skip to content

Instantly share code, notes, and snippets.

@spaghetti-source
Last active August 1, 2024 07:47
Show Gist options
  • Save spaghetti-source/f0630f1d0ad1b98f736a1d8e9719ff6d to your computer and use it in GitHub Desktop.
Save spaghetti-source/f0630f1d0ad1b98f736a1d8e9719ff6d to your computer and use it in GitHub Desktop.
use std::collections::HashMap;
use std::time::Instant;
use candle_core::{backprop::GradStore, DType, Device, Result, TensorId, Var};
use candle_nn::{init::DEFAULT_KAIMING_NORMAL, Optimizer, VarBuilder, VarMap, SGD};
#[derive(Debug)]
pub struct MySGD {
vars: HashMap<TensorId, Var>,
learning_rate: f64,
}
impl Optimizer for MySGD {
type Config = f64;
fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
let vars = vars
.into_iter()
.filter_map(|var| {
if var.dtype().is_float() {
Some((var.id(), var))
} else {
None
}
})
.collect();
Ok(Self {
vars,
learning_rate,
})
}
fn learning_rate(&self) -> f64 {
self.learning_rate
}
// Add the following fn to GradStore
//
// pub fn get_ids(&self) -> impl Iterator<Item = &TensorId> {
// self.0.keys()
// }
fn step(&mut self, grads: &GradStore) -> Result<()> {
for id in grads.get_ids() {
if let Some(var) = self.vars.get(id) {
if let Some(grad) = grads.get(var) {
var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
}
}
}
Ok(())
}
fn set_learning_rate(&mut self, lr: f64) {
self.learning_rate = lr
}
}
fn original(vocabulary_size: usize) -> Result<()> {
let device = Device::cuda_if_available(0)?;
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
let mut tensors = vec![];
for i in 0..vocabulary_size {
tensors.push(vb.get_with_hints(128, &format!("var_{i}"), DEFAULT_KAIMING_NORMAL)?)
}
let now = Instant::now();
let mut optimizer = SGD::new(varmap.all_vars(), 0.01)?;
for idx in 0..vocabulary_size {
let loss = (&tensors[idx] * &tensors[idx])?.sum_all()?;
optimizer.backward_step(&loss)?;
}
println!("original: {:?}", now.elapsed());
Ok(())
}
fn sparse(vocabulary_size: usize) -> Result<()> {
let device = Device::cuda_if_available(0)?;
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
let mut tensors = vec![];
for i in 0..vocabulary_size {
tensors.push(vb.get_with_hints(128, &format!("var_{i}"), DEFAULT_KAIMING_NORMAL)?)
}
let now = Instant::now();
let mut optimizer = MySGD::new(varmap.all_vars(), 0.01)?;
for idx in 0..vocabulary_size {
let loss = (&tensors[idx] * &tensors[idx])?.sum_all()?;
optimizer.backward_step(&loss)?;
}
println!("sparse: {:?}", now.elapsed());
Ok(())
}
fn main() -> Result<()> {
original(10_000)?;
sparse(10_000)?;
original(100_000)?;
sparse(100_000)?;
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment