Skip to content

Loading pretrained models

The AttackModel

To launch any adversarial attack, you would need a model to attack.

torchattack provides a simple abstraction over both torchvision and timm models, to load pretrained image classification models on ImageNet.

First, import torch, import AttackModel from torchattack, and determine the device to use.

import torch
from torchattack import AttackModel

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Pretrained models are loaded by its name

Contrary to torchvision.models, AttackModel loads a pretrained model by its name.

To load a ResNet-50 model for instance.

model = AttackModel.from_pretrained(model_name='resnet50').to(device)

The AttackModel.from_pretrained() method does three things under the hood:

  1. It automatically loads the model from either torchvision (by default) or timm (if not found in torchvision).
  2. It sets the model to evaluation mode by calling model.eval().
  3. It resolves the model's transform and normalize functions associated with its pretrained weights to the AttackModel instance.
  4. And finally, it populates the resolved transformation attributes to the model's meta attribute.

Doing so, we not only get our pretrained model set up, but also its necessary associated, and more importantly, separated transform and normalization functions(1).

  1. Separating the model's normalize function from its transform is crucial for launching attacks, as adversarial perturbation is crafted within the original image space — most often within (0, 1).
transform, normalize = model.transform, model.normalize
>>> model.meta 
AttackModelMeta(resize_size=232, crop_size=224, interpolation=<InterpolationMode.BILINEAR: 'bilinear'>, antialias=True, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

Specifying the model source

AttackModel honors an explicit model source to load from, by prepending the model name with tv/ or timm/, for torchvision and timm respectively.

For instance, to load the ViT-B/16 model from timm.

vit_b16 = AttackModel.from_pretrained(model_name='timm/vit_base_patch16_224').to(device)

To load the Inception-v3 model from torchvision.

inv_v3 = AttackModel.from_pretrained(model_name='tv/inception_v3').to(device)

Or, explicitly specify using timm as the source with from_timm=True.

pit_b = AttackModel.from_pretrained(model_name='pit_b_224', from_timm=True).to(device)