Professional Documents
Culture Documents
Important Pytorch Stuff
Important Pytorch Stuff
In [2]:
import torch.nn as nn
import torch
from torch.autograd.variable import Variable
from torchvision import datasets, models, transforms
In [20]:
(https://github.com/Spandan-Madan/A-Collection-of-important-tasks-in-pytorch)
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 1/20
8/1/23, 7:55 PM Important Pytorch Stuff
(https://github.com/Spandan-Madan/A-Collection-of-important-tasks-in-pytorch)
(https://github.com/Spandan-Madan/A-Collection-of-important-tasks-in-pytorch)
(https://github.com/Spandan-Madan/A-Collection-of-important-tasks-in-pytorch)
Let us first explore this model's layers and then make a decision as to which ones we want to
freeze. By freeze we mean that we want the parameters of those layers to be fixed. When fine
tuning a model, we are basically taking a model trained on Dataset A, and then training it on a new
Dataset B. We could potentially start the training from scratch as well, but it would be like re-
inventing the wheel. Let me explain why.
Suppose, I want to train a dataset to learn to differentiate between a car and a bicycle. Now, I could
potentially gather images of both categories and train a network from scratch. But, given the
majority of work already out there, it's easy to find a model trained to identify things like Dogs, cats,
and humans. Admittedly, neither of these 3 look like cars or bicycles. However, it's still better than
nothing. We could start by taking this model, and train it to learn car v/s bicycle. Gains : 1) It will be
faster, 2) We need lesser images of cats and bicycles.
Now, let's take a look at the contents of a resnet18. We use the function .children() for this purpose.
This lets us look at the contents/layers of a model. Then, we use the .parameters() function to
access the parameters/weights of any layer. Finally, every parameter has a property .requires_grad
which defines whether a parameter is trained or frozen. By default it is True, and the network
updates it in every iteration. If it is set to False, then it is not updated and is said to be "frozen".
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 2/20
8/1/23, 7:55 PM Important Pytorch Stuff
In [21]:
child_counter = 0
for child in model.children():
print(" child", child_counter, "is -")
print(child)
child_counter += 1
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 3/20
8/1/23, 7:55 PM Important Pytorch Stuff
child 0 is -
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bi
as=False)
child 1 is -
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
child 2 is -
ReLU (inplace)
child 3 is -
MaxPool2d (size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1,
1))
child 4 is -
Sequential (
(0): BasicBlock (
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padd
ing=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
(relu): ReLU (inplace)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padd
ing=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
)
(1): BasicBlock (
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padd
ing=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
(relu): ReLU (inplace)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padd
ing=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
)
)
child 5 is -
Sequential (
(0): BasicBlock (
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), pad
ding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
(relu): ReLU (inplace)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), pa
dding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
(downsample): Sequential (
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=
False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
)
)
(1): BasicBlock (
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), pa
dding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
(relu): ReLU (inplace)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), pa
dding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
)
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 4/20
8/1/23, 7:55 PM Important Pytorch Stuff
)
child 6 is -
Sequential (
(0): BasicBlock (
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), pa
dding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
(relu): ReLU (inplace)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), pa
dding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
(downsample): Sequential (
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias
=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
)
)
(1): BasicBlock (
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), pa
dding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
(relu): ReLU (inplace)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), pa
dding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
)
)
child 7 is -
Sequential (
(0): BasicBlock (
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), pa
dding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
(relu): ReLU (inplace)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), pa
dding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
(downsample): Sequential (
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias
=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
)
)
(1): BasicBlock (
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), pa
dding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
(relu): ReLU (inplace)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), pa
dding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
)
)
child 8 is -
AvgPool2d (
)
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 5/20
8/1/23, 7:55 PM Important Pytorch Stuff
child 9 is -
Linear (512 -> 1000)
Now, you can see that some of the children are actually big chunks and have layers within them. To
access one level deeper we can run .children() on a child object as well!
Let's saw we want to freeze all parameters up to first BasicBlock of Child 6. First, lets see a
parameter and set it to frozen -
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 6/20
8/1/23, 7:55 PM Important Pytorch Stuff
In [22]:
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 7/20
8/1/23, 7:55 PM Important Pytorch Stuff
(0 ,1 ,.,.) =
-1.6319e-02 3.3193e-02 -2.2146e-04 ... 1.2571e-03 -1.3313e-02
-4.7580e-02
-4.9329e-02 3.2548e-02 5.4202e-03 ... -4.5771e-02 -2.6863e-03
-3.6992e-03
8.7714e-03 2.4772e-02 1.0026e-02 ... 1.6512e-02 -7.4382e-03
6.0990e-02
... ⋱ ...
-4.0751e-02 3.3605e-04 -2.1426e-02 ... 1.1318e-02 -1.5222e-04
-3.5020e-02
-4.1432e-02 -9.1312e-03 -1.7572e-02 ... 1.6974e-03 5.9792e-03
1.2868e-02
-4.4471e-02 -1.1013e-02 4.9902e-03 ... -2.1241e-02 2.2371e-02
-2.1672e-02
(0 ,2 ,.,.) =
1.0826e-02 -4.4230e-02 -1.5594e-02 ... -1.3197e-03 6.1211e-03
-1.6262e-02
-1.3989e-02 -3.2357e-02 2.0250e-02 ... 7.5012e-03 2.8761e-04
-2.1318e-02
-7.8574e-04 1.7702e-02 1.0301e-02 ... -2.0074e-02 4.4735e-02
1.0149e-02
... ⋱ ...
-2.4707e-02 2.3952e-03 6.5615e-04 ... 4.4371e-02 -1.0678e-02
2.3425e-02
-2.4330e-02 1.3018e-02 1.1473e-02 ... -3.6666e-03 -2.1145e-02
-1.5511e-02
-3.0876e-02 -1.6071e-02 -2.4506e-02 ... 2.7417e-03 6.2566e-03
1.6208e-02
⋮
(1 ,0 ,.,.) =
-1.0333e-02 1.5746e-02 3.0517e-02 ... -1.0851e-02 -7.7141e-04
-4.0873e-02
-1.6966e-02 -3.6460e-02 5.3054e-02 ... -2.0641e-02 -1.8781e-02
-7.1048e-03
3.9752e-02 -3.6240e-02 -4.6019e-03 ... -2.1766e-02 -2.5955e-03
-3.4346e-02
... ⋱ ...
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 8/20
8/1/23, 7:55 PM Important Pytorch Stuff
(1 ,1 ,.,.) =
2.0440e-02 7.0067e-03 -2.8885e-03 ... -4.8313e-02 4.4430e-02
-1.6539e-03
1.1405e-02 1.3499e-02 -1.0181e-02 ... 3.2469e-03 2.6244e-02
3.3834e-03
5.5702e-03 1.0040e-02 -1.1350e-02 ... 1.1416e-02 2.5718e-02
-1.1672e-02
... ⋱ ...
-2.2712e-02 3.1696e-03 -2.5725e-02 ... -3.1355e-02 4.4028e-02
-1.7592e-02
-1.4702e-02 -2.3544e-02 -1.7768e-02 ... -3.6875e-02 -2.1635e-02
4.1800e-03
-1.3653e-02 -2.0815e-02 2.5550e-02 ... 2.9072e-02 2.6506e-02
-2.1846e-02
(1 ,2 ,.,.) =
2.6863e-02 -1.1023e-02 -3.8302e-02 ... -2.9343e-02 -1.3996e-02
-1.3504e-02
2.1842e-02 1.6150e-02 -5.8077e-03 ... -1.3306e-04 1.2624e-02
8.6928e-03
-9.3998e-03 3.3038e-02 1.6890e-02 ... 2.2586e-02 2.9318e-02
1.3380e-02
... ⋱ ...
2.2479e-02 -1.6830e-02 -1.1064e-02 ... 7.8837e-03 5.3104e-03
5.4637e-02
-2.3910e-02 2.9069e-02 -3.2870e-02 ... -2.4287e-02 2.2562e-02
-1.8842e-02
-3.3899e-02 4.5987e-02 -3.7849e-03 ... -5.9900e-03 4.6050e-02
-2.0960e-02
⋮
(2 ,0 ,.,.) =
-1.3902e-02 2.4403e-02 2.3496e-02 ... -3.6818e-03 -1.3517e-02
-3.4732e-03
-1.3494e-02 -5.9880e-03 -1.8047e-02 ... -2.9621e-02 2.3363e-02
4.0067e-02
-6.7423e-02 -4.1190e-02 -1.1207e-02 ... 1.1878e-02 1.2203e-02
6.7536e-03
... ⋱ ...
-1.7779e-02 -2.1686e-02 -1.7968e-02 ... 6.1823e-04 -1.1427e-02
3.8056e-03
4.7457e-02 5.8501e-03 1.3968e-02 ... 1.1012e-02 9.1363e-04
5.3913e-03
1.3919e-02 3.9247e-02 -2.6585e-03 ... 4.3866e-02 -5.1949e-02
2.9817e-02
(2 ,1 ,.,.) =
1.4485e-02 2.5544e-02 2.7102e-02 ... 2.2926e-02 1.4463e-02
3.4932e-02
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 9/20
8/1/23, 7:55 PM Important Pytorch Stuff
(2 ,2 ,.,.) =
-1.8339e-02 -2.6231e-02 1.2887e-02 ... 1.3463e-02 2.4595e-02
9.5057e-03
-1.4393e-02 1.5389e-02 2.7461e-03 ... -5.0952e-02 1.2911e-02
-1.8666e-03
-4.3669e-03 -5.8640e-03 -2.1081e-02 ... -1.6746e-02 1.9807e-02
8.5502e-03
... ⋱ ...
-1.0186e-02 -2.3766e-02 1.6136e-02 ... -8.0708e-02 -2.5807e-02
1.5700e-02
-9.2855e-03 -1.9718e-02 1.9457e-02 ... -1.3100e-02 -1.0821e-02
2.1621e-02
8.1854e-03 -3.1841e-02 1.3033e-02 ... -2.0506e-02 1.2037e-02
6.4032e-04
...
⋮
(61,0 ,.,.) =
9.4316e-03 2.3648e-02 -8.4966e-03 ... -2.2285e-03 -1.4238e-02
5.2163e-02
1.1587e-03 1.2474e-02 -1.6408e-02 ... -2.2976e-02 6.6632e-03
3.6772e-03
3.7755e-02 8.0352e-04 8.9609e-03 ... 2.1675e-02 -3.6027e-03
-1.1842e-02
... ⋱ ...
1.2762e-02 1.9184e-02 2.7700e-02 ... 9.5043e-04 -1.7118e-03
2.9772e-02
2.8610e-02 -1.5271e-02 5.1606e-02 ... -3.9722e-03 -3.3161e-02
-5.1093e-02
-2.0437e-02 1.5838e-02 2.7344e-02 ... -2.6124e-03 3.0168e-02
-2.4499e-02
(61,1 ,.,.) =
1.3869e-02 2.9713e-02 -2.2218e-03 ... -5.0385e-02 -3.8294e-02
5.0754e-02
-2.1760e-02 -1.1468e-02 2.2944e-02 ... 1.0988e-02 -1.8024e-03
2.4294e-02
-1.2950e-02 1.5043e-02 -1.8723e-03 ... -2.3066e-02 1.9586e-02
-2.3099e-03
... ⋱ ...
1.7489e-02 2.0666e-02 1.1381e-02 ... 1.7181e-02 -4.0002e-02
-1.9487e-02
4.2988e-02 2.6599e-02 -2.4061e-02 ... -3.5973e-02 -3.5824e-03
1.1549e-02
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 10/20
8/1/23, 7:55 PM Important Pytorch Stuff
(61,2 ,.,.) =
-3.0184e-02 2.5058e-02 -2.9590e-03 ... -4.1951e-03 -2.5637e-02
1.5420e-02
-2.7229e-02 -4.2415e-03 -2.7928e-02 ... -1.1144e-02 1.1510e-03
-1.2208e-02
5.0833e-02 -1.8479e-02 -1.8046e-02 ... 2.5169e-03 2.3112e-03
8.1823e-03
... ⋱ ...
8.0303e-03 2.8856e-02 7.8058e-03 ... 2.3697e-02 -3.2406e-04
4.0989e-02
-1.1054e-02 1.7881e-02 -7.2309e-03 ... -2.6414e-02 -3.9901e-02
2.2379e-02
6.5656e-03 7.0047e-03 -3.3296e-03 ... 2.8250e-03 5.2304e-03
-4.7857e-03
⋮
(62,0 ,.,.) =
-2.1959e-02 -1.4979e-02 1.2155e-02 ... 2.6459e-02 -2.6932e-03
-5.3835e-03
1.3290e-02 -1.2008e-02 1.7921e-02 ... -8.4513e-03 4.7896e-02
8.4751e-03
8.2594e-03 4.3179e-03 9.1544e-03 ... -7.6523e-03 -1.0549e-02
-1.5311e-02
... ⋱ ...
7.0592e-03 -5.5720e-03 5.7900e-02 ... -7.5445e-03 1.6987e-02
-4.9320e-02
1.2382e-03 2.9988e-02 1.5510e-02 ... 5.7371e-03 -1.9073e-02
1.1134e-02
1.1451e-02 -2.5826e-02 -2.6174e-02 ... 2.8024e-02 -8.2831e-02
3.7890e-02
(62,1 ,.,.) =
-2.3812e-02 2.6700e-02 2.4878e-02 ... -1.2590e-02 1.4942e-02
7.3503e-03
-3.2630e-02 -2.1997e-02 -5.1692e-02 ... 1.8524e-02 -2.1054e-02
8.3692e-03
2.4765e-02 3.4338e-02 4.4222e-02 ... -6.9486e-03 -1.3035e-02
-1.6388e-02
... ⋱ ...
2.7492e-02 2.1982e-02 -2.1263e-02 ... -3.3880e-02 1.2141e-02
6.8169e-03
-2.5239e-02 -2.1256e-02 -3.8697e-03 ... -1.5080e-02 -1.0833e-03
1.2719e-02
3.4392e-02 -1.6532e-02 -4.6843e-04 ... 1.6460e-02 3.8641e-02
-3.2814e-02
(62,2 ,.,.) =
-3.2325e-02 -4.7595e-02 2.8533e-02 ... 5.0494e-02 1.8599e-02
-1.5499e-02
2.3927e-02 -2.9398e-02 -5.7063e-02 ... -1.2592e-03 1.5265e-02
-2.7379e-03
-1.9947e-02 -2.8994e-02 -1.0712e-02 ... -1.3606e-02 -5.3947e-03
1.1104e-02
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 11/20
8/1/23, 7:55 PM Important Pytorch Stuff
... ⋱ ...
-1.5108e-02 -5.3751e-03 -6.6983e-02 ... 2.1419e-02 2.4127e-02
-1.6207e-02
3.8778e-02 1.6684e-02 2.3376e-02 ... 1.4579e-02 2.0048e-02
-5.1052e-02
6.9827e-04 -3.1476e-02 2.2414e-02 ... 3.5637e-02 -6.2860e-03
2.1901e-03
⋮
(63,0 ,.,.) =
3.6892e-02 -1.0093e-02 1.4863e-02 ... -1.9750e-02 -3.5509e-02
-1.9200e-02
-2.5392e-02 8.6157e-05 -2.5180e-03 ... 3.3918e-03 9.8297e-03
1.7278e-03
1.5289e-02 -4.9295e-03 1.7144e-02 ... 1.3728e-02 2.6355e-02
1.4548e-03
... ⋱ ...
-1.0451e-02 -1.9699e-02 -3.0967e-02 ... -8.3925e-03 8.0206e-04
-9.3016e-03
1.5797e-02 -2.2791e-02 3.6044e-02 ... -2.5666e-02 4.4125e-02
-8.0478e-03
-1.3139e-02 -1.9758e-02 1.4868e-02 ... -9.0605e-03 2.7318e-02
-1.0136e-02
(63,1 ,.,.) =
-4.4853e-03 -3.4300e-02 -3.2744e-02 ... -1.2309e-02 3.7756e-02
-2.6677e-02
-1.1187e-02 -2.2497e-03 -1.6091e-02 ... -2.4397e-02 6.7627e-03
-1.5241e-02
-1.4663e-02 1.9999e-02 -3.5072e-02 ... 8.0089e-03 1.9439e-02
3.5001e-02
... ⋱ ...
-4.4512e-03 -2.9858e-02 8.8768e-03 ... 2.8442e-02 -2.7011e-02
4.4332e-03
1.0174e-02 -4.3775e-02 -2.9107e-02 ... 2.9213e-02 1.7982e-02
3.4712e-02
-1.3463e-02 -1.4656e-02 4.7337e-03 ... 1.6846e-02 -1.6850e-02
-1.9964e-02
(63,2 ,.,.) =
-4.5719e-03 2.6236e-02 5.5996e-03 ... 3.7875e-03 8.6500e-03
6.2772e-03
-2.3837e-02 -1.6006e-03 -2.1914e-02 ... -1.3637e-02 -1.9399e-02
-1.6704e-03
-7.3654e-03 -1.4505e-02 4.5674e-02 ... 3.2031e-02 -2.9054e-02
-1.4125e-02
... ⋱ ...
2.5126e-02 9.5580e-03 -4.0513e-03 ... -2.6763e-02 1.8345e-02
6.2725e-04
-5.2027e-02 1.6874e-02 -8.8866e-03 ... 7.5890e-03 -1.1678e-02
2.8387e-03
-8.3314e-03 3.1768e-02 2.7137e-02 ... 1.4631e-02 -1.9952e-02
1.1544e-02
[torch.FloatTensor of size 64x3x7x7]
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 12/20
8/1/23, 7:55 PM Important Pytorch Stuff
Evidently, training this will take a lot of calculations. So, by setting a bunch of these to frozen,
training becomes much faster. Now, let's freeze up to first BasicBlock of Child 6
In [23]:
child_counter = 0
for child in model.children():
if child_counter < 6:
print("child ",child_counter," was frozen")
for param in child.parameters():
param.requires_grad = False
elif child_counter == 6:
children_of_child_counter = 0
for children_of_child in child.children():
if children_of_child_counter < 1:
for param in children_of_child.parameters():
param.requires_grad = False
print('child ', children_of_child_counter, 'of child',child_coun
ter,' was frozen')
else:
print('child ', children_of_child_counter, 'of child',child_coun
ter,' was not frozen')
children_of_child_counter += 1
else:
print("child ",child_counter," was not frozen")
child_counter += 1
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 13/20
8/1/23, 7:55 PM Important Pytorch Stuff
Important Note
Now that you have frozen this network, another thing changes to make this work. That is your
optimizer. Your optimizer is the one which actually updates these values. By default, the models are
written like this -
But, this will give you an error as this will try to update all the parameters of model. However, you've
set a bunch of them to frozen! So, the way to pass only the ones still being updated is -
In [ ]:
# Saving a Model
torch.save(model.state_dict(), MODEL_PATH)
# First create a model and define it's architecture as done above in this notebo
ok. If you want a custom architecture.
# read below it's been covered below.
checkpoint = torch.load(MODEL_PATH)
model.load_state_dict(checkpoint)
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 14/20
8/1/23, 7:55 PM Important Pytorch Stuff
In [25]:
# Get number of parameters going in to the last layer. we need this to change th
e final layer.
num_final_in = model.fc.in_features
# The final layer of the model is model.fc so we can basically just overwrite it
#to have the output = number of classes we need. Say, 300 classes.
NUM_CLASSES = 300
model.fc = nn.Linear(num_final_in, NUM_CLASSES)
In [ ]:
We can get the layers by using model.children() as before. Then, we can convert this into a list by
using a list() command on it. Then, we can remove the last layer by indexing the list. Finally, we can
use the PyTorch function nn.Sequential() to stack this modified list together into a new model. You
can edit the list in any way you want. That is, you can delete the last 2 layers if you want the
features of an image from the 3rd last layer!
You may even delete layers from the middle of the model. But obviously, this would lead to incorrect
number of features going in to the layer after it as most layers change size of image. In this case,
you can index that specific layer of the model and overwrite it just as I showed you immediately
above!
In [33]:
new_model = nn.Sequential(*list(model.children())[:-1])
In [34]:
new_model_2_removed = nn.Sequential(*list(model.children())[:-2])
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 15/20
8/1/23, 7:55 PM Important Pytorch Stuff
ADDING LAYERS
Say, you want to add a fully connected layer to the model we have right now. One obvious way
would be to edit the list I discussed above and appending to it another layer. However, often times
we have such a model trained and want to see if we can load that model, and add just a new layer
on top of it. As mentioned above, the loaded model should have the SAME architecture as saved
one, and so we can't use the list method.
We need to add layers on top. The way to do this is simple in PyTorch - We just need to create a
custom model! And this brings us to our next section - creating custom models!
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 16/20
8/1/23, 7:55 PM Important Pytorch Stuff
In [ ]:
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 17/20
8/1/23, 7:55 PM Important Pytorch Stuff
# New models are defined as classes. Then, when we want to create a model we cre
ate an object instantiating this class.
class Resnet_Added_Layers_Half_Frozen(nn.Module):
def __init__(self,LOAD_VIS_URL=None):
super(ResnetCombinedFull2, self).__init__()
# Start with half the resnet model, swap out the final layer because th
at's the model we had defined above.
model = models.resnet18(pretrained = False)
num_final_in = model.fc.in_features
model.fc = nn.Linear(num_final_in, 300)
# Now that the architecture is defined same as above, let's load the mod
el we would have trained above.
checkpoint = torch.load(MODEL_PATH)
model.load_state_dict(checkpoint)
# Let's freeze the same as above. Same code as above without the print s
tatements
child_counter = 0
for child in model.children():
if child_counter < 6:
for param in child.parameters():
param.requires_grad = False
elif child_counter == 6:
children_of_child_counter = 0
for children_of_child in child.children():
if children_of_child_counter < 1:
for param in children_of_child.parameters():
param.requires_grad = False
else:
children_of_child_counter += 1
else:
print("child ",child_counter," was not frozen")
child_counter += 1
# The forward function defines the flow of the input data and thus decides w
hich layer/chunk goes on top of what.
def forward(self,x):
x = self.vismodel(x)
x = torch.squeeze(x)
x = self.projective(x)
x = self.nonlinearity(x)
x = self.projective2(x)
return x
Sometimes, we need to define our own loss functions. And here are a few things to know about this
-
custom Loss functions are defined using a custom class too. They inherit from
torch.nn.Module just like the custom model.
Often, we need to change the dimenions of one of our inputs. This can be done using
view() function.
If we want to add a dimension to a tensor, use the unsqueeze() function.
The value finally being returned by a loss function MUST BE a scalar value. Not a
vector/tensor.
The value being returned must be a Variable. This is so that it can be used to update the
parameters. The best way to do so is to just make sure that both x and y being passed in
are Variables. That way any function of the two will also be a Variable.
A Pytorch Variable is just a Pytorch Tensor, but Pytorch is tracking the operations being
done on it so that it can backpropagate to get the gradient.
Here I show a custom loss called Regress_Loss which takes as input 2 kinds of input x and y. Then
it reshapes x to be similar to y and finally returns the loss by calculating L2 difference between
reshaped x and y. This is a standard thing you'll run across very often in training networks.
Consider x to be shape (5,10) and y to be shape (5,5,10). So, we need to add a dimension to x,
then repeat it along the added dimension to match the dimension of y. Then, (x-y) will be the shape
(5,5,10). We will have to add over all three dimensions i.e. three torch.sum() to get a scalar.
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 19/20
8/1/23, 7:55 PM Important Pytorch Stuff
In [38]:
class Regress_Loss(torch.nn.Module):
def __init__(self):
super(Regress_Loss,self).__init__()
def forward(self,x,y):
y_shape = y.size()[1]
x_added_dim = x.unsqueeze(1)
x_stacked_along_dimension1 = x_added_dim.repeat(1,NUM_WORDS,1)
diff = torch.sum((y - x_stacked_along_dimension1)**2,2)
totloss = torch.sum(torch.sum(torch.sum(diff)))
return totloss
In [ ]:
https://spandan-madan.github.io/A-Collection-of-important-tasks-in-pytorch/ 20/20