Last active
August 1, 2024 07:47
-
-
Save spaghetti-source/f0630f1d0ad1b98f736a1d8e9719ff6d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::HashMap; | |
use std::time::Instant; | |
use candle_core::{backprop::GradStore, DType, Device, Result, TensorId, Var}; | |
use candle_nn::{init::DEFAULT_KAIMING_NORMAL, Optimizer, VarBuilder, VarMap, SGD}; | |
#[derive(Debug)] | |
pub struct MySGD { | |
vars: HashMap<TensorId, Var>, | |
learning_rate: f64, | |
} | |
impl Optimizer for MySGD { | |
type Config = f64; | |
fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> { | |
let vars = vars | |
.into_iter() | |
.filter_map(|var| { | |
if var.dtype().is_float() { | |
Some((var.id(), var)) | |
} else { | |
None | |
} | |
}) | |
.collect(); | |
Ok(Self { | |
vars, | |
learning_rate, | |
}) | |
} | |
fn learning_rate(&self) -> f64 { | |
self.learning_rate | |
} | |
// Add the following fn to GradStore | |
// | |
// pub fn get_ids(&self) -> impl Iterator<Item = &TensorId> { | |
// self.0.keys() | |
// } | |
fn step(&mut self, grads: &GradStore) -> Result<()> { | |
for id in grads.get_ids() { | |
if let Some(var) = self.vars.get(id) { | |
if let Some(grad) = grads.get(var) { | |
var.set(&var.sub(&(grad * self.learning_rate)?)?)?; | |
} | |
} | |
} | |
Ok(()) | |
} | |
fn set_learning_rate(&mut self, lr: f64) { | |
self.learning_rate = lr | |
} | |
} | |
fn original(vocabulary_size: usize) -> Result<()> { | |
let device = Device::cuda_if_available(0)?; | |
let varmap = VarMap::new(); | |
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device); | |
let mut tensors = vec![]; | |
for i in 0..vocabulary_size { | |
tensors.push(vb.get_with_hints(128, &format!("var_{i}"), DEFAULT_KAIMING_NORMAL)?) | |
} | |
let now = Instant::now(); | |
let mut optimizer = SGD::new(varmap.all_vars(), 0.01)?; | |
for idx in 0..vocabulary_size { | |
let loss = (&tensors[idx] * &tensors[idx])?.sum_all()?; | |
optimizer.backward_step(&loss)?; | |
} | |
println!("original: {:?}", now.elapsed()); | |
Ok(()) | |
} | |
fn sparse(vocabulary_size: usize) -> Result<()> { | |
let device = Device::cuda_if_available(0)?; | |
let varmap = VarMap::new(); | |
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device); | |
let mut tensors = vec![]; | |
for i in 0..vocabulary_size { | |
tensors.push(vb.get_with_hints(128, &format!("var_{i}"), DEFAULT_KAIMING_NORMAL)?) | |
} | |
let now = Instant::now(); | |
let mut optimizer = MySGD::new(varmap.all_vars(), 0.01)?; | |
for idx in 0..vocabulary_size { | |
let loss = (&tensors[idx] * &tensors[idx])?.sum_all()?; | |
optimizer.backward_step(&loss)?; | |
} | |
println!("sparse: {:?}", now.elapsed()); | |
Ok(()) | |
} | |
fn main() -> Result<()> { | |
original(10_000)?; | |
sparse(10_000)?; | |
original(100_000)?; | |
sparse(100_000)?; | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment