""" Code for calculating the mean and standard deviation of a dataset. This is useful for normalizing the dataset to obtain mean 0, std 1. Programmed by Aladdin Persson * 2020-05-09 Initial coding * 2022-12-16 Updated comments, code revision, and checked code still works with latest PyTorch. """ import torch import torchvision.transforms as transforms from torch.utils.data import DataLoader import torchvision.datasets as datasets from tqdm import tqdm device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_set = datasets.CIFAR10( root="dataset/", transform=transforms.ToTensor(), download=True ) train_loader = DataLoader(dataset=train_set, batch_size=64, shuffle=True) def get_mean_std(loader): # var[X] = E[X**2] - E[X]**2 channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0 for data, _ in tqdm(loader): channels_sum += torch.mean(data, dim=[0, 2, 3]) channels_sqrd_sum += torch.mean(data**2, dim=[0, 2, 3]) num_batches += 1 mean = channels_sum / num_batches std = (channels_sqrd_sum / num_batches - mean**2) ** 0.5 return mean, std mean, std = get_mean_std(train_loader) print(mean) print(std)