一、什么是原型网络
原型网络(Prototypical Network)是一种深度学习中用于学习表示空间的神经网络,由于能够处理较小的数据集而被广泛应用于图像分类、目标识别等领域。
原型网络将样本分为若干个类别,对每一类样本通过计算其与该类别原型的距离来进行分类。其中,原型是该类别所有样本向量均值的表征。
二、原型网络的优点
与传统深度学习分类算法相比,原型网络有以下优点:
- 对于小样本任务,其具有更好的表示学习能力。
- 相对于其他深度学习方法,原型网络的计算复杂度和存储需求较低。
- 原型网络需要更少的训练时间,即使在训练样本数量较大的情况下也具有很好的收敛性。
三、原型网络的应用
1. 图像分类
原型网络可以通过将向量表示为图像从而实现图像分类。在训练期间,网络对于所有样本类别计算其原型。在测试期间,对于任意一个测试样本,它将预测它属于哪个类别。该预测将基于该测试样本与计算所有原型之间的距离,以及该测试样本属于哪些类别中的距离最近。
class ProtoNet(nn.Module):
def __init__(self, x_dim=1, y_dim=1, hid_dim=64, z_dim=64):
super(ProtoNet, self).__init__()
self.encoder = Encoder(x_dim, hid_dim, z_dim)
self.fc = nn.Linear(z_dim, y_dim)
def set_forward_loss(self, X, Y, n_support):
z_support, z_query = self.parse_feature(X, n_support)
y_support, y_query = self.parse_label(Y, n_support)
z_support = z_support.contiguous()
z_proto = z_support.view(n_support * self.n_way, -1).mean(0)
z_query = z_query.contiguous().view(-1, *z_query.size()[2:])
dists = euclidean_dist(self.fc(z_query), self.fc(z_proto))
scores = -dists
return scores
def set_forward(self, X, n_support):
z_support, z_query = self.parse_feature(X, n_support)
z_support = z_support.contiguous()
z_proto = z_support.view(n_support * self.n_way, -1).mean(0)
z_query = z_query.contiguous().view(-1, *z_query.size()[2:])
dists = euclidean_dist(self.fc(z_query), self.fc(z_proto))
scores = -dists
return scores.argmax(dim=1)
def parse_feature(self, X, n_support):
X = X.reshape(self.n_way, n_support + self.n_query, *X.shape[2:])
support, query = X[:, :n_support], X[:, n_support:]
z_support = self.encoder(support.reshape(self.n_way * n_support, *support.shape[2:]))
z_query = self.encoder(query.reshape(self.n_way * self.n_query, *query.shape[2:]))
return z_support, z_query
def parse_label(self, Y, n_support):
Y = Y.reshape(self.n_way, n_support + self.n_query)
label_support = Y[:, :n_support]
label_query = Y[:, n_support:]
return label_support, label_query
2. 目标识别
原型网络可以将其用于目标识别等视觉任务。首先将所有图像通过特征提取器进行编码,然后对于每个类别,通过计算该类别所有样本的均值得到该类别原型。在测试期间,任何一个样本都可以与所有原型进行比较,并选择距离最近的一个作为其类别。
class PrototypicalLoss(nn.Module):
def __init__(self, n_support):
super(PrototypicalLoss, self).__init__()
self.n_support = n_support
def forward(self, input, target):
assert input.size(0) % self.n_support == 0, "erro, #class not matched with #support"
input = input.reshape(input.size(0) // self.n_support, self.n_support, -1)
input_cpu = input.detach().cpu()
target_cpu = target.detach().cpu()
classes = torch.unique(target_cpu)
n_classes = len(classes)
prototypes = torch.zeros(n_classes, input.size(2))
for i, class_id in enumerate(classes):
mask = (target_cpu == class_id).unsqueeze(1).expand_as(input_cpu)
class_support_features = torch.masked_select(input_cpu, mask).reshape(self.n_support, -1)
prototypes[i] = class_support_features.mean(0).squeeze()
mask = (target == classes.unsqueeze(1)).expand_as(input)
input = input.reshape(-1, input.size(2))
target = torch.nonzero(mask)[:, 1]
dists = euclidean_dist(input, prototypes)
log_p_y = F.log_softmax(-dists, dim=1).view(self.n_support, -1, n_classes)
target = target.view(self.n_support, -1)
loss = -log_p_y.gather(2, target).squeeze().view(-1).mean()
_, y_hat = log_p_y.max(2)
acc = torch.eq(y_hat, target).float().mean()
return loss, acc
3. 元学习
原型网络还可以用于元学习,仅通过极少量的数据样本学习一般类别的表示空间,并从更广泛的语境中进行归纳或外推。
class Embedding(nn.Module):
def __init__(self, input_size=1, embedding_size=64):
super(Embedding, self).__init__()
self.layer1 = nn.Linear(input_size, embedding_size)
self.layer2 = nn.Linear(embedding_size, embedding_size)
self.layer3 = nn.Linear(embedding_size, embedding_size)
def forward(self, x):
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
x = self.layer3(x)
return x
class PrototypicalNet(nn.Module):
def __init__(self, encoder):
super(PrototypicalNet, self).__init__()
self.encoder = encoder
def forward(self, x):
embeddings = self.encoder(x)
return embeddings
def embed(self, x):
return self.forward(x)
四、总结
总的来说,原型网络是一个非常有效且快速的深度学习算法,它可以从有限量的样本中学习表示空间,因此对于小数据集任务非常有用。而且它可以用于图像分类、目标识别和元学习等众多领域,是一种兼顾效率和效果的可靠选择。
原创文章,作者:小蓝,如若转载,请注明出处:https://www.506064.com/n/248278.html
微信扫一扫
支付宝扫一扫