Deep Learning without labels – Self-Supervised Learning¶
In this blog post we’ll discuss Self-Supervised Learning!
Classical supervised learning suffers from four main problems:
- Fully labelled datasets are expensive or not available at all.
- There is a large amount of unlabeled datasets which cannot be leveraged by Supervised Learning.
- Difficuly in creating One-Shot or Few-Shot systems, for example: Why are we able to detect a dog after only seeing few images showing one whereas a CNN needs to see thousands?
- Fixed to the initial classes it was trained on. Not easy to extend to other classes not in the original labeled datasets.
NOTE: The pretrained weights used in this blog post can be found here: https://drive.google.com/file/d/13TDyUL8yc48srKfLWfJkDnlju4_-fcyH/view?usp=sharing
Don’t wait, download now and transform your career!
Your FREE Guide to Become a Data Scientist
Self-Supervised Learning¶
Imagine the (common) task of creating an image classifier for a customer.
But unfortunately, the customer has only described the problem to you, but has not provided any class labels. In order to train a CNN you would need to label a large portion of the dataset which is not a fun task at all!
How could we prevent that?
One idea would be to pretrain our network on a task for which labels are available.
However, it would be even better to pretrain it on a task where the data itself defines the labels. If that sounds confusing, dont worry, we will look at many examples later in this post!
This forces the network to learn something about the data (e.g that grass is often green and the sky is blue)
Using this method should increase the performance when only training on a labelled subset removing the necessity of labeling the full dataset.
Example: Image Colorization (Colorization as a Proxy Task for Visual Understanding, Larsson et al. 2018, Link)
In their work, Larsson et al. showed that colorizing a previously decolorized image can act as a pretraining strategy!
During training, in a first step, the dataloader decolorizes the input image. A subsequent CNN then has the task to predict the colorization of the given image. As the initial colorization is known we can use any regression loss (e.g MSE) as loss function to optimize the weights of the network
The figure below illustrates the entire process:
Using the pretraining strategy, the network will learn how landscapes look like.
After this initial self-supervised training we can replace the regression head responsible to colorize the image with output we actually want to predict and then finally finetune on our small labelled subset!
The authors could show that Top-5 (ground truth is amongst the 5 predictions with the highest score) classification accuracy was increased by over 20% from 42.5% to 63.1% on ImageNet
Let’s build this self-supervised framework and use it on Cifar10.
First we start with the dataset class. Luckily, cifar10 is directly contained within torchvision
import torch
import torch.nn as nn
import torchvision
import numpy as np
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import cv2
from tqdm.notebook import tqdm
train_data = torchvision.datasets.CIFAR10(root="~/Temp", train=True, download=True)
val_data = torchvision.datasets.CIFAR10(root="~/Temp", train=False, download=True)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /Users/marci/Temp/cifar-10-python.tar.gz
0%| | 0/170498071 [00:00<?, ?it/s]
Extracting /Users/marci/Temp/cifar-10-python.tar.gz to /Users/marci/Temp Files already downloaded and verified
class Cifar10GrayscaleDataset(torch.utils.data.Dataset):
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img, target = self.dataset.data[idx], self.dataset.targets[idx]
img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
return np.expand_dims(img_gray / 255, 0).astype(np.float32), \
np.transpose(img / 255, (2, 0, 1)).astype(np.float32), \
target
train_dataset = Cifar10GrayscaleDataset(train_data)
val_dataset = Cifar10GrayscaleDataset(val_data)
img_gray, img_color, label = train_dataset[1]
fig, axis = plt.subplots(1, 2)
axis[0].imshow(img_color.transpose(1, 2, 0))
axis[1].imshow(img_gray[0], cmap="gray")
fig.suptitle(f"Label: {label} ")
Next we build the model class. Let’s use a ResNet50 encoder-decoder architecture
class Colorizer(pl.LightningModule):
def __init__(self):
super().__init__()
self.downsample = torchvision.models.resnet50()
self.downsample.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.downsample = torch.nn.Sequential(*list(self.downsample.children())[:-2])
self.upsample = nn.Sequential(
nn.Conv2d(2048, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Upsample(scale_factor=2),
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Upsample(scale_factor=2),
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Upsample(scale_factor=2),
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Upsample(scale_factor=2),
nn.Conv2d(16, 3, kernel_size=3, stride=1, padding=1)
)
self.loss_fn = torch.nn.MSELoss()
def forward(self, X):
features = self.downsample(X)
upsampled = self.upsample(features)
return upsampled
def training_step(self, batch, batch_idx):
img_gray, img, _ = batch
pred = self(img_gray)
loss = self.loss_fn(pred, img)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return [optimizer]
model = Colorizer()
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_dataset, batch_size=64)
model(torch.randn(2, 1, 32, 32)).shape
torch.Size([2, 3, 32, 32])
trainer = pl.Trainer(gpus=0)
GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
| Name | Type | Params ------------------------------------ 0 | model | ResNet | 12.7 M 1 | loss_fn | MSELoss | 0 ------------------------------------ 12.7 M Trainable params 0 Non-trainable params 12.7 M Total params 50.985 Total estimated model params size (MB)
Validation sanity check: 0it [00:00, ?it/s]
/Users/marci/opt/miniconda3/envs/anatomy/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:132: UserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance. rank_zero_warn(
Training: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
model = Colorizer.load_from_checkpoint("colorizer.ckpt")
model.eval();
model.cpu();
img_gray, img_color, _ = val_dataset[3]
img_gray_tensor = torch.tensor(img_gray).unsqueeze(0)
img_color_predicted = model(img_gray_tensor)[0]
plt.figure()
plt.imshow(img_color.transpose(1, 2, 0))
plt.figure()
plt.imshow(img_color_predicted.detach().permute(1, 2, 0))
Finetuning¶
Before finetuning the above pretrained network, let us train a standard classifier using only 10 labelled samples per class
train_dataset_subsampled = torchvision.datasets.CIFAR10(root="~/Temp", train=True, download=True)
Files already downloaded and verified
idx_list = []
for i in range(10):
idx_list.append(np.where(np.array(train_dataset_subsampled.targets) == i)[0][:10])
train_dataset_subsampled.data = train_data.data[np.array(idx_list).flatten()]
train_dataset_subsampled.targets = np.array(train_data.targets)[np.array(idx_list).flatten()]
train_dataset_subsampled.targets
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
train_dataset_subsampled = Cifar10GrayscaleDataset(train_dataset_subsampled)
class CifarClassifier(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = torchvision.models.resnet50()
self.model.fc = torch.nn.Linear(in_features=2048, out_features=10)
self.loss_fn = torch.nn.CrossEntropyLoss()
def forward(self, X):
return self.model(X)
def training_step(self, batch, batch_idx):
_, img, target = batch
pred = self(img)
loss = self.loss_fn(pred, target)
return loss
def validation_step(self, batch, batch_idx):
_, img, target = batch
pred = self(img)
loss = self.loss_fn(pred, target)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
return [optimizer]
classifier = CifarClassifier()
train_dl = torch.utils.data.DataLoader(train_dataset_subsampled, batch_size=64, shuffle=True, num_workers=0)
val_dl = torch.utils.data.DataLoader(val_dataset, batch_size=64, num_workers=0)
trainer = pl.Trainer(max_epochs=25)
trainer.fit(classifier, train_dataloaders=train_dl)
classifier.eval()
val_predictions = []
val_targets = []
for (_, img, label) in tqdm(val_dataset):
img = torch.tensor(img).unsqueeze(0).float()
pred = classifier(img).argmax(1).item()
val_predictions.append(pred)
val_targets.append(label)
(np.array(val_predictions) == np.array(val_targets)).mean()
0%| | 0/10000 [00:00<?, ?it/s]
0.1121
We can see that the accuracy is ~0.1 which is basically a random guess given that we have 10 different classes
Semi-Supervised Learning¶
Now, after training the Colorization model we fine tune it on the above defined cifar10 subset.
To do so, we only use the encoder (downsample) and add a linear layer consisting of 2048 neurons.
class CifarClassifierFineTune(pl.LightningModule):
def __init__(self, pretrained_model):
super().__init__()
self.downsample = pretrained_model.downsample
self.downsample[0] = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
self.loss_fn = torch.nn.CrossEntropyLoss()
def forward(self, X):
features = self.downsample(X)
features = torch.nn.functional.adaptive_avg_pool2d(features, output_size=(1, 1))
features = torch.flatten(features, 1)
prediction = self.fc(features)
return prediction
def training_step(self, batch, batch_idx):
_, img, label = batch
pred = self(img)
loss = self.loss_fn(pred, label)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
return [optimizer]
fine_tune = CifarClassifierFineTune(model)
fine_tune = CifarClassifierFineTune(model)
train_dl = torch.utils.data.DataLoader(train_dataset_subsampled, batch_size=64, shuffle=True, num_workers=0)
val_dl = torch.utils.data.DataLoader(val_dataset, batch_size=64, num_workers=0)
trainer = pl.Trainer(max_epochs=25)
trainer.fit(fine_tune, train_dataloaders=train_dl)
GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs | Name | Type | Params ------------------------------------------------ 0 | downsample | Sequential | 23.5 M 1 | fc | Linear | 20.5 K 2 | loss_fn | CrossEntropyLoss | 0 ------------------------------------------------ 23.5 M Trainable params 0 Non-trainable params 23.5 M Total params 94.114 Total estimated model params size (MB)
Training: 0it [00:00, ?it/s]
fine_tune.eval().cpu()
val_predictions = []
val_targets = []
for (_, img, label) in tqdm(val_dataset):
img = torch.tensor(img).unsqueeze(0).float()
pred = fine_tune(img).argmax(1).item()
val_predictions.append(pred)
val_targets.append(label)
(np.array(val_predictions) == np.array(val_targets)).mean()
0%| | 0/10000 [00:00<?, ?it/s]
0.21
Hopefully this blog post helped cement some of the concepts of self supervised learning for you!