一、作用與基本用法
torch.mul是Pytorch中的一個重要函數,用於對兩個張量逐元素相乘,返回一個新的張量。
torch.mul(input, other, out=None)
- input:第一個相乘的張量
- other:第二個相乘的張量
- out :指定輸出張量
若不指定out,返回的是逐元素相乘後的新張量,若指定out,則原地修改第一個張量,也就是將第二個張量逐元素相乘後結果賦值給第一個張量。
import torch a = torch.tensor([1, 2, 3]) b = torch.tensor([4, 5, 6]) # 返回新的張量 c = torch.mul(a, b) print(c) # tensor([ 4, 10, 18]) # 原地修改第一個張量 torch.mul(a, b, out=a) print(a) # tensor([ 4, 10, 18])
二、特殊使用
1.向量點乘
向量點乘是指兩個向量逐元素相乘然後相加的結果,可以用torch.mul實現。
例如,我們有兩個向量a=[1,2,3]和b=[4,5,6],向量a與向量b的點積為1*4 + 2*5 + 3*6 = 32。
import torch a = torch.tensor([1, 2, 3]) b = torch.tensor([4, 5, 6]) dot_product = torch.sum(torch.mul(a, b)) print(dot_product) # tensor(32)
2.矩陣乘法
矩陣乘法是對兩個矩陣進行操作的一種方式,可以使用torch.mul和torch.sum函數實現。
例如,我們有兩個矩陣A=[[1,2,3],[4,5,6]]和B=[[7,8],[9,10],[11,12]],則矩陣AB為:[[1*7+2*9+3*11, 1*8+2*10+3*12], [4*7+5*9+6*11, 4*8+5*10+6*12]]。
import torch A = torch.tensor([[1, 2, 3], [4, 5, 6]]) B = torch.tensor([[7, 8], [9, 10], [11, 12]]) AB = torch.zeros((A.shape[0], B.shape[1])) for i in range(A.shape[0]): for j in range(B.shape[1]): AB[i][j] = torch.sum(torch.mul(A[i], B[:,j])) print(AB) # tensor([[ 58, 64], # [139, 154]])
三、小結
torch.mul是Pytorch中一個十分常用的函數,可以對兩個張量進行逐元素相乘操作得到新的張量,也可以在指定輸出張量的情況下進行原地修改。此外,torch.mul還可以進行向量點乘和矩陣乘法等特殊用途。
當我們處理神經網絡的深度學習中的線性變換過程時,經常會使用到torch.mul及相關操作函數,掌握好它們的使用方式也是我們提高深度學習技能的重要一步。
原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/280514.html