File size: 2,966 Bytes
9393e64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""

Ryan Tietjen

Aug 2024

Creates a vit base 16 model for the demo

"""

import torch
import torchvision
from torch import nn

def vit_b_16(num_classes:int=101, 

             seed:int=31, 

             freeze_gradients:bool=True,

             unfreeze_blocks=0):
    """

    Initializes and configures a Vision Transformer (ViT-B/16) model with options for freezing gradients

    and adjusting the number of trainable blocks.



    This function sets up a ViT-B/16 model pre-trained on the ImageNet-1K dataset, modifies the classification

    head to accommodate a specified number of classes, and optionally freezes the gradients of certain blocks

    to prevent them from being updated during training.



    Parameters:

    num_classes (int): The number of output classes for the new classification head. Default is 101.

    seed (int): Random seed for reproducibility. Default is 31.

    freeze_gradients (bool): If True, freezes the gradients of the model's parameters, except for the last few

                             blocks specified by `unfreeze_blocks`. Default is True.

    unfreeze_blocks (int): Number of transformer blocks from the end whose parameters will have trainable gradients.

                           Default is 0, implying all are frozen except the new classification head.



    Returns:

    tuple: A tuple containing:

        - model (torch.nn.Module): The modified ViT-B/16 model with a new classifier head.

        - transforms (callable): The transformation function required for input images, as recommended by the

                                 pre-trained weights.



    Example:

    ```python

    model, transform = vit_b_16(num_classes=101, seed=31, freeze_gradients=True, unfreeze_blocks=2)

    ```



    Notes:

    - The total number of parameters in the model is calculated and used to determine which parameters to freeze.

    - The classifier head of the model is replaced with a new linear layer that outputs to the specified number of classes.

    """

    torch.manual_seed(seed) 

    #Create model and extract weights/transforms
    weights = torchvision.models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
    transforms = weights.transforms()
    model = torchvision.models.vit_b_16(weights=weights)

    params = list(model.parameters())
    params_to_unfreeze = 4 + (12 * unfreeze_blocks)
    # Total number of parameters
    total_params = len(params)


    #Freeze gradients to avoid modifying the original model    
    if freeze_gradients:
        for i, param in enumerate(params):
            # Set requires_grad to False for all but the last n encoder blocks
            if i < total_params - params_to_unfreeze:
                param.requires_grad = False

    #modify classifier model to fit our 
    model.heads = nn.Sequential(
        nn.Linear(in_features=768,
                  out_features=num_classes))
    
    return model, transforms