Skip to content

Attack runner

runner

CLI tool for evaluating adversarial attack transferability.

Evaluates the effectiveness of attacks against a surrogate model and optionally measures transfer rates to victim models. Supports all built-in attacks from torchattack against torchvision.models and timm models.

Example CLI usage
$ python -m torchattack.evaluate.runner --attack PGD --eps 16/255 \
    --model-name resnet18 --victim-model-names vgg11 densenet121 \
    --dataset-root datasets/nips2017 --max-samples 200

run_attack(attack, attack_args=None, model_name='resnet50', victim_model_names=None, dataset_root='datasets/nips2017', max_samples=100, batch_size=4, save_adv_batch=-1)

Helper function to run evaluation on attacks.

Example
>>> from torchattack import FGSM
>>> args = {"eps": 8 / 255, "clip_min": 0.0, "clip_max": 1.0}
>>> run_attack(attack=FGSM, attack_args=args)

Parameters:

Name Type Description Default
attack str | Type[Attack]

The name of the attack to run.

required
attack_args dict | None

A dict of keyword arguments to pass to the attack. Defaults to None.

None
model_name str

The name of the white-box surrogate model to attack. Defaults to "resnet50".

'resnet50'
victim_model_names list[str] | None

The names of the black-box victim models to attack. Defaults to None.

None
dataset_root str

Root directory of the NIPS2017 dataset. Defaults to "datasets/nips2017".

'datasets/nips2017'
max_samples int

Max number of samples used for the evaluation. Defaults to 100.

100
batch_size int

Batch size for the dataloader. Defaults to 4.

4
save_adv_batch int

Batch index for optionally saving a batch of adversarial examples to visualize. Set to -1 to disable. Defaults to -1.

-1
Source code in torchattack/evaluate/runner.py
def run_attack(
    attack: str | Type['Attack'],
    attack_args: dict | None = None,
    model_name: str = 'resnet50',
    victim_model_names: list[str] | None = None,
    dataset_root: str = 'datasets/nips2017',
    max_samples: int = 100,
    batch_size: int = 4,
    save_adv_batch: int = -1,
) -> None:
    """Helper function to run evaluation on attacks.

    Example:
        ```pycon
        >>> from torchattack import FGSM
        >>> args = {"eps": 8 / 255, "clip_min": 0.0, "clip_max": 1.0}
        >>> run_attack(attack=FGSM, attack_args=args)
        ```

    Args:
        attack: The name of the attack to run.
        attack_args: A dict of keyword arguments to pass to the attack. Defaults to None.
        model_name: The name of the white-box surrogate model to attack. Defaults to "resnet50".
        victim_model_names: The names of the black-box victim models to attack. Defaults to None.
        dataset_root: Root directory of the NIPS2017 dataset. Defaults to "datasets/nips2017".
        max_samples: Max number of samples used for the evaluation. Defaults to 100.
        batch_size: Batch size for the dataloader. Defaults to 4.
        save_adv_batch: Batch index for optionally saving a batch of adversarial examples
            to visualize. Set to -1 to disable. Defaults to -1.
    """

    import torch
    from rich import print
    from rich.progress import track

    from torchattack import AttackModel, create_attack
    from torchattack.evaluate.dataset import NIPSLoader
    from torchattack.evaluate.metric import FoolingRateMetric

    if attack_args is None:
        attack_args = {}
    is_targeted = attack_args.get('targeted', False)

    # Setup model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = AttackModel.from_pretrained(model_name).to(device)
    transform, normalize = model.transform, model.normalize

    # Set up dataloader
    dataloader = NIPSLoader(
        root=dataset_root,
        batch_size=batch_size,
        transform=transform,
        max_samples=max_samples,
        return_target_label=is_targeted,
    )
    dataloader = track(dataloader, description='Attacking')  # type: ignore

    # Set up attack and trackers
    frm = FoolingRateMetric(is_targeted)
    adversary = create_attack(attack, model, normalize, device, **attack_args)
    print(adversary)

    # Setup victim models if provided
    if victim_model_names:
        victims = [AttackModel.from_pretrained(vn) for vn in victim_model_names]
        victim_frms = [FoolingRateMetric(is_targeted) for _ in victim_model_names]

    # Run attack over the dataset
    for i, (x, y, _) in enumerate(dataloader):
        if is_targeted:
            yl, yt = y  # Unpack target labels from `y` if the attack is targeted
            x, y = x.to(device), (yl.to(device), yt.to(device))
        else:
            x, y = x.to(device), y.to(device)

        # Create adversarial examples. Pass target labels if the attack is targeted
        advs = adversary(x, y[1]) if is_targeted else adversary(x, y)

        # Track accuracy
        cln_outs = model(normalize(x))
        adv_outs = model(normalize(advs))
        frm.update(y, cln_outs, adv_outs)

        # Save one batch of adversarial examples if requested
        if i == save_adv_batch:
            from torchattack.evaluate import save_image_batch

            save_image_batch(advs, f'outputs_{adversary.attack_name}_b{i}')

        # Track transfer fooling rates if victim models are provided
        if victim_model_names:
            for _, (v, vfrm) in enumerate(zip(victims, victim_frms)):
                v.to(device)
                vtransform = v.create_relative_transform(model)
                v_cln_outs = v(v.normalize(vtransform(x)))
                v_adv_outs = v(v.normalize(vtransform(advs)))
                vfrm.update(y, v_cln_outs, v_adv_outs)

    # Print results
    cln_acc, adv_acc, fr = frm.compute()
    print(f'Surrogate ({model_name}): {cln_acc=:.2%}, {adv_acc=:.2%} ({fr=:.2%})')

    if victim_model_names:
        for v, vfrm in zip(victims, victim_frms):
            vcln_acc, vadv_acc, vfr = vfrm.compute()
            print(
                f'Victim ({v.model_name}): cln_acc={vcln_acc:.2%}, '
                f'adv_acc={vadv_acc:.2%} (fr={vfr:.2%})'
            )