一、什麼是原型網路
原型網路(Prototypical Network)是一種深度學習中用於學習表示空間的神經網路,由於能夠處理較小的數據集而被廣泛應用於圖像分類、目標識別等領域。
原型網路將樣本分為若干個類別,對每一類樣本通過計算其與該類別原型的距離來進行分類。其中,原型是該類別所有樣本向量均值的表徵。
二、原型網路的優點
與傳統深度學習分類演算法相比,原型網路有以下優點:
- 對於小樣本任務,其具有更好的表示學習能力。
- 相對於其他深度學習方法,原型網路的計算複雜度和存儲需求較低。
- 原型網路需要更少的訓練時間,即使在訓練樣本數量較大的情況下也具有很好的收斂性。
三、原型網路的應用
1. 圖像分類
原型網路可以通過將向量表示為圖像從而實現圖像分類。在訓練期間,網路對於所有樣本類別計算其原型。在測試期間,對於任意一個測試樣本,它將預測它屬於哪個類別。該預測將基於該測試樣本與計算所有原型之間的距離,以及該測試樣本屬於哪些類別中的距離最近。
class ProtoNet(nn.Module): def __init__(self, x_dim=1, y_dim=1, hid_dim=64, z_dim=64): super(ProtoNet, self).__init__() self.encoder = Encoder(x_dim, hid_dim, z_dim) self.fc = nn.Linear(z_dim, y_dim) def set_forward_loss(self, X, Y, n_support): z_support, z_query = self.parse_feature(X, n_support) y_support, y_query = self.parse_label(Y, n_support) z_support = z_support.contiguous() z_proto = z_support.view(n_support * self.n_way, -1).mean(0) z_query = z_query.contiguous().view(-1, *z_query.size()[2:]) dists = euclidean_dist(self.fc(z_query), self.fc(z_proto)) scores = -dists return scores def set_forward(self, X, n_support): z_support, z_query = self.parse_feature(X, n_support) z_support = z_support.contiguous() z_proto = z_support.view(n_support * self.n_way, -1).mean(0) z_query = z_query.contiguous().view(-1, *z_query.size()[2:]) dists = euclidean_dist(self.fc(z_query), self.fc(z_proto)) scores = -dists return scores.argmax(dim=1) def parse_feature(self, X, n_support): X = X.reshape(self.n_way, n_support + self.n_query, *X.shape[2:]) support, query = X[:, :n_support], X[:, n_support:] z_support = self.encoder(support.reshape(self.n_way * n_support, *support.shape[2:])) z_query = self.encoder(query.reshape(self.n_way * self.n_query, *query.shape[2:])) return z_support, z_query def parse_label(self, Y, n_support): Y = Y.reshape(self.n_way, n_support + self.n_query) label_support = Y[:, :n_support] label_query = Y[:, n_support:] return label_support, label_query
2. 目標識別
原型網路可以將其用於目標識別等視覺任務。首先將所有圖像通過特徵提取器進行編碼,然後對於每個類別,通過計算該類別所有樣本的均值得到該類別原型。在測試期間,任何一個樣本都可以與所有原型進行比較,並選擇距離最近的一個作為其類別。
class PrototypicalLoss(nn.Module): def __init__(self, n_support): super(PrototypicalLoss, self).__init__() self.n_support = n_support def forward(self, input, target): assert input.size(0) % self.n_support == 0, "erro, #class not matched with #support" input = input.reshape(input.size(0) // self.n_support, self.n_support, -1) input_cpu = input.detach().cpu() target_cpu = target.detach().cpu() classes = torch.unique(target_cpu) n_classes = len(classes) prototypes = torch.zeros(n_classes, input.size(2)) for i, class_id in enumerate(classes): mask = (target_cpu == class_id).unsqueeze(1).expand_as(input_cpu) class_support_features = torch.masked_select(input_cpu, mask).reshape(self.n_support, -1) prototypes[i] = class_support_features.mean(0).squeeze() mask = (target == classes.unsqueeze(1)).expand_as(input) input = input.reshape(-1, input.size(2)) target = torch.nonzero(mask)[:, 1] dists = euclidean_dist(input, prototypes) log_p_y = F.log_softmax(-dists, dim=1).view(self.n_support, -1, n_classes) target = target.view(self.n_support, -1) loss = -log_p_y.gather(2, target).squeeze().view(-1).mean() _, y_hat = log_p_y.max(2) acc = torch.eq(y_hat, target).float().mean() return loss, acc
3. 元學習
原型網路還可以用於元學習,僅通過極少量的數據樣本學習一般類別的表示空間,並從更廣泛的語境中進行歸納或外推。
class Embedding(nn.Module): def __init__(self, input_size=1, embedding_size=64): super(Embedding, self).__init__() self.layer1 = nn.Linear(input_size, embedding_size) self.layer2 = nn.Linear(embedding_size, embedding_size) self.layer3 = nn.Linear(embedding_size, embedding_size) def forward(self, x): x = F.relu(self.layer1(x)) x = F.relu(self.layer2(x)) x = self.layer3(x) return x class PrototypicalNet(nn.Module): def __init__(self, encoder): super(PrototypicalNet, self).__init__() self.encoder = encoder def forward(self, x): embeddings = self.encoder(x) return embeddings def embed(self, x): return self.forward(x)
四、總結
總的來說,原型網路是一個非常有效且快速的深度學習演算法,它可以從有限量的樣本中學習表示空間,因此對於小數據集任務非常有用。而且它可以用於圖像分類、目標識別和元學習等眾多領域,是一種兼顧效率和效果的可靠選擇。
原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/248278.html