class Attack(ABC):
    """The base class for all attacks."""
    attack_name: str
    attack_category: AttackCategory
    def __init__(
        self,
        model: nn.Module | AttackModel | None,
        normalize: Callable[[torch.Tensor], torch.Tensor] | None,
        device: torch.device | None,
    ) -> None:
        super().__init__()
        self.model = (
            # If model is an AttackModel, use the model attribute
            model.model
            if isinstance(model, AttackModel)
            # If model is a nn.Module, use the model itself
            else model
            if model is not None
            # Otherwise, use an empty nn.Sequential acting as a dummy model
            else nn.Sequential()
        )
        # Set device to given or defaults to cuda if available
        is_cuda = torch.cuda.is_available()
        self.device = device if device else torch.device('cuda' if is_cuda else 'cpu')
        # If normalize is None, use identity function
        self.normalize = normalize if normalize else lambda x: x
    @classmethod
    def is_category(cls, category: str | AttackCategory) -> bool:
        """Check if the attack class belongs to the given category."""
        return cls.attack_category is AttackCategory.verify(category)
    @abstractmethod
    def forward(self, *args: Any, **kwds: Any) -> Any:
        pass
    def __call__(self, *args: Any, **kwds: Any) -> Any:
        return self.forward(*args, **kwds)
    def __repr__(self) -> str:
        name = self.__class__.__name__
        def repr_map(k: str, v: Any) -> str:
            if isinstance(v, float):
                return f'{k}={v:.3f}'
            if k in [
                'model',
                'normalize',
                'feature_module',
                'feature_maps',
                'features',
                'hooks',
                'generator',
            ]:
                return f'{k}={v.__class__.__name__}'
            if isinstance(v, torch.Tensor):
                return f'{k}={v.shape}'
            return f'{k}={v}'
        args = ', '.join(repr_map(k, v) for k, v in self.__dict__.items())
        return f'{name}({args})'
    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, Attack):
            return False
        eq_name_attrs = [
            'model',
            'normalize',
            'lossfn',
            'feature_module',  # FIA, ILPD, NAA
            'feature_maps',
            'features',
            'hooks',  # PNAPatchOut, TGR, VDC
            'sub_basis',  # GeoDA
            'generator',  # BIA, CDA, LTP
        ]
        for attr in eq_name_attrs:
            if not (hasattr(self, attr) and hasattr(other, attr)):
                continue
            if (
                getattr(self, attr).__class__.__name__
                != getattr(other, attr).__class__.__name__
            ):
                return False
        for attr in self.__dict__:
            if attr in eq_name_attrs:
                continue
            self_val = getattr(self, attr)
            other_val = getattr(other, attr)
            if isinstance(self_val, torch.Tensor):
                if not isinstance(other_val, torch.Tensor):
                    return False
                if not torch.equal(self_val, other_val):
                    return False
            elif self_val != other_val:
                return False
        return True