Adversarial-Attacks-PyTorch
Torchattacks is a PyTorch library that provides adversarial attacks to generate adversarial examples. It contains PyTorch-like interface and functions that make it easier for PyTorch users to implement adversarial attacks (README [KOR]).
import torchattacks
atk = torchattacks.PGD(model, eps=8/255, alpha=2/255, steps=4)
# If, images are normalized:
# atk.set_normalization_used(mean=[...], std=[...])
adv_images = atk(images, labels)
Table of Contents
- Requirements and Installation
- Getting Started
- Performance Comparison
- Citation
- Contribution
- Recommended Sites and Packages
Requirements and Installation
π Requirements
- PyTorch version >=1.4.0
- Python version >=3.6
π¨ Installation
pip install torchattacks
or install from source
pip install git+https://github.com/Harry24k/adversarial-attacks-pytorch.git
Getting Started
β οΈ Precautions
- All models should return ONLY ONE vector of
(N, C)
whereC = number of classes
. Considering most models in torchvision.models return one vector of(N,C)
, whereN
is the number of inputs andC
is thenumber of classes, torchattacks also only supports limited forms of output. Please check the shape of the modelβs output carefully. torch.backends.cudnn.deterministic = True
to get same adversarial examples with fixed random seed. Some operations are non-deterministic with float tensors on GPU [discuss]. If you want to get same results with same inputs, please runtorch.backends.cudnn.deterministic = True
[ref].
π Demos
- White-box Attack on CIFAR10 (code, nbviewer)
- White-box Attack on ImageNet (code, nbviewer)
- Transfer Attack on CIFAR10 (code, nbviewer)
Torchattacks supports following functions:
Targeted mode
- Random target label:
# random labels as target labels.
atk.set_mode_targeted_random(n_classses)
- Least likely label:
# label with the k-th smallest probability used as target labels.
atk.set_mode_targeted_least_likely(kth_min)
- By custom function:
# label from mapping function
atk.set_mode_targeted_by_function(target_map_function=lambda images, labels:(labels+1)%10)
- By labels:
# label from user provide.
atk = torchattacks.PGD(model, eps=8/255, alpha=2/255, steps=4)
atk.set_mode_targeted_by_label(quiet=True) # do not show the message
# shift all class loops one to the right, 1=>2, 2=>3, .., 9=>0
target_labels = (labels + 1) % 10
adv_images = atk(images, target_labels)
- Return to default:
atk.set_mode_default()
Save adversarial images
# Save
atk.save(data_loader, save_path="./data.pt", verbose=True)
# Load
adv_loader = atk.load(load_path="./data.pt")
Training/Eval during attack
# For RNN-based models, we cannot calculate gradients with eval mode.
# Thus, it should be changed to the training mode during the attack.
atk.set_training_mode(model_training=False, batchnorm_training=False, dropout_training=False)
Make a set of attacks
- Strong attacks
atk1 = torchattacks.FGSM(model, eps=8/255)
atk2 = torchattacks.PGD(model, eps=8/255, alpha=2/255, iters=40, random_start=True)
atk = torchattacks.MultiAttack([atk1, atk2])
- Binary search for CW
atk1 = torchattacks.CW(model, c=0.1, steps=1000, lr=0.01)
atk2 = torchattacks.CW(model, c=1, steps=1000, lr=0.01)
atk = torchattacks.MultiAttack([atk1, atk2])
- Random restarts
atk1 = torchattacks.PGD(model, eps=8/255, alpha=2/255, iters=40, random_start=True)
atk2 = torchattacks.PGD(model, eps=8/255, alpha=2/255, iters=40, random_start=True)
atk = torchattacks.MultiAttack([atk1, atk2])
Torchattacks also supports collaboration with other attack packages.
FoolBox
https://github.com/bethgelab/foolbox
from torchattacks.attack import Attack
import foolbox as fb
# L2BrendelBethge
class L2BrendelBethge(Attack):
def __init__(self, model):
super(L2BrendelBethge, self).__init__("L2BrendelBethge", model)
self.fmodel = fb.PyTorchModel(self.model, bounds=(0,1), device=self.device)
self.init_attack = fb.attacks.DatasetAttack()
self.adversary = fb.attacks.L2BrendelBethgeAttack(init_attack=self.init_attack)
self._attack_mode = 'only_default'
def forward(self, images, labels):
images, labels = images.to(self.device), labels.to(self.device)
# DatasetAttack
batch_size = len(images)
batches = [(images[:batch_size//2], labels[:batch_size//2]),
(images[batch_size//2:], labels[batch_size//2:])]
self.init_attack.feed(model=self.fmodel, inputs=batches[0][0]) # feed 1st batch of inputs
self.init_attack.feed(model=self.fmodel, inputs=batches[1][0]) # feed 2nd batch of inputs
criterion = fb.Misclassification(labels)
init_advs = self.init_attack.run(self.fmodel, images, criterion)
# L2BrendelBethge
adv_images = self.adversary.run(self.fmodel, images, labels, starting_points=init_advs)
return adv_images
atk = L2BrendelBethge(model)
Adversarial-Robustness-Toolbox (ART)
https://github.com/IBM/adversarial-robustness-toolbox
import torch.nn as nn
import torch.optim as optim
from torchattacks.attack import Attack
import art.attacks.evasion as evasion
from art.classifiers import PyTorchClassifier
# SaliencyMapMethod (or Jacobian based saliency map attack)
class JSMA(Attack):
def __init__(self, model, theta=1/255, gamma=0.15, batch_size=128):
super(JSMA, self).__init__("JSMA", model)
self.classifier = PyTorchClassifier(
model=self.model, clip_values=(0, 1),
loss=nn.CrossEntropyLoss(),
optimizer=optim.Adam(self.model.parameters(), lr=0.01),
input_shape=(1, 28, 28), nb_classes=10)
self.adversary = evasion.SaliencyMapMethod(classifier=self.classifier,
theta=theta, gamma=gamma,
batch_size=batch_size)
self.target_map_function = lambda labels: (labels+1)%10
self._attack_mode = 'only_default'
def forward(self, images, labels):
adv_images = self.adversary.generate(images, self.target_map_function(labels))
return torch.tensor(adv_images).to(self.device)
atk = JSMA(model)
π₯ List of implemented papers
The distance measure in parentheses.
Name | Paper | Remark |
---|---|---|
FGSM (Linf) |
Explaining and harnessing adversarial examples (Goodfellow et al., 2014) | |
BIM (Linf) |
Adversarial Examples in the Physical World (Kurakin et al., 2016) | Basic iterative method or Iterative-FSGM |
CW (L2) |
Towards Evaluating the Robustness of Neural Networks (Carlini et al., 2016) | |
RFGSM (Linf) |
Ensemble Adversarial Traning: Attacks and Defences (Tramèr et al., 2017) | Random initialization + FGSM |
PGD (Linf) |
Towards Deep Learning Models Resistant to Adversarial Attacks (Mardry et al., 2017) | Projected Gradient Method |
PGDL2 (L2) |
Towards Deep Learning Models Resistant to Adversarial Attacks (Mardry et al., 2017) | Projected Gradient Method |
MIFGSM (Linf) |
Boosting Adversarial Attacks with Momentum (Dong et al., 2017) | |
TPGD (Linf) |
Theoretically Principled Trade-off between Robustness and Accuracy (Zhang et al., 2019) | |
EOTPGD (Linf) |
Comment on "Adv-BNN: Improved Adversarial Defense through Robust Bayesian Neural Network" (Zimmermann, 2019) | EOT+PGD |
APGD (Linf, L2) |
Reliable evaluation of adversarial robustness with an ensemble of diverse parameter-free attacks (Croce et al., 2020) | |
APGDT (Linf, L2) |
Reliable evaluation of adversarial robustness with an ensemble of diverse parameter-free attacks (Croce et al., 2020) | Targeted APGD |
FAB (Linf, L2, L1) |
Minimally distorted Adversarial Examples with a Fast Adaptive Boundary Attack (Croce et al., 2019) | |
Square (Linf, L2) |
Square Attack: a query-efficient black-box adversarial attack via random search (Andriushchenko et al., 2019) | |
AutoAttack (Linf, L2) |
Reliable evaluation of adversarial robustness with an ensemble of diverse parameter-free attacks (Croce et al., 2020) | APGD+APGDT+FAB+Square |
DeepFool (L2) |
DeepFool: A Simple and Accurate Method to Fool Deep Neural Networks (Moosavi-Dezfooli et al., 2016) | |
OnePixel (L0) |
One pixel attack for fooling deep neural networks (Su et al., 2019) | |
SparseFool (L0) |
SparseFool: a few pixels make a big difference (Modas et al., 2019) | |
DIFGSM (Linf) |
Improving Transferability of Adversarial Examples with Input Diversity (Xie et al., 2019) | |
TIFGSM (Linf) |
Evading Defenses to Transferable Adversarial Examples by Translation-Invariant Attacks (Dong et al., 2019) | |
NIFGSM (Linf) |
Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks (Lin, et al., 2022) | |
SINIFGSM (Linf) |
Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks (Lin, et al., 2022) | |
VMIFGSM (Linf) |
Enhancing the Transferability of Adversarial Attacks through Variance Tuning (Wang, et al., 2022) | |
VNIFGSM (Linf) |
Enhancing the Transferability of Adversarial Attacks through Variance Tuning (Wang, et al., 2022) | |
Jitter (Linf) |
Exploring Misclassifications of Robust Neural Networks to Enhance Adversarial Attacks (Schwinn, Leo, et al., 2021) | |
Pixle (L0) |
Pixle: a fast and effective black-box attack based on rearranging pixels (Pomponi, Jary, et al., 2022) | |
LGV (Linf, L2, L1, L0) |
LGV: Boosting Adversarial Example Transferability from Large Geometric Vicinity (Gubri, et al., 2022) | |
SPSA (Linf) |
Adversarial Risk and the Dangers of Evaluating Against Weak Attacks (Uesato, Jonathan, et al., 2018) | |
JSMA (L0) |
The Limitations of Deep Learning in Adversarial Settings (Papernot, Nicolas, et al., 2016) | |
EADL1 (L1) |
EAD: Elastic-Net Attacks to Deep Neural Networks (Chen, Pin-Yu, et al., 2018) | |
EADEN (L1, L2) |
EAD: Elastic-Net Attacks to Deep Neural Networks (Chen, Pin-Yu, et al., 2018) |
Performance Comparison
For a fair comparison, Robustbench is used. As for the comparison packages, currently updated and the most cited methods were selected:
Robust accuracy against each attack and elapsed time on the first 50 images of CIFAR10. For L2 attacks, the average L2 distances between adversarial images and the original images are recorded. All experiments were done on GeForce RTX 2080. For the latest version, please refer to here (code, nbviewer).
Attack | Package | Standard | Wong2020Fast | Rice2020Overfitting | Remark |
---|---|---|---|---|---|
FGSM (Linf) | Torchattacks | 34% (54ms) | 48% (5ms) | 62% (82ms) | |
Foolbox* | 34% (15ms) | 48% (8ms) | 62% (30ms) | ||
ART | 34% (214ms) | 48% (59ms) | 62% (768ms) | ||
PGD (Linf) | Torchattacks | 0% (174ms) | 44% (52ms) | 58% (1348ms) | |
Foolbox* | 0% (354ms) | 44% (56ms) | 58% (1856ms) | ||
ART | 0% (1384 ms) | 44% (437ms) | 58% (4704ms) | ||
CWβ Β (L2) | Torchattacks | 0% / 0.40 (2596ms) |
14% / 0.61 (3795ms) |
22% / 0.56 (43484ms) |
|
Foolbox* | 0% / 0.40 (2668ms) |
32% / 0.41 (3928ms) |
34% / 0.43 (44418ms) |
||
ART | 0% / 0.59 (196738ms) |
24% / 0.70 (66067ms) |
26% / 0.65 (694972ms) |
||
PGD (L2) | Torchattacks | 0% / 0.41 (184ms) | 68% / 0.5 (52ms) |
70% / 0.5 (1377ms) |
|
Foolbox* | 0% / 0.41 (396ms) | 68% / 0.5 (57ms) |
70% / 0.5 (1968ms) |
||
ART | 0% / 0.40 (1364ms) | 68% / 0.5 (429ms) |
70% / 0.5 (4777ms) |
* Note that Foolbox returns accuracy and adversarial images simultaneously, thus the actual time for generating adversarial images might be shorter than the records.
β Considering that the binary search algorithm for const c
can be time-consuming, torchattacks supports MutliAttack for grid searching c
.
To push further, I introduce Rai-toolbox, which is newly added package!
Attack | Package | Time/step (accuracy) |
---|---|---|
FGSM (Linf) | rai-toolbox | 58 ms (0%) |
Torchattacks | 81 ms (0%) | |
Foolbox | 105 ms (0%) | |
ART | 83 ms (0%) | |
PGD (Linf) | rai-toolbox | 58 ms (44%) |
Torchattacks | 79 ms (44%) | |
Foolbox | 82 ms (44%) | |
ART | 90 ms (44%) | |
PGD (L2) | rai-toolbox | 58 ms (70%) |
Torchattacks | 81 ms (70%) | |
Foolbox | 82 ms (70%) | |
ART | 89 ms (70%) |
The rai-toolbox takes a unique approach to gradient-based perturbations: they are implemented in terms of parameter-transforming optimizers and perturbation models. This enables users to implement diverse algorithms (like universal perturbations and concept probing with sparse gradients) using the same paradigm as a standard PGD attack.
Citation
If you use this package, please cite the following BibTex (SemanticScholar, GoogleScholar):
@article{kim2020torchattacks,
title={Torchattacks: A pytorch repository for adversarial attacks},
author={Kim, Hoki},
journal={arXiv preprint arXiv:2010.01950},
year={2020}
}
Recommended Sites and Packages
-
Adversarial Attack Packages:
- https://github.com/IBM/adversarial-robustness-toolbox: Adversarial attack and defense package made by IBM. TensorFlow, Keras, Pytorch available.
- https://github.com/bethgelab/foolbox: Adversarial attack package made by Bethge Lab. TensorFlow, Pytorch available.
- https://github.com/tensorflow/cleverhans: Adversarial attack package made by Google Brain. TensorFlow available.
- https://github.com/BorealisAI/advertorch: Adversarial attack package made by BorealisAI. Pytorch available.
- https://github.com/DSE-MSU/DeepRobust: Adversarial attack (especially on GNN) package made by BorealisAI. Pytorch available.
- https://github.com/fra31/auto-attack: Set of attacks that is believed to be the strongest in existence. TensorFlow, Pytorch available.
- https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/: PyTorch-centric tools for evaluating and enhancing both the robustness and the explainability of AI models. Pytorch available.
-
Adversarial Defense Leaderboard:
-
Adversarial Attack and Defense Papers:
- https://nicholas.carlini.com/writing/2019/all-adversarial-example-papers.html: A Complete List of All (arXiv) Adversarial Example Papers made by Nicholas Carlini.
- https://github.com/chawins/Adversarial-Examples-Reading-List: Adversarial Examples Reading List made by Chawin Sitawarin.