Skip to content

Instantly share code, notes, and snippets.

@kujirahand
Created September 2, 2025 11:54
Show Gist options
  • Save kujirahand/eaa6aae9a17b2969720c99f19cb69df3 to your computer and use it in GitHub Desktop.
Save kujirahand/eaa6aae9a17b2969720c99f19cb69df3 to your computer and use it in GitHub Desktop.
Candleを使ってMLPを実装、果物データセットの判定を行う
use anyhow::{Context, Result};
use candle_core::{D, DType, Device, Tensor};
use candle_nn::{linear, loss, Module, VarBuilder, VarMap};
use candle_optimisers::adam::{Adam, ParamsAdam};
use candle_nn::optim::Optimizer;
use image::{imageops::FilterType};
use rand::{seq::SliceRandom, rng};
use std::fs;
use std::path::Path;
use walkdir::WalkDir;
// 機械学習で使うパラメータを定義 --- (*1)
const IMG_SIZE: u32 = 32; // 画像サイズ(32, 32) * 3色=3072次元
const EPOCHS: usize = 20; // 学習時のエポック数
const BATCH_SIZE: usize = 128; // 学習時のバッチサイズ
/// データセットの1サンプルを表現する構造体 --- (*2)
#[derive(Clone)]
struct Sample {
x: Vec<f32>, // 正規化済み [0,1]
y: u32, // クラスID
}
/// 単純なMLP構造のモデルを定義 --- (*3)
struct Mlp {
// 全結合層(candle_nn::Linear)を3段積んだ 多層パーセプトロン
l1: candle_nn::Linear,
l2: candle_nn::Linear,
l3: candle_nn::Linear,
}
impl Mlp {
/// MLPモデルを構築する(vbは重み・バイアス, in_dimは入力次元, out_dimは出力次元)
fn new(vb: VarBuilder, in_dim: usize, hidden1: usize, hidden2: usize, out_dim: usize) -> Result<Self> {
Ok(Self {
l1: linear(in_dim, hidden1, vb.pp("l1"))?,
l2: linear(hidden1, hidden2, vb.pp("l2"))?,
l3: linear(hidden2, out_dim, vb.pp("l3"))?,
})
}
}
impl Module for Mlp {
/// 推論
fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
// 入力を [N, in_dim] に整形
let x = if xs.rank() == 2 { xs.clone() } else { xs.flatten_from(1)? };
// 隠れ層 + ReLU活性化関数
let x = self.l1.forward(&x)?.relu()?;
let x = self.l2.forward(&x)?.relu()?;
// 出力層
self.l3.forward(&x)
}
}
fn main() -> Result<()> {
// dataフォルダ以下のデータセットを読み込み --- (*4)
let base = Path::new("data");
anyhow::ensure!(base.exists(), "data/ ディレクトリが見つかりません。fruits_db を配置してください。");
// サブフォルダ(クラス)を列挙(例: lemon, strawberry)
let mut classes = list_subdirs(base)?;
classes.sort();
anyhow::ensure!(!classes.is_empty(), "クラスフォルダが見つかりません。");
println!("classes: {:?}", classes);
// クラス名→IDのマップを作成
let class_to_id = classes
.iter()
.enumerate()
.map(|(i, name)| (name.clone(), i as u32))
.collect::<std::collections::HashMap<_, _>>();
// 画像ファイルを読み込み --- (*5)
let mut samples = vec![];
for c in &classes {
let label = *class_to_id.get(c).unwrap();
let dir = base.join(c);
let mut cnt = 0usize;
for entry in WalkDir::new(&dir).into_iter().filter_map(|e| e.ok()) {
let p = entry.path();
if p.is_file() && is_image_ext(p) {
if let Ok(x) = load_image_as_vec(p) {
samples.push(Sample { x, y: label });
cnt += 1;
}
}
}
// 読み込んだ画像枚数を表示
println!("loaded: {:>10} => {} files", c, cnt);
}
anyhow::ensure!(!samples.is_empty(), "画像が読み込めませんでした。");
// シャッフルして学習用途テスト用にデータセットを分割 --- (*6)
// 80%を学習用、20%をテスト用に分割
samples.shuffle(&mut rng());
let n = samples.len();
let n_train = (n as f32 * 0.8).round() as usize;
let (train, test) = samples.split_at(n_train);
let in_dim = (IMG_SIZE * IMG_SIZE * 3) as usize; // RGB画像なので3倍
let out_dim = classes.len();
// データをTensor化する --- (*7)
let device = Device::Cpu;
/*
// macOS用のMetalデバイスを使う場合
let device = match Device::new_metal(0) {
Ok(device) => {
println!("Using Metal GPU acceleration");
device
}
Err(_) => {
println!("Metal not available, falling back to CPU");
Device::Cpu
}
};
*/
let (x_train, y_train) = to_tensors(train, in_dim, &device)?;
let (x_test, y_test) = to_tensors(test, in_dim, &device)?;
// モデル・最適化器 --- (*8)
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
let model = Mlp::new(vb, in_dim, 512, 128, out_dim)?;
let mut opt = Adam::new(varmap.all_vars(), ParamsAdam { lr: 1e-3, ..Default::default() })?;
// 学習ループ --- (*9)
let n_train = x_train.dims()[0];
println!("train: {n_train}, test: {}", x_test.dims()[0]);
for epoch in 1..=EPOCHS {
let mut idx: Vec<usize> = (0..n_train).collect();
idx.shuffle(&mut rng());
let mut total = 0f32;
let mut steps = 0usize;
for chunk in idx.chunks(BATCH_SIZE) {
let chunk_u32: Vec<u32> = chunk.iter().map(|&x| x as u32).collect();
let chunk_tensor = Tensor::from_vec(chunk_u32, chunk.len(), &device)?;
let xb = x_train.index_select(&chunk_tensor, 0)?;
let yb = y_train.index_select(&chunk_tensor, 0)?;
let logits = model.forward(&xb)?;
let loss = loss::cross_entropy(&logits, &yb)?;
total += loss.to_scalar::<f32>()?;
steps += 1;
opt.backward_step(&loss)?;
}
// テスト精度を確認 --- (*10)
let acc = accuracy(&model, &x_test, &y_test)?;
println!("epoch {epoch}: loss={:.4}, acc={:.2}%", total / steps as f32, acc * 100.0);
}
// クラスID→名前の表示 --- (*11)
println!("class id map:");
for (name, id) in class_to_id.iter() {
println!(" {} -> {}", id, name);
}
Ok(())
}
// ユーティリティ関数
/// 指定フォルダの直下にあるサブフォルダ名を列挙する
fn list_subdirs(base: &Path) -> Result<Vec<String>> {
let mut v = vec![];
for e in fs::read_dir(base).with_context(|| format!("read_dir {:?}", base))? {
let e = e?;
if e.file_type()?.is_dir() {
let name = e.file_name().to_string_lossy().to_string();
v.push(name);
}
}
Ok(v)
}
/// 画像ファイルの拡張子がjpg/jpeg/pngかどうかを判定
fn is_image_ext(p: &Path) -> bool {
if let Some(ext) = p.extension().and_then(|s| s.to_str()) {
let ext = ext.to_ascii_lowercase();
matches!(ext.as_str(), "jpg" | "jpeg" | "png")
} else {
false
}
}
/// 画像ファイルを読み込み、32x32にリサイズしてVec<f32>に変換
fn load_image_as_vec(p: &Path) -> Result<Vec<f32>> {
let img = image::open(p).with_context(|| format!("open image {:?}", p))?;
// RGB画像として32x32にリサイズ
let rgb_img = img.to_rgb8();
let resized = image::imageops::resize(&rgb_img, IMG_SIZE, IMG_SIZE, FilterType::Triangle);
// 各ピクセルのR,G,B値を0.0..=1.0に正規化してフラットなVec<f32>に格納
let mut v = Vec::with_capacity((IMG_SIZE * IMG_SIZE * 3) as usize);
for pixel in resized.pixels() {
v.push(pixel.0[0] as f32 / 255.0); // R
v.push(pixel.0[1] as f32 / 255.0); // G
v.push(pixel.0[2] as f32 / 255.0); // B
}
Ok(v)
}
/// SampleのスライスをTensorに変換
fn to_tensors(samples: &[Sample], in_dim: usize, device: &Device) -> Result<(Tensor, Tensor)> {
let mut xs = Vec::with_capacity(samples.len() * in_dim);
let mut ys = Vec::with_capacity(samples.len());
for s in samples {
anyhow::ensure!(s.x.len() == in_dim, "unexpected dim {}", s.x.len());
xs.extend_from_slice(&s.x);
ys.push(s.y);
}
let x = Tensor::from_vec(xs, (samples.len(), in_dim), device)?;
let y = Tensor::from_vec(ys, samples.len(), device)?.to_dtype(DType::U32)?;
Ok((x, y))
}
/// モデルの精度を計算
fn accuracy<M: Module>(model: &M, x: &Tensor, y: &Tensor) -> Result<f32> {
let logits = model.forward(x)?;
let pred = logits.argmax(D::Minus1)?;
let correct = pred
.eq(y)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()?;
Ok(correct / (y.dims()[0] as f32))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment