一、tensordot概述
tensordot是一種numpy中的數學函數,它旨在實現高維張量的乘法操作。在實際深度學習的應用中,特別是卷積神經網路中,tensordot是一項核心技術,因此學習如何使用它是至關重要的。
tensordot最基本的使用形式為:np.tensordot(a, b, axes),其中a和b都是具有多個軸的張量。在這個基本形式中,tensordot將a和b中的軸進行匹配,然後對它們進行乘法操作,最終返回一個新的張量c。
import numpy as np a = np.random.rand(3, 4, 5) b = np.random.rand(4, 5, 6) c = np.tensordot(a, b, axes=([1, 2], [0, 1])) print(c.shape) # 輸出(3, 6)
在此示例中,我們定義了兩個張量a和b,分別是shape為(3, 4, 5)和shape為(4, 5, 6)的張量。我們對a的最後兩個維度(4和5)和b的第一二個維度(4和5)進行了匹配,然後執行了張量相乘,得到了一個新的張量c,它的shape為(3,6)。
二、理解tensordot的axes參數
tensordot的axes參數用於指定張量a和張量b的維度匹配方式。在基本形式中,它採用了默認值,即axes=2,它會從a和b中的最後兩個維度開始匹配兩個張量,並輸出其他維度的乘積。實際上,axes接受一個元組(x,y),其中x和y都是張量的維度,表示我們要將a的第x個維度和b的第y個維度進行匹配。因此,當我們將axes設置為([1, 2], [0, 1])時,它將從a和b中的第1和第2個維度開始匹配,並輸出其他維度的乘積。
下面通過一個更高級的例子,來進一步理解axes參數的作用。
import numpy as np a = np.random.rand(3,4,5) b = np.random.rand(4,5,6) c = np.tensordot(a,b,axes=([1], [0])) print(c.shape) # 輸出(3,6,6)
在此示例中,我們設置了axes=([1], [0]),這意味著我們要從a的第1個維度開始匹配,從b的第0個維度開始匹配。此時,a的第1個維度大小為4,b的第0個維度的大小也為4,因此,這種匹配方式是合法的。然後,我們執行[a[:,i,:] * b[i,:,:] for i in range(4)]操作,將這些張量相加,得到一個新的張量,它的shape為(3,6,6)。
三、tensordot的高級操作
在深度學習中,tensordot還有很多高級用法。
1. tensordot的broadcasting行為
tensordot類似於廣播操作,它可以自動擴展輸入張量的形狀,以適應要執行的操作。因此,我們可以使用不同形狀的張量來執行tensordot操作,根據axes參數的設置,可以自動調整張量的形狀,以執行正確的操作。
import numpy as np x = np.random.rand(2, 3) y = np.random.rand(3, 4, 5) z = np.tensordot(x, y, axes=(1, 0)) print(z.shape) # 輸出(2,4,5)
在本例中,我們定義了一個形狀為(2,3)的張量x,和一個形狀為(3,4,5)的張量y。我們設置axes=(1,0),這意味著通過將x的第1個維度與y的第0個維度相匹配並相乘來計算tensordot。x的第1個維度大小為3,與y的第0個維度的大小相同,因此它們能正確匹配。我們得到的新張量的形狀是(2,4,5)。
2. tensordot的reshape操作
在某些情況下,我們需要將張量的維度進行重新排列,以使它們可以在tensordot操作中正確匹配。這個過程在numpy中的實現非常簡單,我們可以使用reshape函數來輕鬆地重塑張量的形狀。
import numpy as np a = np.random.rand(3, 4, 5) b = np.random.rand(4, 5, 6) a = np.reshape(a, (3, 20)) b = np.reshape(b, (20, 6)) c = np.tensordot(a, b, axes=1) print(c.shape) # 輸出(3,6)
在此示例中,我們定義了兩個張量a和b,分別是形狀為(3, 4, 5)和(4, 5, 6)的張量。然後,我們使用reshape函數將張量a和b的形狀分別改變為(3,20)和(20,6),這使它們可以正確匹配,進行tensordot操作。我們得到的新張量的形狀是(3,6)。
3. tensordot的內積實現
tensordot還可以用於計算內積。對於兩個形狀都為(N,)的張量,它們的內積可以通過tensordot來計算。
import numpy as np x = np.random.rand(3) y = np.random.rand(3) ip = np.tensordot(x,y,axes=0) print(ip) # 輸出單個實數
在此示例中,我們定義了兩個為(3,)形張量x和y。我們將axes設置為0,這意味著我們要計算兩個張量的內積,即[sum(x[i]*y[i])],得到的結果是一個單個的實數。
四、總結
tensordot是numpy中的一種高級操作,可用於計算張量的乘法。在深度學習中,tensordot是卷積神經網路的核心技術之一。通過本文,我們深入理解了numpy中tensordot的基本用法和高級用法。可以根據具體的需求來選擇合適的axes參數,輕鬆實現高維張量的乘法操作。
原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/250459.html