Skip to content

Register attack

New in 1.4.0

The register_attack() decorator was first introduced in v1.4.0.

register_attack() is a decorator that registers an attack to the attack registry. This allows external attacks to be recognized by create_attack().

The attack registry resides at ATTACK_REGISTRY. This registry is populated at import time. To register an additional attack, simply decorate the attack class with @register_attack().

from torchattack import Attack, register_attack


@register_attack()
class MyNewAttack(Attack):
    def __init__(self, model, normalize, device):
        super().__init__(model, normalize, device)

    def forward(self, x):
        return x

Afterwards, the attack can be accessed in the same manner as the built-in attacks.

import torch
from torchattack import create_attack, AttackModel

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AttackModel.from_pretrained('resnet50').to(device)
adversary = create_attack('MyNewAttack', model=model, normalize=model.normalize, device=device)

register_attack(name=None, category=AttackCategory.COMMON)

Decorator to register an attack class in the attack registry.

Source code in torchattack/attack.py
def register_attack(
    name: str | None = None, category: str | AttackCategory = AttackCategory.COMMON
) -> Callable[[Type['Attack']], Type['Attack']]:
    """Decorator to register an attack class in the attack registry."""

    def wrapper(attack_cls: Type['Attack']) -> Type['Attack']:
        key = name if name else attack_cls.__name__
        if key in ATTACK_REGISTRY:
            return ATTACK_REGISTRY[key]
        attack_cls.attack_name = key
        attack_cls.attack_category = AttackCategory.verify(category)
        ATTACK_REGISTRY[key] = attack_cls
        return attack_cls

    return wrapper