一、CIFAR10數據集下載
import urllib.request import tarfile import os url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" filepath = "cifar-10-python.tar.gz" if not os.path.isfile(filepath): result = urllib.request.urlretrieve(url, filepath) print('downloaded:', result) if not os.path.exists("cifar-10-batches-py"): tfile = tarfile.open("cifar-10-python.tar.gz", 'r:gz') result = tfile.extractall('.') print('extracted:', result) else: print('Data has existed.')
CIFAR-10(Canadian Institute For Advanced Research)是一個經典的圖像分類數據集,共有10個類別,每個類別有6000張32*32的彩色圖片,其中50000張作為訓練集,10000張作為測試集。
二、CIFAR10數據集讀取
import pickle import numpy as np def load_CIFAR_batch(filename): with open(filename, 'rb') as f: datadict = pickle.load(f, encoding='latin1') X = datadict['data'] Y = datadict['labels'] X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float") Y = np.array(Y) return X, Y def load_CIFAR10(ROOT): xs = [] ys = [] for b in range(1,6): f = os.path.join(ROOT, 'data_batch_%d' % (b,)) X, Y = load_CIFAR_batch(f) xs.append(X) ys.append(Y) Xtr = np.concatenate(xs) Ytr = np.concatenate(ys) del X, Y Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch')) return Xtr, Ytr, Xte, Yte
CIFAR-10數據集中包含5個批量的訓練數據,每個批量大小為10000,測試數據包含一個批量,大小也為10000。以上代碼可以讀取CIFAR-10數據集並將其轉化為易於處理的numpy數組。
三、CIFAR10數據集介紹
import numpy as np import matplotlib.pyplot as plt def visualize_CIFAR10_data(X_train, y_train, classes, samples_per_class=7): num_classes = len(classes) for y, cls in enumerate(classes): idxes = np.flatnonzero(y_train == y) idxes = np.random.choice(idxes, samples_per_class, replace=False) for i, idx in enumerate(idxes): plt_idx = i * num_classes + y + 1 plt.subplot(samples_per_class, num_classes, plt_idx) plt.imshow(X_train[idx].astype('uint8')) plt.axis('off') if i == 0: plt.title(cls) plt.show() def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=10000): cifar10_dir = '../datasets/cifar-10-batches-py' X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir) mask = range(num_training, num_training + num_validation) X_val = X_train[mask] y_val = y_train[mask] mask = range(num_training) X_train = X_train[mask] y_train = y_train[mask] mask = range(num_test) X_test = X_test[mask] y_test = y_test[mask] mean_image = np.mean(X_train, axis=0) X_train -= mean_image X_val -= mean_image X_test -= mean_image return X_train, y_train, X_val, y_val, X_test, y_test
CIFAR-10數據集是由TensorFlow提供的,其中包括訓練、測試、驗證集,每張圖片都有一個標籤,總共10個類別。以上代碼用於可視化數據集以及獲取數據集。
四、CIFAR100數據集下載
import urllib.request import tarfile import os url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" filepath = "cifar-100-python.tar.gz" if not os.path.isfile(filepath): result = urllib.request.urlretrieve(url, filepath) print('downloaded:', result) if not os.path.exists("cifar-100-python"): tfile = tarfile.open("cifar-100-python.tar.gz", 'r:gz') result = tfile.extractall('.') print('extracted:', result) else: print('Data has existed.')
CIFAR-100數據集共有100個類別,每個類別有600張32*32的彩色圖片,其中50000張作為訓練集,10000張作為測試集。CIFAR-100數據集的組織方式與CIFAR-10相似,但是它有更多的類別,更多的樣本。運行以上代碼可以下載CIFAR-100數據集。
五、CIFAR100數據集介紹
def load_CIFAR100(filename): with open(filename, 'rb') as f: datadict = pickle.load(f, encoding='latin1') X = datadict['data'] Y = datadict['fine_labels'] X = X.reshape(50000, 3, 32, 32).transpose(0,2,3,1).astype("float") Y = np.array(Y) return X, Y def visualize_CIFAR100_data(X_train, y_train, classes, samples_per_class=7): num_classes = len(classes) for y, cls in enumerate(classes): idxes = np.flatnonzero(y_train == y) idxes = np.random.choice(idxes, samples_per_class, replace=False) for i, idx in enumerate(idxes): plt_idx = i * num_classes + y + 1 plt.subplot(samples_per_class, num_classes, plt_idx) plt.imshow(X_train[idx].astype('uint8')) plt.axis('off') if i == 0: plt.title(cls) plt.show() def get_CIFAR100_data(num_training=49000, num_validation=1000, num_test=10000): cifar100_dir = '../datasets/cifar-100-python' X_train, y_train = load_CIFAR100(os.path.join(cifar100_dir, 'train')) X_test, y_test = load_CIFAR100(os.path.join(cifar100_dir, 'test')) classes = ['beaver', 'dolphin', 'otter', 'seal', 'whale','aquarium fish', 'flatfish', 'ray', 'shark', 'trout','orchids', 'poppies', 'roses', 'sunflowers', 'tulips','bottles', 'bowls', 'cans', 'cups', 'plates','apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers','clock', 'computer keyboard', 'lamp', 'telephone', 'television','bed', 'chair', 'couch', 'table', 'wardrobe','bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach','bear', 'leopard', 'lion', 'tiger', 'wolf','bridge', 'castle', 'house', 'road', 'skyscraper','cloud', 'forest', 'mountain', 'plain', 'sea','camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo','fox', 'porcupine', 'possum', 'raccoon', 'skunk','crab', 'lobster', 'snail', 'spider', 'worm','baby', 'boy', 'girl', 'man', 'woman','crocodile', 'dinosaur', 'lizard', 'snake', 'turtle','hamster', 'mouse', 'rabbit', 'shrew', 'squirrel','maple', 'oak', 'palm', 'pine', 'willow','bicycle', 'bus', 'motorcycle', 'pickup truck', 'train','lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor'] num_classes = len(classes) mask = range(num_training, num_training + num_validation) X_val = X_train[mask] y_val = y_train[mask] mask = range(num_training) X_train = X_train[mask] y_train = y_train[mask] mask = range(num_test) X_test = X_test[mask] y_test = y_test[mask] mean_image = np.mean(X_train, axis=0) X_train -= mean_image X_val -= mean_image X_test -= mean_image return X_train, y_train, X_val, y_val, X_test, y_test, classes
CIFAR-100數據集是一個更加複雜的數據集,它有更多的類別,更多的樣本。以上代碼也用於可視化數據集以及獲取數據集。
六、CIFAR100數據集大小
CIFAR-100數據集大小為169MB,可以從以下鏈接中下載:https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
七、CIFAR10數據集格式
CIFAR-10數據集中的每個批量都是一個Python Pickle字典,其中包含以下鍵:
data: 一個10000 * 3072的numpy數組,第一個維度是圖像的索引,第二個維度是展平的圖像像素值,該數組中的值在0到255之間。
labels: 由大小為10000的1D列表組成的一個長度為10000的numpy數組,其中每個元素是一個類別ID。
八、CIFAR100數據集準確率
CIFAR-100的分類準確率通常在70%到75%之間,具體取決於所用的演算法和架構。與CIFAR-10相比,它是更加挑戰性的,但是它也是一個非常好的機器學習基準測試集。
原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/190352.html