CIFAR100数据集下载及介绍

一、CIFAR10数据集下载

import urllib.request
import tarfile
import os

url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filepath = "cifar-10-python.tar.gz"
if not os.path.isfile(filepath):
    result = urllib.request.urlretrieve(url, filepath)
    print('downloaded:', result)
     
if not os.path.exists("cifar-10-batches-py"):
    tfile = tarfile.open("cifar-10-python.tar.gz", 'r:gz')
    result = tfile.extractall('.')
    print('extracted:', result)
else:
    print('Data has existed.')

CIFAR-10(Canadian Institute For Advanced Research)是一个经典的图像分类数据集,共有10个类别,每个类别有6000张32*32的彩色图片,其中50000张作为训练集,10000张作为测试集。

二、CIFAR10数据集读取

import pickle
import numpy as np
 
def load_CIFAR_batch(filename):
    with open(filename, 'rb') as f:
        datadict = pickle.load(f, encoding='latin1')
        X = datadict['data']
        Y = datadict['labels']
        X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
        Y = np.array(Y)
        return X, Y
 
def load_CIFAR10(ROOT):
    xs = []
    ys = []
    for b in range(1,6):
        f = os.path.join(ROOT, 'data_batch_%d' % (b,))
        X, Y = load_CIFAR_batch(f)
        xs.append(X)
        ys.append(Y)
    Xtr = np.concatenate(xs)
    Ytr = np.concatenate(ys)
    del X, Y
    Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
    return Xtr, Ytr, Xte, Yte

CIFAR-10数据集中包含5个批量的训练数据,每个批量大小为10000,测试数据包含一个批量,大小也为10000。以上代码可以读取CIFAR-10数据集并将其转化为易于处理的numpy数组。

三、CIFAR10数据集介绍

import numpy as np
import matplotlib.pyplot as plt
 
def visualize_CIFAR10_data(X_train, y_train, classes, samples_per_class=7):
    num_classes = len(classes)
    for y, cls in enumerate(classes):
        idxes = np.flatnonzero(y_train == y)
        idxes = np.random.choice(idxes, samples_per_class, replace=False)
        for i, idx in enumerate(idxes):
            plt_idx = i * num_classes + y + 1
            plt.subplot(samples_per_class, num_classes, plt_idx)
            plt.imshow(X_train[idx].astype('uint8'))
            plt.axis('off')
            if i == 0:
                plt.title(cls)
    plt.show()

def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=10000):
    cifar10_dir = '../datasets/cifar-10-batches-py'
    X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
         
    mask = range(num_training, num_training + num_validation)
    X_val = X_train[mask]
    y_val = y_train[mask]
    mask = range(num_training)
    X_train = X_train[mask]
    y_train = y_train[mask]
    mask = range(num_test)
    X_test = X_test[mask]
    y_test = y_test[mask]
         
    mean_image = np.mean(X_train, axis=0)
    X_train -= mean_image
    X_val -= mean_image
    X_test -= mean_image
     
    return X_train, y_train, X_val, y_val, X_test, y_test

CIFAR-10数据集是由TensorFlow提供的,其中包括训练、测试、验证集,每张图片都有一个标签,总共10个类别。以上代码用于可视化数据集以及获取数据集。

四、CIFAR100数据集下载

import urllib.request
import tarfile
import os

url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filepath = "cifar-100-python.tar.gz"
if not os.path.isfile(filepath):
    result = urllib.request.urlretrieve(url, filepath)
    print('downloaded:', result)
     
if not os.path.exists("cifar-100-python"):
    tfile = tarfile.open("cifar-100-python.tar.gz", 'r:gz')
    result = tfile.extractall('.')
    print('extracted:', result)
else:
    print('Data has existed.')

CIFAR-100数据集共有100个类别,每个类别有600张32*32的彩色图片,其中50000张作为训练集,10000张作为测试集。CIFAR-100数据集的组织方式与CIFAR-10相似,但是它有更多的类别,更多的样本。运行以上代码可以下载CIFAR-100数据集。

五、CIFAR100数据集介绍

def load_CIFAR100(filename):
    with open(filename, 'rb') as f:
        datadict = pickle.load(f, encoding='latin1')
        X = datadict['data']
        Y = datadict['fine_labels']
        X = X.reshape(50000, 3, 32, 32).transpose(0,2,3,1).astype("float")
        Y = np.array(Y)
        return X, Y

def visualize_CIFAR100_data(X_train, y_train, classes, samples_per_class=7):
    num_classes = len(classes)
    for y, cls in enumerate(classes):
        idxes = np.flatnonzero(y_train == y)
        idxes = np.random.choice(idxes, samples_per_class, replace=False)
        for i, idx in enumerate(idxes):
            plt_idx = i * num_classes + y + 1
            plt.subplot(samples_per_class, num_classes, plt_idx)
            plt.imshow(X_train[idx].astype('uint8'))
            plt.axis('off')
            if i == 0:
                plt.title(cls)
    plt.show()

def get_CIFAR100_data(num_training=49000, num_validation=1000, num_test=10000):
    cifar100_dir = '../datasets/cifar-100-python'
    X_train, y_train = load_CIFAR100(os.path.join(cifar100_dir, 'train'))
    X_test, y_test = load_CIFAR100(os.path.join(cifar100_dir, 'test'))
    classes = ['beaver', 'dolphin', 'otter', 'seal', 'whale','aquarium fish', 'flatfish', 'ray', 'shark', 'trout','orchids', 'poppies', 'roses', 'sunflowers', 'tulips','bottles', 'bowls', 'cans', 'cups', 'plates','apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers','clock', 'computer keyboard', 'lamp', 'telephone', 'television','bed', 'chair', 'couch', 'table', 'wardrobe','bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach','bear', 'leopard', 'lion', 'tiger', 'wolf','bridge', 'castle', 'house', 'road', 'skyscraper','cloud', 'forest', 'mountain', 'plain', 'sea','camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo','fox', 'porcupine', 'possum', 'raccoon', 'skunk','crab', 'lobster', 'snail', 'spider', 'worm','baby', 'boy', 'girl', 'man', 'woman','crocodile', 'dinosaur', 'lizard', 'snake', 'turtle','hamster', 'mouse', 'rabbit', 'shrew', 'squirrel','maple', 'oak', 'palm', 'pine', 'willow','bicycle', 'bus', 'motorcycle', 'pickup truck', 'train','lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor']
    num_classes = len(classes)
    mask = range(num_training, num_training + num_validation)
    X_val = X_train[mask]
    y_val = y_train[mask]
    mask = range(num_training)
    X_train = X_train[mask]
    y_train = y_train[mask]
    mask = range(num_test)
    X_test = X_test[mask]
    y_test = y_test[mask]
    mean_image = np.mean(X_train, axis=0)
    X_train -= mean_image
    X_val -= mean_image
    X_test -= mean_image
    return X_train, y_train, X_val, y_val, X_test, y_test, classes

CIFAR-100数据集是一个更加复杂的数据集,它有更多的类别,更多的样本。以上代码也用于可视化数据集以及获取数据集。

六、CIFAR100数据集大小

CIFAR-100数据集大小为169MB,可以从以下链接中下载:https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz

七、CIFAR10数据集格式

CIFAR-10数据集中的每个批量都是一个Python Pickle字典,其中包含以下键:

data: 一个10000 * 3072的numpy数组,第一个维度是图像的索引,第二个维度是展平的图像像素值,该数组中的值在0到255之间。

labels: 由大小为10000的1D列表组成的一个长度为10000的numpy数组,其中每个元素是一个类别ID。

八、CIFAR100数据集准确率

CIFAR-100的分类准确率通常在70%到75%之间,具体取决于所用的算法和架构。与CIFAR-10相比,它是更加挑战性的,但是它也是一个非常好的机器学习基准测试集。

原创文章,作者:小蓝,如若转载,请注明出处:https://www.506064.com/n/190352.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
小蓝小蓝
上一篇 2024-11-29 22:32
下一篇 2024-11-29 22:32

相关推荐

  • Python读取CSV数据画散点图

    本文将从以下方面详细阐述Python读取CSV文件并画出散点图的方法: 一、CSV文件介绍 CSV(Comma-Separated Values)即逗号分隔值,是一种存储表格数据的…

    编程 2025-04-29
  • Python中读入csv文件数据的方法用法介绍

    csv是一种常见的数据格式,通常用于存储小型数据集。Python作为一种广泛流行的编程语言,内置了许多操作csv文件的库。本文将从多个方面详细介绍Python读入csv文件的方法。…

    编程 2025-04-29
  • 如何用Python统计列表中各数据的方差和标准差

    本文将从多个方面阐述如何使用Python统计列表中各数据的方差和标准差, 并给出详细的代码示例。 一、什么是方差和标准差 方差是衡量数据变异程度的统计指标,它是每个数据值和该数据值…

    编程 2025-04-29
  • Python多线程读取数据

    本文将详细介绍多线程读取数据在Python中的实现方法以及相关知识点。 一、线程和多线程 线程是操作系统调度的最小单位。单线程程序只有一个线程,按照程序从上到下的顺序逐行执行。而多…

    编程 2025-04-29
  • Python爬取公交数据

    本文将从以下几个方面详细阐述python爬取公交数据的方法: 一、准备工作 1、安装相关库 import requests from bs4 import BeautifulSou…

    编程 2025-04-29
  • Python两张表数据匹配

    本篇文章将详细阐述如何使用Python将两张表格中的数据匹配。以下是具体的解决方法。 一、数据匹配的概念 在生活和工作中,我们常常需要对多组数据进行比对和匹配。在数据量较小的情况下…

    编程 2025-04-29
  • Python数据标准差标准化

    本文将为大家详细讲述Python中的数据标准差标准化,以及涉及到的相关知识。 一、什么是数据标准差标准化 数据标准差标准化是数据处理中的一种方法,通过对数据进行标准差标准化可以将不…

    编程 2025-04-29
  • 如何使用Python读取CSV数据

    在数据分析、数据挖掘和机器学习等领域,CSV文件是一种非常常见的文件格式。Python作为一种广泛使用的编程语言,也提供了方便易用的CSV读取库。本文将介绍如何使用Python读取…

    编程 2025-04-29
  • Python根据表格数据生成折线图

    本文将介绍如何使用Python根据表格数据生成折线图。折线图是一种常见的数据可视化图表形式,可以用来展示数据的趋势和变化。Python是一种流行的编程语言,其强大的数据分析和可视化…

    编程 2025-04-29
  • Python如何打乱数据集

    本文将从多个方面详细阐述Python打乱数据集的方法。 一、shuffle函数原理 shuffle函数是Python中的一个内置函数,主要作用是将一个可迭代对象的元素随机排序。 在…

    编程 2025-04-29

发表回复

登录后才能评论