参考连接:https://blog.floydhub.com/n-shot-learning/
Artificial Intelligence is the new electricity - Andrew NG
If AI is the new electricity, then data is the new coal.
Unfortunately, just as we’ve seen a hazardous depletion in the amount of available coal, many AI applications have little or no data accessible to them.
New technology has made up for a lack of physical resources; likewise, new techniques are needed to allow applications with little data to perform satisfactorily. This is the issue at the heart of what is becoming a very popular field: N-shot Learning.
You may be asking, what the heck is a shot, anyway? Fair question. A shot is nothing more than a single example available for training, so in N-shot learning, we have N examples for training. With the term “few-shot learning”, the “few” usually lies between zero and five, meaning that training a model with zero examples is known as zero-shot learning, one example is one-shot learning, and so on. All of these variants are trying to solve the same problem with differing levels of training material.
什么是 N-Shot?
Why do we need this when we are already getting less than a 4% error in ImageNet?
To start, ImageNet’s dataset contains a multitude of examples for machine learning, which is not always the case in fields like medical imaging, drug discovery and many others where AI could be crucially important. Typical deep learning architecture relies on substantial data for sufficient outcomes- ImageNet, for example, would need to train on hundreds of hotdog images before accurately assessing new images as hotdogs. And some datasets, much like a fridge after a 4th of July celebration, are greatly lacking in hotdogs.
There are many use cases for machine learning where data is scarce, and that is where this technology comes in. We need to train a deep learning model which has millions or even billions of parameters, all randomly initialized, to learn to classify an unseen image using no more than 5 images. To put it succinctly, our model has to train using a very limited number of hotdog images.
To approach an issue as complex as this one, we need to first define it clearly.
In the N-shot learning field, we have $n$ labeled examples of each $K$ classes, i.e. $N∗K$ total examples which we call support set $S$ . We also have to classify Query Set $Q$, where each example lies in one of the $K$ classes. N-shot learning has three major sub-fields: zero-shot learning, one-shot learning, and few-shot learning, which each deserve individual attention.
Zero-Shot Learning
To me, this is the most interesting sub-field. With zero-shot learning, the target is to classify unseen classes without a single training example.
How does a machine “learn” without having any data to utilize?
Think about it this way. Can you classify an object without ever seeing it?
Yes, you can if you have adequate information about its appearance, properties, and functionality. Think back to how you came to understand the world as a kid. You could spot Mars in the night sky after reading about its color and where it would be that night, or identify the constellation Cassiopeia from only being told “it’s basically a malformed ‘W’”.
According to this year trend in NLP, Zero shot learning will become more effective.
A machine utilizes the metadata of the images to perform the same task. The metadata is nothing but the features associated with the image. Here is a list of a few papers in this field which gave excellent results.
- Learning to Compare: Relation Network for Few-Shot Learning
- Learning Deep Representations of Fine-Grained Visual Descriptions
- Improving zero-shot learning by mitigating the hubness problem
One-Shot 学习
In one-shot learning, we only have a single example of each class. Now the task is to classify any test image to a class using that constraint. There are many different architectures developed to achieve this goal, such as Siamese Neural Networks, which brought about major progress and led to exceptional results, and then matching networks, which also helped us make great leaps in this field.
Now there are many excellent papers for understanding one-shot learning, as below.
- Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
- One-shot Learning with Memory-Augmented Neural Networks
- Prototypical Networks for Few-shot Learning
Few-Shot 学习
Few-shot learning is just a flexible version of one-shot learning, where we have more than one training example (usually two to five images, though most of the above-mentioned models can be used for few-shot learning as well).
During the 2019 Conference on Computer Vision and Pattern Recognition, Meta-Transfer Learning for Few-Shot Learning was presented. This model set the precedent for future research; it gave state-of-the-art results and paved the path for more sophisticated meta-transfer learning methods.
Many of these meta-learning and reinforcement-learning algorithms are combined with typical deep learning algorithms to produce remarkable results. Prototypical networks(原型网络) are one of the most popular deep learning algorithms, and are frequently used for this task.
In this article, we’ll accomplish this task using Prototypical Networks and understand how it works and why it works.
原型网络背后的思想
A diagram of the function of the prototypical network. An encoder maps an image into a vector in the embedding space (dark circles). Support images are used to define the prototype (stars). Distances between prototypes and encoded query images are used to classify them. Source
Unlike typical deep learning architecture, prototypical networks do not classify the image directly, and instead learn the mapping of an image in metric space.
For anyone needing a mathematics refresher, metric space deals with the notion of “distance”. It does not have a distinguished “origin” point; instead, in metric space we only compute the distance of one point to another. You therefore lack the operations of addition and scalar multiplication that you have in a vector space (because, unlike with vectors, a point only represents a coordinate, and adding two coordinates or scaling a coordinate makes no sense!). Check out this link to learn more about the difference between vector space and metric space.
Now that we have that background, we can begin to understand how prototypical networks do not classify the image directly, but instead learn the mapping of an image in metric space. As can be seen in the above diagram, the encoder maps the images of the same class within tight proximity to each other, while different classes are spaced at a considerable distance. This means that whenever a new example is given, the network just checks the nearest cluster and classifies the example to its corresponding class. The underlying model in the prototypical net that maps images into metric space can be called an “Image2Vector” model, which is a Convolutional Neural Network (CNN) based architecture.
Now for those who don’t know a lot about CNNs, you can read more here:
- Check out the list of best deep learning courses here.
- Check out the list of best deep learning book here.
- To learn and apply it quickly refer to Building Your First ConvNet
A brief Introduction to Prototypical Networks
Simply put, their aim is to train a classifier. This classifier can then make generalizations regarding new classes that are unavailable during training, and only needs a small number of examples of each new class. Hence, the training set contains images of a set of classes, while our test set contains images of another set of classes which is entirely disjointed from the former one. In this model, the examples are divided randomly into the support set and query set.
Overview of Prototypical Network
Few-shot prototypes $C_k$ are computed as the mean of embedded support examples for each class. The encoder maps new image($X$) and classifies it to the closest class like $C_2$ in the above image. Source
In the context of few-shot learning, a training iteration is known as an episode. An episode is nothing but a step in which we train the network once, calculate loss and backpropagate the error. In each episode, we select $Nc$ classes at random from the training set. For each class, we randomly sample $Ns$ images. These images belong to the support set and the learning model is known as NsNs-shot model. Another randomly sampled Nq images are obtained which belongs to the query set. Here NcNc, NsNs & NqNq are just hyperparameters in the model where NcNc is the number of classes per iteration, NsNs is the number of support examples per class and NqNq is the number of query examples per class.
After that, we retrieve D-dimensional points from the support set images by passing them through “Image2Vector” model. This model encodes an image with its corresponding point in the metric space. For each class we now have multiple points, but we need to represent them as one point for each class. Hence, we compute geometric center, i.e. mean of the points, for each class. After that, we also need to classify the query images.
To do that, we first need to encode every image in the query set into a point. After that, the distance from each centroid to each query point is calculated. At last, each query image is predicted to lie in the class which is nearest to it. That’s how the model works in general.
But the question now is, what is the architecture of this “Image2Vector” model?
Image2Vector function
Image2vector CNN architecture used in the paper.
For all practical purposes, 4–5 CNN blocks are used. As shown in the above image, each block consists of a CNN layer followed by batch normalization, then by a ReLu activation function which leads into a max pool layer. After all the blocks, the remaining output is flattened and returned as a result. This is the architecture used in the paper and you can use any architecture you like. It is necessary to know that though we call it “Image2Vector” model, it actually converts an image into a 64-dimensional point in the metric space. To understand the difference more, check out these math stack exchange answers.
Loss function
The working of negative log-likelihood. Source.
Now that we know how the model is working, you might be wondering how we’re going to calculate loss function. We need a loss function which is robust enough for our model to learn representation quickly and efficiently. Prototypical Nets use log-softmax loss, which is nothing but log over softmax loss. The log-softmax has the effect of heavily penalizing the model when it fails to predict the correct class, which is what we need. To know more about the loss function go here. Here is a very good discussion about softmax and log-softmax.
Dataset overview
A few classes of images in Omniglot dataset. Source.
The network was trained on the Omniglot dataset. The Omniglot data set is designed for developing more human-like learning algorithms. It contains 1,623 different handwritten characters from 50 different alphabets. Then, to increase the number of classes, all the images are rotated by 90, 180 and 270 degrees, with each rotation resulting in an additional class. Hence the total count of classes reached to 6,492(1,623 * 4) classes. We split images of 4,200 classes to training data while the rest went to the test set. For each episode, we trained the model on 5 examples from each of the 64 randomly selected classes. We trained our model for 1 hour and got about 88% accuracy. The official paper claimed to achieve the accuracy of 99.7% after training for a few hours and tuning a few parameters.
Time to get your hands dirty!
You can easily run the code by clicking on the button below.
Let’s dive into the code!
class Net(nn.Module):
"""
Image2Vector CNN which takes the image of dimension (28x28x3) and return column vector length 64
"""
def sub_block(self, in_channels, out_channels=64, kernel_size=3):
block = torch.nn.Sequential(
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=1),
torch.nn.BatchNorm2d(out_channels),
torch.nn.ReLU()
torch.nn.MaxPool2d(kernel_size=2))
return block
def __init__(self):
super(Net, self).__init__()
self.convnet1 = self.sub_block(3)
self.convnet2 = self.sub_block(64)
self.convnet3 = self.sub_block(64)
self.convnet4 = self.sub_block(64)
def forward(self, x):
x = self.convnet1(x)
x = self.convnet2(x)
x = self.convnet3(x)
x = self.convnet4(x)
x = torch.flatten(x, start_dim=1)
return x
The above snippet is an implementation of image2vector CNN architecture. It takes images of dimensions 28x28x3 and returns a vector of length 64.
class PrototypicalNet(nn.Module):
def __init__(self, use_gpu=False):
super(PrototypicalNet, self).__init__()
self.f = Net()
self.gpu = use_gpu
if self.gpu:
self.f = self.f.cuda()
def forward(self, datax, datay, Ns,Nc, Nq, total_classes):
"""
Implementation of one episode in Prototypical Net
datax: Training images
datay: Corresponding labels of datax
Nc: Number of classes per episode
Ns: Number of support data per class
Nq: Number of query data per class
total_classes: Total classes in training set
"""
k = total_classes.shape[0]
K = np.random.choice(total_classes, Nc, replace=False)
Query_x = torch.Tensor()
if(self.gpu):
Query_x = Query_x.cuda()
Query_y = []
Query_y_count = []
centroid_per_class = {}
class_label = {}
label_encoding = 0
for cls in K:
S_cls, Q_cls = self.random_sample_cls(datax, datay, Ns, Nq, cls)
centroid_per_class[cls] = self.get_centroid(S_cls, Nc)
class_label[cls] = label_encoding
label_encoding += 1
Query_x = torch.cat((Query_x, Q_cls), 0) # Joining all the query set together
Query_y += [cls]
Query_y_count += [Q_cls.shape[0]]
Query_y, Query_y_labels = self.get_query_y(Query_y, Query_y_count, class_label)
Query_x = self.get_query_x(Query_x, centroid_per_class, Query_y_labels)
return Query_x, Query_y
def random_sample_cls(self, datax, datay, Ns, Nq, cls):
"""
Randomly samples Ns examples as support set and Nq as Query set
"""
data = datax[(datay == cls).nonzero()]
perm = torch.randperm(data.shape[0])
idx = perm[:Ns]
S_cls = data[idx]
idx = perm[Ns : Ns+Nq]
Q_cls = data[idx]
if self.gpu:
S_cls = S_cls.cuda()
Q_cls = Q_cls.cuda()
return S_cls, Q_cls
def get_centroid(self, S_cls, Nc):
"""
Returns a centroid vector of support set for a class
"""
return torch.sum(self.f(S_cls), 0).unsqueeze(1).transpose(0,1) / Nc
def get_query_y(self, Qy, Qyc, class_label):
"""
Returns labeled representation of classes of Query set and a list of labels.
"""
labels = []
m = len(Qy)
for i in range(m):
labels += [Qy[i]] * Qyc[i]
labels = np.array(labels).reshape(len(labels), 1)
label_encoder = LabelEncoder()
Query_y = torch.Tensor(label_encoder.fit_transform(labels).astype(int)).long()
if self.gpu:
Query_y = Query_y.cuda()
Query_y_labels = np.unique(labels)
return Query_y, Query_y_labels
def get_centroid_matrix(self, centroid_per_class, Query_y_labels):
"""
Returns the centroid matrix where each column is a centroid of a class.
"""
centroid_matrix = torch.Tensor()
if(self.gpu):
centroid_matrix = centroid_matrix.cuda()
for label in Query_y_labels:
centroid_matrix = torch.cat((centroid_matrix, centroid_per_class[label]))
if self.gpu:
centroid_matrix = centroid_matrix.cuda()
return centroid_matrix
def get_query_x(self, Query_x, centroid_per_class, Query_y_labels):
"""
Returns distance matrix from each Query image to each centroid.
"""
centroid_matrix = self.get_centroid_matrix(centroid_per_class, Query_y_labels)
Query_x = self.f(Query_x)
m = Query_x.size(0)
n = centroid_matrix.size(0)
# The below expressions expand both the matrices such that they become compatible with each other in order to calculate L2 distance.
centroid_matrix = centroid_matrix.expand(m, centroid_matrix.size(0), centroid_matrix.size(1)) # Expanding centroid matrix to "m".
Query_matrix = Query_x.expand(n, Query_x.size(0), Query_x.size(1)).transpose(0,1) # Expanding Query matrix "n" times
Qx = torch.pairwise_distance(centroid_matrix.transpose(1,2), Query_matrix.transpose(1,2))
return Qx
The above snippet is an implementation of a single episode in Prototypical Net. It is well commented, but if you have any doubts just ask in the comments or create an issue here.
Overview of the Network. Source.
The code is structured in the same format in which the algorithm is explained. We give the prototypical network function the following inputs: input image data, input labels, number of classes per iteration i.e NcNc , number of support examples per class i.e NsNs and number of query examples per class i.e. NqNq. The function returns QueryxQueryx, which is a distance matrix from each Query point to each mean point and QueryyQueryy which is a vector containing labels corresponding to QueryxQueryx. QueryyQueryy stores the class in which images of QueryxQueryx actually belong. In the above image, we can see that 3 classes are used, i.e. NcNc =3, and that for each class, a total of 5 examples are used for training, i.e. NsNs=5. Above SS represents the support set that contains those 15 (Ns∗NcNs∗Nc ) images and XX represents the query set. Notice that both support set and query set passes through ff, which is nothing but our “Image2Vector” function. It mapped all the images in metric space. Let’s break the whole process down step by step.
First of all, we choose NcNc classes randomly from the input data. For each class, we randomly select a support set and a query set from the images using the random_sample_cls
function. In the above image, SS is the support set and XX is the query set. Now that we chose the classes (C1C1, C2C2, and C3C3), we pass all the support set examples through the “Image2vector” model and compute the centroid for each class using the get_centroid
function. The same can be observed in the nearby image where C1C1 and C2C2 are the center, computed using the neighboring points. Each centroid represents a class and will be used for classifying queries.
Centroid calculation in the Network. Source.
After computing centroid for each class, we now have to predict the query image to one of the classes. For that, we need actual labels corresponding to each query, which we get by using the get_query_y
function. The QueryyQueryy is categorical data and the function converts this categorical text data into a one-hot vector, which will only be “1” in the row label where the image corresponding to the column point actually belongs, and will be “0” else in the column.
After that, we need points corresponding to each QueryxQueryx image in order to classify it. We get the points using “Image2Vector” model and now we need to classify them. For that purpose, we calculate the distance between each point in QueryxQueryx to each class center. This gives us a matrix where index ijij represents the distance of the point corresponding to ith query image from the center of jth class. We used the get_query_x
function to construct the matrix and save the matrix in the QueryxQueryx variable. The same can be seen in the nearby image. For each example in the query set, The distance it has from C1C1, C2C2 and C3C3 is being calculated. In this case, xx is closest to C2C2 and we can therefore say that xx is predicted to belong to class C2C2.
Programmatically, we can use a simple argmin function to do the same, i.e. to find out the class where the image was predicted to lie. Then we use the predicted class and actual class to calculate loss and backpropagate the error.
If you want to use the trained model or just have to retrain again for yourself, here is my implementation. You can use it as an API and train the model using a couple of lines of code. You can find this network in action here.
Resources
Here are a few resources that might help you learn this topic thoroughly:
- One Shot Learning with Siamese Networks using Keras
- One-Shot Learning: Face Recognition using Siamese Neural Network
- Matching network official implementation
- Prototypical Network official implementation.
- Meta-Learning for Semi-Supervised Few-Shot Classification
Limitations
Though prototypical networks produce great results, they still have limitations. The first one is the lack of generalization. It works on the Omniglot dataset well because all the images in there are images of a character, and hence share a few similar characteristics. However, if we were to try using the model to classify different breeds of cats, it wouldn’t give us accurate results. Cats and character images share few characteristics, and the number of common features which can be exploited to map the image on the corresponding metric space is negligible.
Another limitation to prototypical networks is that they only use mean to decide center, and ignore the variance in support set. This hinders the classifying ability of the model when the images have noise. This limitation is overcome by using Gaussian Prototypical Networks which utilizes the variance in the class by modeling the embedded points using Gaussian formulations.
Conclusion
Few-Shot learning has been a topic of active research for a while. There are many novel approaches which use prototypical networks, like this meta-learning one, and which show great results. Researchers are also exploring it with reinforcement-learning, which also has great potential. The best thing about this model is that it is simple and easy to understand, and it gives incredible results.