Last active
August 11, 2024 22:41
-
-
Save jmsdnns/2cc415c1b3f759f65294f4d5157bbbd6 to your computer and use it in GitHub Desktop.
Loads 10000 images, converts them to tensors, and finds the total number of unique colors
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use candle_core::{Device, Result, Tensor}; | |
//use candle_nn; | |
use image; | |
use std::collections::HashSet; | |
const CPUNKS_PATH: &str = "../cpunks-10k/cpunks/images/training/"; | |
const CPUNKS_TOTAL: u16 = 10000; | |
pub fn get_punk_tensor(p_id: u16) -> Result<(Tensor, usize, usize)> { | |
let path = format!("{}punk{:0>4}.png", CPUNKS_PATH, p_id); | |
let img = image::open(path).unwrap(); | |
let (height, width) = (img.height() as usize, img.width() as usize); | |
let img = img.to_rgba8(); | |
let data = img.into_raw(); | |
let data = Tensor::from_vec(data, (height, width, 4), &Device::Cpu)?.permute((2, 0, 1))?; | |
Ok((data, height, width)) | |
} | |
pub fn one_image_colors(img: &Tensor) -> HashSet<(u8, u8, u8, u8)> { | |
let img = img.reshape((4, ())).unwrap(); | |
let mut colors = HashSet::new(); | |
let r = img.get(0).unwrap(); | |
let g = img.get(1).unwrap(); | |
let b = img.get(2).unwrap(); | |
let a = img.get(3).unwrap(); | |
for i in 0..576 { | |
let color = ( | |
r.get(i).unwrap().to_scalar::<u8>().unwrap(), | |
g.get(i).unwrap().to_scalar::<u8>().unwrap(), | |
b.get(i).unwrap().to_scalar::<u8>().unwrap(), | |
a.get(i).unwrap().to_scalar::<u8>().unwrap() | |
); | |
colors.insert(color); | |
} | |
colors | |
} | |
pub fn many_image_colors(ts: &Vec<Tensor>) -> HashSet<(u8, u8, u8, u8)> { | |
let mut all_colors = HashSet::new(); | |
for t in ts { | |
let img = t.reshape((4, ())).unwrap(); | |
let colors = one_image_colors(&img); | |
all_colors.extend(&colors); | |
} | |
all_colors | |
} | |
pub fn get_all_punks() -> Vec<Tensor> { | |
let mut ts: Vec<Tensor> = Vec::with_capacity(CPUNKS_TOTAL as usize); | |
for i in 0..CPUNKS_TOTAL { | |
if let Ok((t, _, _)) = get_punk_tensor(i) { | |
ts.push(t); | |
} | |
} | |
ts | |
} | |
fn main() { | |
let ts = get_all_punks(); | |
let colors = many_image_colors(&ts); | |
println!("{}", colors.len()); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment