Created
September 2, 2025 11:54
-
-
Save kujirahand/eaa6aae9a17b2969720c99f19cb69df3 to your computer and use it in GitHub Desktop.
Candleを使ってMLPを実装、果物データセットの判定を行う
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use anyhow::{Context, Result}; | |
| use candle_core::{D, DType, Device, Tensor}; | |
| use candle_nn::{linear, loss, Module, VarBuilder, VarMap}; | |
| use candle_optimisers::adam::{Adam, ParamsAdam}; | |
| use candle_nn::optim::Optimizer; | |
| use image::{imageops::FilterType}; | |
| use rand::{seq::SliceRandom, rng}; | |
| use std::fs; | |
| use std::path::Path; | |
| use walkdir::WalkDir; | |
| // 機械学習で使うパラメータを定義 --- (*1) | |
| const IMG_SIZE: u32 = 32; // 画像サイズ(32, 32) * 3色=3072次元 | |
| const EPOCHS: usize = 20; // 学習時のエポック数 | |
| const BATCH_SIZE: usize = 128; // 学習時のバッチサイズ | |
| /// データセットの1サンプルを表現する構造体 --- (*2) | |
| #[derive(Clone)] | |
| struct Sample { | |
| x: Vec<f32>, // 正規化済み [0,1] | |
| y: u32, // クラスID | |
| } | |
| /// 単純なMLP構造のモデルを定義 --- (*3) | |
| struct Mlp { | |
| // 全結合層(candle_nn::Linear)を3段積んだ 多層パーセプトロン | |
| l1: candle_nn::Linear, | |
| l2: candle_nn::Linear, | |
| l3: candle_nn::Linear, | |
| } | |
| impl Mlp { | |
| /// MLPモデルを構築する(vbは重み・バイアス, in_dimは入力次元, out_dimは出力次元) | |
| fn new(vb: VarBuilder, in_dim: usize, hidden1: usize, hidden2: usize, out_dim: usize) -> Result<Self> { | |
| Ok(Self { | |
| l1: linear(in_dim, hidden1, vb.pp("l1"))?, | |
| l2: linear(hidden1, hidden2, vb.pp("l2"))?, | |
| l3: linear(hidden2, out_dim, vb.pp("l3"))?, | |
| }) | |
| } | |
| } | |
| impl Module for Mlp { | |
| /// 推論 | |
| fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> { | |
| // 入力を [N, in_dim] に整形 | |
| let x = if xs.rank() == 2 { xs.clone() } else { xs.flatten_from(1)? }; | |
| // 隠れ層 + ReLU活性化関数 | |
| let x = self.l1.forward(&x)?.relu()?; | |
| let x = self.l2.forward(&x)?.relu()?; | |
| // 出力層 | |
| self.l3.forward(&x) | |
| } | |
| } | |
| fn main() -> Result<()> { | |
| // dataフォルダ以下のデータセットを読み込み --- (*4) | |
| let base = Path::new("data"); | |
| anyhow::ensure!(base.exists(), "data/ ディレクトリが見つかりません。fruits_db を配置してください。"); | |
| // サブフォルダ(クラス)を列挙(例: lemon, strawberry) | |
| let mut classes = list_subdirs(base)?; | |
| classes.sort(); | |
| anyhow::ensure!(!classes.is_empty(), "クラスフォルダが見つかりません。"); | |
| println!("classes: {:?}", classes); | |
| // クラス名→IDのマップを作成 | |
| let class_to_id = classes | |
| .iter() | |
| .enumerate() | |
| .map(|(i, name)| (name.clone(), i as u32)) | |
| .collect::<std::collections::HashMap<_, _>>(); | |
| // 画像ファイルを読み込み --- (*5) | |
| let mut samples = vec![]; | |
| for c in &classes { | |
| let label = *class_to_id.get(c).unwrap(); | |
| let dir = base.join(c); | |
| let mut cnt = 0usize; | |
| for entry in WalkDir::new(&dir).into_iter().filter_map(|e| e.ok()) { | |
| let p = entry.path(); | |
| if p.is_file() && is_image_ext(p) { | |
| if let Ok(x) = load_image_as_vec(p) { | |
| samples.push(Sample { x, y: label }); | |
| cnt += 1; | |
| } | |
| } | |
| } | |
| // 読み込んだ画像枚数を表示 | |
| println!("loaded: {:>10} => {} files", c, cnt); | |
| } | |
| anyhow::ensure!(!samples.is_empty(), "画像が読み込めませんでした。"); | |
| // シャッフルして学習用途テスト用にデータセットを分割 --- (*6) | |
| // 80%を学習用、20%をテスト用に分割 | |
| samples.shuffle(&mut rng()); | |
| let n = samples.len(); | |
| let n_train = (n as f32 * 0.8).round() as usize; | |
| let (train, test) = samples.split_at(n_train); | |
| let in_dim = (IMG_SIZE * IMG_SIZE * 3) as usize; // RGB画像なので3倍 | |
| let out_dim = classes.len(); | |
| // データをTensor化する --- (*7) | |
| let device = Device::Cpu; | |
| /* | |
| // macOS用のMetalデバイスを使う場合 | |
| let device = match Device::new_metal(0) { | |
| Ok(device) => { | |
| println!("Using Metal GPU acceleration"); | |
| device | |
| } | |
| Err(_) => { | |
| println!("Metal not available, falling back to CPU"); | |
| Device::Cpu | |
| } | |
| }; | |
| */ | |
| let (x_train, y_train) = to_tensors(train, in_dim, &device)?; | |
| let (x_test, y_test) = to_tensors(test, in_dim, &device)?; | |
| // モデル・最適化器 --- (*8) | |
| let varmap = VarMap::new(); | |
| let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device); | |
| let model = Mlp::new(vb, in_dim, 512, 128, out_dim)?; | |
| let mut opt = Adam::new(varmap.all_vars(), ParamsAdam { lr: 1e-3, ..Default::default() })?; | |
| // 学習ループ --- (*9) | |
| let n_train = x_train.dims()[0]; | |
| println!("train: {n_train}, test: {}", x_test.dims()[0]); | |
| for epoch in 1..=EPOCHS { | |
| let mut idx: Vec<usize> = (0..n_train).collect(); | |
| idx.shuffle(&mut rng()); | |
| let mut total = 0f32; | |
| let mut steps = 0usize; | |
| for chunk in idx.chunks(BATCH_SIZE) { | |
| let chunk_u32: Vec<u32> = chunk.iter().map(|&x| x as u32).collect(); | |
| let chunk_tensor = Tensor::from_vec(chunk_u32, chunk.len(), &device)?; | |
| let xb = x_train.index_select(&chunk_tensor, 0)?; | |
| let yb = y_train.index_select(&chunk_tensor, 0)?; | |
| let logits = model.forward(&xb)?; | |
| let loss = loss::cross_entropy(&logits, &yb)?; | |
| total += loss.to_scalar::<f32>()?; | |
| steps += 1; | |
| opt.backward_step(&loss)?; | |
| } | |
| // テスト精度を確認 --- (*10) | |
| let acc = accuracy(&model, &x_test, &y_test)?; | |
| println!("epoch {epoch}: loss={:.4}, acc={:.2}%", total / steps as f32, acc * 100.0); | |
| } | |
| // クラスID→名前の表示 --- (*11) | |
| println!("class id map:"); | |
| for (name, id) in class_to_id.iter() { | |
| println!(" {} -> {}", id, name); | |
| } | |
| Ok(()) | |
| } | |
| // ユーティリティ関数 | |
| /// 指定フォルダの直下にあるサブフォルダ名を列挙する | |
| fn list_subdirs(base: &Path) -> Result<Vec<String>> { | |
| let mut v = vec![]; | |
| for e in fs::read_dir(base).with_context(|| format!("read_dir {:?}", base))? { | |
| let e = e?; | |
| if e.file_type()?.is_dir() { | |
| let name = e.file_name().to_string_lossy().to_string(); | |
| v.push(name); | |
| } | |
| } | |
| Ok(v) | |
| } | |
| /// 画像ファイルの拡張子がjpg/jpeg/pngかどうかを判定 | |
| fn is_image_ext(p: &Path) -> bool { | |
| if let Some(ext) = p.extension().and_then(|s| s.to_str()) { | |
| let ext = ext.to_ascii_lowercase(); | |
| matches!(ext.as_str(), "jpg" | "jpeg" | "png") | |
| } else { | |
| false | |
| } | |
| } | |
| /// 画像ファイルを読み込み、32x32にリサイズしてVec<f32>に変換 | |
| fn load_image_as_vec(p: &Path) -> Result<Vec<f32>> { | |
| let img = image::open(p).with_context(|| format!("open image {:?}", p))?; | |
| // RGB画像として32x32にリサイズ | |
| let rgb_img = img.to_rgb8(); | |
| let resized = image::imageops::resize(&rgb_img, IMG_SIZE, IMG_SIZE, FilterType::Triangle); | |
| // 各ピクセルのR,G,B値を0.0..=1.0に正規化してフラットなVec<f32>に格納 | |
| let mut v = Vec::with_capacity((IMG_SIZE * IMG_SIZE * 3) as usize); | |
| for pixel in resized.pixels() { | |
| v.push(pixel.0[0] as f32 / 255.0); // R | |
| v.push(pixel.0[1] as f32 / 255.0); // G | |
| v.push(pixel.0[2] as f32 / 255.0); // B | |
| } | |
| Ok(v) | |
| } | |
| /// SampleのスライスをTensorに変換 | |
| fn to_tensors(samples: &[Sample], in_dim: usize, device: &Device) -> Result<(Tensor, Tensor)> { | |
| let mut xs = Vec::with_capacity(samples.len() * in_dim); | |
| let mut ys = Vec::with_capacity(samples.len()); | |
| for s in samples { | |
| anyhow::ensure!(s.x.len() == in_dim, "unexpected dim {}", s.x.len()); | |
| xs.extend_from_slice(&s.x); | |
| ys.push(s.y); | |
| } | |
| let x = Tensor::from_vec(xs, (samples.len(), in_dim), device)?; | |
| let y = Tensor::from_vec(ys, samples.len(), device)?.to_dtype(DType::U32)?; | |
| Ok((x, y)) | |
| } | |
| /// モデルの精度を計算 | |
| fn accuracy<M: Module>(model: &M, x: &Tensor, y: &Tensor) -> Result<f32> { | |
| let logits = model.forward(x)?; | |
| let pred = logits.argmax(D::Minus1)?; | |
| let correct = pred | |
| .eq(y)? | |
| .to_dtype(DType::F32)? | |
| .sum_all()? | |
| .to_scalar::<f32>()?; | |
| Ok(correct / (y.dims()[0] as f32)) | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment