Skip to content

L2T

L2T

Bases: Attack

The L2T (Learning to Transform) attack.

From the paper: Learning to Transform Dynamically for Better Adversarial Transferability.

Note

The L2T attack requires the torchvision package as it uses torchvision.transforms for image transformations.

Parameters:

Name Type Description Default
model Module | AttackModel

The model to attack.

required
normalize Callable[[Tensor], Tensor] | None

A transform to normalize images.

None
device device | None

Device to use for tensors. Defaults to cuda if available.

None
eps float

The maximum perturbation. Defaults to 8/255.

8 / 255
steps int

Number of steps. Defaults to 10.

10
alpha float | None

Step size, eps / steps if None. Defaults to None.

None
decay float

Decay factor for the momentum term. Defaults to 1.0.

1.0
clip_min float

Minimum value for clipping. Defaults to 0.0.

0.0
clip_max float

Maximum value for clipping. Defaults to 1.0.

1.0
targeted bool

Targeted attack if True. Defaults to False.

False
Source code in torchattack/l2t.py
@register_attack()
class L2T(Attack):
    """The L2T (Learning to Transform) attack.

    > From the paper: [Learning to Transform Dynamically for Better Adversarial
    Transferability](https://arxiv.org/abs/2405.14077).

    Note:
        The L2T attack requires the `torchvision` package as it uses
        `torchvision.transforms` for image transformations.

    Args:
        model: The model to attack.
        normalize: A transform to normalize images.
        device: Device to use for tensors. Defaults to cuda if available.
        eps: The maximum perturbation. Defaults to 8/255.
        steps: Number of steps. Defaults to 10.
        alpha: Step size, `eps / steps` if None. Defaults to None.
        decay: Decay factor for the momentum term. Defaults to 1.0.
        clip_min: Minimum value for clipping. Defaults to 0.0.
        clip_max: Maximum value for clipping. Defaults to 1.0.
        targeted: Targeted attack if True. Defaults to False.
    """

    def __init__(
        self,
        model: nn.Module | AttackModel,
        normalize: Callable[[torch.Tensor], torch.Tensor] | None = None,
        device: torch.device | None = None,
        eps: float = 8 / 255,
        steps: int = 10,
        alpha: float | None = None,
        decay: float = 1.0,
        num_scale: int = 5,
        clip_min: float = 0.0,
        clip_max: float = 1.0,
        targeted: bool = False,
    ) -> None:
        super().__init__(model, normalize, device)

        self.eps = eps
        self.steps = steps
        self.alpha = alpha
        self.decay = decay
        self.num_scale = num_scale
        self.clip_min = clip_min
        self.clip_max = clip_max
        self.targeted = targeted
        self.lossfn = nn.CrossEntropyLoss()

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Perform L2T on a batch of images.

        Args:
            x: A batch of images. Shape: (N, C, H, W).
            y: A batch of labels. Shape: (N).

        Returns:
            The perturbed images if successful. Shape: (N, C, H, W).
        """

        g = torch.zeros_like(x)
        delta = torch.zeros_like(x, requires_grad=True)
        aug_params = torch.zeros(len(AUG_OPS), requires_grad=True, device=self.device)

        ops_num = 2
        lr = 0.01

        # If alpha is not given, set to eps / steps
        if self.alpha is None:
            self.alpha = self.eps / self.steps

        # Perform L2T
        for _ in range(self.steps):
            aug_probs = []
            losses = []

            for _ in range(self.num_scale):
                # Create a random aug search instance for the given number of ops
                rw_search = RWAugSearch(ops_num)

                # Randomly select ops based on the aug params
                ops_indices = select_op(aug_params, ops_num)
                # Compute the joint probs of the selected ops
                aug_prob = trace_prob(aug_params, ops_indices)

                # Update the aug search with the selected ops
                rw_search.ops_num = ops_num
                rw_search.ops_indices = ops_indices

                # Save the computed probs for the current scale to later update the aug params
                aug_probs.append(aug_prob)

                # Compute loss
                outs = self.model(self.normalize(rw_search(x + delta)))
                num_copies = math.floor((len(outs) + 0.01) / len(y))
                loss = self.lossfn(outs, y.repeat(num_copies))

                if self.targeted:
                    loss = -loss

                losses.append(loss)

            # Compute gradient
            loss = torch.stack(losses).mean()
            delta.grad = torch.autograd.grad(loss, delta)[0]

            # Compute gradient for augmentation params
            aug_loss = (torch.stack(aug_probs) * torch.stack(losses)).mean()
            aug_params.grad = torch.autograd.grad(aug_loss, aug_params)[0]

            # Update augmentation params
            aug_params.data = aug_params.data + lr * aug_params.grad
            aug_params.grad.detach_()
            aug_params.grad.zero_()

            # Apply momentum term and compute delta update
            g = self.decay * g + delta.grad / torch.mean(
                torch.abs(delta.grad), dim=(1, 2, 3), keepdim=True
            )

            # Update delta
            delta.data = delta.data + self.alpha * g.sign()
            delta.data = torch.clamp(delta.data, -self.eps, self.eps)
            delta.data = torch.clamp(x + delta.data, self.clip_min, self.clip_max) - x

            # Zero out gradient
            delta.grad.detach_()
            delta.grad.zero_()

        return x + delta

forward(x, y)

Perform L2T on a batch of images.

Parameters:

Name Type Description Default
x Tensor

A batch of images. Shape: (N, C, H, W).

required
y Tensor

A batch of labels. Shape: (N).

required

Returns:

Type Description
Tensor

The perturbed images if successful. Shape: (N, C, H, W).

Source code in torchattack/l2t.py
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Perform L2T on a batch of images.

    Args:
        x: A batch of images. Shape: (N, C, H, W).
        y: A batch of labels. Shape: (N).

    Returns:
        The perturbed images if successful. Shape: (N, C, H, W).
    """

    g = torch.zeros_like(x)
    delta = torch.zeros_like(x, requires_grad=True)
    aug_params = torch.zeros(len(AUG_OPS), requires_grad=True, device=self.device)

    ops_num = 2
    lr = 0.01

    # If alpha is not given, set to eps / steps
    if self.alpha is None:
        self.alpha = self.eps / self.steps

    # Perform L2T
    for _ in range(self.steps):
        aug_probs = []
        losses = []

        for _ in range(self.num_scale):
            # Create a random aug search instance for the given number of ops
            rw_search = RWAugSearch(ops_num)

            # Randomly select ops based on the aug params
            ops_indices = select_op(aug_params, ops_num)
            # Compute the joint probs of the selected ops
            aug_prob = trace_prob(aug_params, ops_indices)

            # Update the aug search with the selected ops
            rw_search.ops_num = ops_num
            rw_search.ops_indices = ops_indices

            # Save the computed probs for the current scale to later update the aug params
            aug_probs.append(aug_prob)

            # Compute loss
            outs = self.model(self.normalize(rw_search(x + delta)))
            num_copies = math.floor((len(outs) + 0.01) / len(y))
            loss = self.lossfn(outs, y.repeat(num_copies))

            if self.targeted:
                loss = -loss

            losses.append(loss)

        # Compute gradient
        loss = torch.stack(losses).mean()
        delta.grad = torch.autograd.grad(loss, delta)[0]

        # Compute gradient for augmentation params
        aug_loss = (torch.stack(aug_probs) * torch.stack(losses)).mean()
        aug_params.grad = torch.autograd.grad(aug_loss, aug_params)[0]

        # Update augmentation params
        aug_params.data = aug_params.data + lr * aug_params.grad
        aug_params.grad.detach_()
        aug_params.grad.zero_()

        # Apply momentum term and compute delta update
        g = self.decay * g + delta.grad / torch.mean(
            torch.abs(delta.grad), dim=(1, 2, 3), keepdim=True
        )

        # Update delta
        delta.data = delta.data + self.alpha * g.sign()
        delta.data = torch.clamp(delta.data, -self.eps, self.eps)
        delta.data = torch.clamp(x + delta.data, self.clip_min, self.clip_max) - x

        # Zero out gradient
        delta.grad.detach_()
        delta.grad.zero_()

    return x + delta