Skip to content

Create attack

create_attack(attack, model=None, normalize=None, device=None, *, eps=None, **kwargs)

Create a torchattack instance based on the provided attack name and config.

Parameters:

Name Type Description Default
attack Union[Type[Attack], str]

The attack to create, either by name or class instance.

required
model Optional[Union[Module, AttackModel]]

The model to be attacked. Can be an instance of nn.Module or AttackModel. Defaults to None.

None
normalize Optional[Callable[[Tensor], Tensor]]

The normalization function specific to the model. Defaults to None.

None
device Optional[device]

The device on which the attack will be executed. Defaults to None.

None
eps Optional[float]

The epsilon value for the attack. Defaults to None.

None
kwargs Any

Additional config parameters for the attack. Defaults to None.

{}

Returns:

Type Description
Attack

An instance of the specified attack.

Raises:

Type Description
ValueError

If the specified attack name is not supported within torchattack.

Source code in torchattack/create_attack.py
def create_attack(
    attack: Union[Type['Attack'], str],
    model: Optional[Union[nn.Module, AttackModel]] = None,
    normalize: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
    device: Optional[torch.device] = None,
    *,
    eps: Optional[float] = None,
    **kwargs: Any,
) -> Attack:
    """Create a torchattack instance based on the provided attack name and config.

    Args:
        attack: The attack to create, either by name or class instance.
        model: The model to be attacked. Can be an instance of nn.Module or AttackModel. Defaults to None.
        normalize: The normalization function specific to the model. Defaults to None.
        device: The device on which the attack will be executed. Defaults to None.
        eps: The epsilon value for the attack. Defaults to None.
        kwargs: Additional config parameters for the attack. Defaults to None.

    Returns:
        An instance of the specified attack.

    Raises:
        ValueError: If the specified attack name is not supported within torchattack.
    """

    # Determine attack name and check if it is supported
    attack_name = attack if isinstance(attack, str) else attack.attack_name
    if attack_name not in ATTACK_REGISTRY:
        raise ValueError(f"Attack '{attack_name}' is not supported within torchattack.")
    # Get attack class if passed as a string
    attack_cls = ATTACK_REGISTRY[attack] if isinstance(attack, str) else attack

    # `eps` is explicitly set as it is such a common argument
    # All other arguments should be passed as keyword arguments
    if eps is not None:
        kwargs['eps'] = eps

    # Special handling for generative attacks
    attacker: Attack = (
        attack_cls(device=device, **kwargs)
        if attack_cls.is_category('GENERATIVE')
        else attack_cls(model=model, normalize=normalize, device=device, **kwargs)
    )
    return attacker