Skip to content

Instantly share code, notes, and snippets.

@jdhao
Last active September 20, 2023 06:36
Show Gist options
  • Save jdhao/9a86d4b9e4f79c5330d54de991461fd6 to your computer and use it in GitHub Desktop.
Save jdhao/9a86d4b9e4f79c5330d54de991461fd6 to your computer and use it in GitHub Desktop.
This snippet will calculate the per-channel image mean and std in the train image set. It is plain simple and may not be efficient for large scale dataset.
"""
in this script, we calculate the image per channel mean and standard
deviation in the training set, do not calculate the statistics on the
whole dataset, as per here http://cs231n.github.io/neural-networks-2/#datapre
"""
import numpy as np
from os import listdir
from os.path import join, isdir
from glob import glob
import cv2
import timeit
# number of channels of the dataset image, 3 for color jpg, 1 for grayscale img
# you need to change it to reflect your dataset
CHANNEL_NUM = 3
def cal_dir_stat(root):
cls_dirs = [d for d in listdir(root) if isdir(join(root, d))]
pixel_num = 0 # store all pixel number in the dataset
channel_sum = np.zeros(CHANNEL_NUM)
channel_sum_squared = np.zeros(CHANNEL_NUM)
for idx, d in enumerate(cls_dirs):
print("#{} class".format(idx))
im_pths = glob(join(root, d, "*.jpg"))
for path in im_pths:
im = cv2.imread(path) # image in M*N*CHANNEL_NUM shape, channel in BGR order
im = im/255.0
pixel_num += (im.size/CHANNEL_NUM)
channel_sum += np.sum(im, axis=(0, 1))
channel_sum_squared += np.sum(np.square(im), axis=(0, 1))
bgr_mean = channel_sum / pixel_num
bgr_std = np.sqrt(channel_sum_squared / pixel_num - np.square(bgr_mean))
# change the format from bgr to rgb
rgb_mean = list(bgr_mean)[::-1]
rgb_std = list(bgr_std)[::-1]
return rgb_mean, rgb_std
# The script assumes that under train_root, there are separate directories for each class
# of training images.
train_root = "/hd1/jdhao/firearm-dataset/train/"
start = timeit.default_timer()
mean, std = cal_dir_stat(train_root)
end = timeit.default_timer()
print("elapsed time: {}".format(end-start))
print("mean:{}\nstd:{}".format(mean, std))
@jdhao
Copy link
Author

jdhao commented Dec 21, 2018

@ozym4nd145, Updated the gist, thanks for your suggestion.

@canhnht
Copy link

canhnht commented Oct 22, 2019

@jdhao Thanks a lot for this script.

@ora-zaq
Copy link

ora-zaq commented Sep 1, 2020

@jdhao
Thank you very much. This script helps me.

@ashnair1
Copy link

ashnair1 commented Nov 16, 2020

If CHANNEL_NUM = 1, script throws an error because cv2.imread() reads images with 3 channels by default. You can either:

if CHANNEL_NUM == 1:
            im = im[:,:,0]

or use

if CHANNEL_NUM == 1:
            im = cv2.imread(path, cv2.IMREAD_GRAYSCALE)

@qbaocaca
Copy link

qbaocaca commented Jun 7, 2022

Thanks for your script.

@morganmcg1
Copy link

thanks for this :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment