tensorflow與pytorch的區別「tensorflow文檔教程」

文章學習資源來自TensorFlow官網文檔

一、 說明

本文訓練一個網路模型來進行服裝分類,比如衣服是T恤還是夾克。這可以快速入門了解TensorFlow2.0怎麼進行分類任務的。

二、步驟

1. 引入 tf.keras

from __future__ import absolute_import, division, print_function, unicode_literals

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)

2. 導入MNIST時裝數據集

Fashion MNIST 包含了10類、70000張灰度圖。這個數據集被打造為圖像識別任務的Hello World程序。
數據集地址 :
https://github.com/zalandoresearch/fashion-mnist
下面圖片是一些圖片示例(28*28像素):

TensorFlow2學習五、基本圖像分類任務
fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

執行代碼,程序會自動下載數據集。

TensorFlow2學習五、基本圖像分類任務

載入的數據集返回4個NumPy數組:

  • train_images , train_labels 數組:模型數據訓練集
  • test_images,test_labes 數組:模型測試集

圖像是28*28的NumPy數組,像素值從0-255。標是整數,0-9,下面是含義:

LabelClass0T-shirt/top1Trouser2Pullover3Dress4Coat5Sandal6Shirt7Sneaker8Bag9Ankle boot

下面定義標註名稱:

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

3. 分析數據

通過train_images.shape可以查看訓練模型的數據格式,這裡會顯示它是60000張圖片的訓練集,每個圖片28*28像素:

TensorFlow2學習五、基本圖像分類任務

查看len(train_labels) 訓練標註:

TensorFlow2學習五、基本圖像分類任務

類似的,也可以查看測試集。

4. 預處理數據

訓練前要先把數據預處理。這裡可以先試著看一張圖片:

plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid(False)
plt.show()

結果:

TensorFlow2學習五、基本圖像分類任務

可以看到像素值是0-255。下面將值轉換到0-1。訓練集和測試集必須採用同樣的處理方法 。

train_images = train_images / 255.0
test_images = test_images / 255.0

下面顯示25張圖片,看看圖片轉換的結果:

TensorFlow2學習五、基本圖像分類任務

5. 重點來了,創建神經網路模型

過程: 1. 配置 ;2.編譯

i. 建順序層

模型的基本單位是層。使用keras會比傳統手工更容易創建一個層:

model = keras.Sequential([
 keras.layers.Flatten(input_shape=(28, 28)),
 keras.layers.Dense(128, activation='relu'),
 keras.layers.Dense(10, activation='softmax')
])

第1個層:tf.keras.layers.Flatten,將圖片從2維(2828像素)數組,轉成一維數組(2828=784像素)。這個層只是把數據平面化。
下面是兩個tf.keras.layers.Dense層,它們稱為緊密連接或全連接、或神經層。1層有128個神經節點,第二個有10節點的softmax激活函數,它返回 10個可能性分值,這些分值總和是1.每個節點都表示當前圖片屬於哪種分類的分值。

2. 編譯模型

編譯要定義三個參數:

  • 損失函數
  • 優化器
  • 評估指標:用來監視訓練和測試的步驟。下面是使用accuracy。
model.compile(optimizer='adam',
 loss='sparse_categorical_crossentropy',
 metrics=['accuracy'])

3. 訓練模型 ,3個步驟:

  1. 輸入訓練數據
  2. 模型學習圖片和標註間的規律
  3. 測試集測試

開始訓練:

model.fit(train_images, train_labels, epochs=10)

訓練過程中會顯示損失值、準確度。

TensorFlow2學習五、基本圖像分類任務

4. 測試集測試,看看訓練的準確度怎麼樣

test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)

print('nTest accuracy:', test_acc)
TensorFlow2學習五、基本圖像分類任務

5. 預測

這裡使用測試集試試預測效果:

predictions = model.predict(test_images)
TensorFlow2學習五、基本圖像分類任務

輸出是一個數組,表示屬於10種分類的可能性值。使用argmax取最大置信度的值:看看和標註值可一致:

print('predict = %i; label=%i' % (np.argmax(predictions[0]),test_labels[0]))
TensorFlow2學習五、基本圖像分類任務

三、完整程序:

from __future__ import absolute_import, division, print_function, unicode_literals

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)

fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

train_images = train_images / 255.0
test_images = test_images / 255.0

model = keras.Sequential([
 keras.layers.Flatten(input_shape=(28, 28)),
 keras.layers.Dense(128, activation='relu'),
 keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
 loss='sparse_categorical_crossentropy',
 metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=10)

test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)

print('nTest accuracy:', test_acc)
predictions = model.predict(test_images)
print('predict = %i; label=%i' % (np.argmax(predictions[0]),test_labels[0]))

原創文章,作者:投稿專員,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/274622.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
投稿專員的頭像投稿專員
上一篇 2024-12-17 14:15
下一篇 2024-12-17 14:15

相關推薦

發表回復

登錄後才能評論