How skip connections changed Deep Learning

Problem

In training Deep Neural Networks there was a problem that occured time and time again. As the number of hidden layers were increased, even when the problems of exploding/dimishing gradients were addressed, the model accuracy would not improve. Interestingly, the accuracy would actually start degrading rapidly!

This was noticed by the esteemed authors of the ResNets paper. As described in it:

When deeper networks are able to start converging, a degradation problem has been exposed: with the network depth increasing, accuracy gets saturated (which might be unsurprising) and then degrades rapidly. Unexpectedly, such degradation is not caused by overfitting, and adding more layers to a suitably deep model leads to higher training error.

They noted that the degradation indicated that not all systems are easy to optimize and “that a deeper model should produce no higher training error than its shallower counterpart.”

This degradation implied that somehow the information being captured by the earlier layers were being abstracted away by the time it reached the end of the network.

They proposed a simple architectural change.

A simple solution

Up until now, the large networks were created by stacking networks on top of each other. The issue was that the later layers didn’t retain much of the information from the early layers.

The solution was connecting the early layers to the later layers (shortcut connections). In Resnets, it was by means of an identity connection.

This is also wonderfully described in the densenet paper :

“As CNNs become increasingly deep, a new research problem emerges: as information about the input or gradient passes through many layers, it can vanish and “wash out” by the time it reaches the end (or beginning) of the network. Many recent publications address this or related problems. ResNets and Highway Networks bypass signal from one layer to the next via identity connections. Stochastic depth shortens ResNets by randomly dropping layers during training to allow better information and gradient flow. FractalNets repeatedly combine several parallel layer sequences with different number of convolutional blocks to obtain a large nominal depth, while maintaining many short paths in the network. Although these different approaches vary in network topology and training procedure, they all share a key characteristic: they create short paths from early layers to later layers.

This allowed the training of truly deep CNNs. The original LeNet5 consisted of 5 layers, VGG consisted of 19 layers. From Highway Networks and Resnets we are now able to cross the 100 layer barrier.

Simple implementation in pytorch

def conv_layer(ni:int, nf:int, kernel_size:int=3, stride:int=2) -> nn.Sequential:
''' ni is the number of input channels
    nf in the number of output channels
    Returns a (convolution and activation) block'''

    return nn.Sequential(
        nn.Conv2d(ni, nf, kernel_size, stride, padding=1),  #Convolution
        nn.BatchNorm2d(nf),  # Batch Normalization
        nn.ReLU()  # Activation function
    )

class ResBlock(nn.Module):
    def __init__(self, nf:int):
        super(ResBlock, self).__init__()  #initialize
        self.conv1 = conv_layer(nf, nf, 3, 1) #input and output channels are of same size
        self.conv2 = conv_layer(nf, nf, 3, 1) #input and output channels are of same size
    
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        return self.conv2(self.conv1(x)) + x #SKIP CONNECTION

We have to obviously extend the pytorch nn.Module class.

x is the input and conv1, conv2 are our two convolution layers undergone normalization and passed through ReLU. When we go through the res block, we pass the input to the first convolution block and the output of which we pass to the second convolution block. But when we do so, we add the input(x) to it.

With that, we implemented skip connection. Very often, the solution to seemingly complicated problems are surprisingly simple.

comments powered by Disqus