|
| 1 | + |
| 2 | + |
| 3 | +# Improving Generalization of Adversarial Training via Robust Critical Fine-Tuning |
| 4 | + |
| 5 | +This is the official implementation of ICCV2023 [Improving Generalization of Adversarial Training via Robust Critical Fine-Tuning](https://arxiv.org/abs/2308.02533). |
| 6 | + |
| 7 | +**Abstract**: Deep neural networks are susceptible to adversarial examples, posing a significant security risk in critical applications. Adversarial Training (AT) is a well-established technique to enhance adversarial robustness, but it often comes at the cost of decreased generalization ability. This paper proposes Robustness Critical Fine-Tuning (RiFT), a novel approach to enhance generalization without compromising adversarial robustness. The core idea of RiFT is to exploit the redundant capacity for robustness by fine-tuning the adversarially trained model on its non-robust-critical module. To do so, we introduce module robust criticality (MRC), a measure that evaluates the significance of a given module to model robustness under worst-case weight perturbations. Using this measure, we identify the module with the lowest MRC value as the non-robust-critical module and fine-tune its weights to obtain fine-tuned weights. Subsequently, we linearly interpolate between the adversarially trained weights and fine-tuned weights to derive the optimal fine-tuned model weights. We demonstrate the efficacy of RiFT on ResNet18, ResNet34, and WideResNet34-10 models trained on CIFAR10, CIFAR100, and Tiny-ImageNet datasets. Our experiments show that RiFT can significantly improve both generalization and out-of-distribution robust- ness by around 1.5% while maintaining or even slightly enhancing adversarial robustness. Code is available at https://github.com/microsoft/robustlearn. |
| 8 | + |
| 9 | +## Requirements |
| 10 | + |
| 11 | +### Running Enviroments |
| 12 | + |
| 13 | +To install requirements: |
| 14 | + |
| 15 | +``` |
| 16 | +conda env create -f env.yaml |
| 17 | +conda activate rift |
| 18 | +``` |
| 19 | + |
| 20 | +### Datasets |
| 21 | + |
| 22 | +CIFAR10 and CIFAR100 can be downloaded via PyTorch. |
| 23 | + |
| 24 | +For other datasets: |
| 25 | + |
| 26 | +1. [Tiny-ImageNet](http://cs231n.stanford.edu/tiny-imagenet-200.zip) |
| 27 | +2. [CIFAR10-C](https://drive.google.com/drive/folders/1HDVw6CmX3HiG0ODFtI75iIfBDxSiSz2K) |
| 28 | +3. [CIFAR100-C](https://drive.google.com/drive/folders/1HDVw6CmX3HiG0ODFtI75iIfBDxSiSz2K) |
| 29 | +4. [Tiny-ImageNet-C](https://berkeley.app.box.com/s/6zt1qzwm34hgdzcvi45svsb10zspop8a) |
| 30 | + |
| 31 | +After downloading these datasets, move them to ./data. |
| 32 | + |
| 33 | +The images in Tiny-ImageNet datasets are 64x64 with 200 classes. |
| 34 | + |
| 35 | +## Robust Critical Fine-Tuning |
| 36 | + |
| 37 | +### Demo |
| 38 | + |
| 39 | +Here we present a example for RiFT ResNet18 on CIFAR10. |
| 40 | + |
| 41 | +Download the adversarially trained model weights [here](https://drive.google.com/drive/folders/1Uzqm1cOYFXLa97GZjjwfiVS2OcbpJK4o?usp=drive_link). |
| 42 | + |
| 43 | +``` |
| 44 | +python main.py --layer=layer2.1.conv2 --resume="./ResNet18_CIFAR10.pth" |
| 45 | +``` |
| 46 | + |
| 47 | +- layer: the desired layer name to fine-tune. |
| 48 | + |
| 49 | +Here, layer2.1.conv2 is a non-robust-critical module. |
| 50 | + |
| 51 | +The non-robust-critical module of each model on each dataset are summarized as follows: |
| 52 | + |
| 53 | +| | CIFAR10 | CIFAR100 | Tiny-ImageNet | |
| 54 | +| -------- | -------------------- | -------------------- | -------------------- | |
| 55 | +| ResNet18 | layer2.1.conv2 | layer2.1.conv2 | layer3.1.conv2 | |
| 56 | +| ResNet34 | layer2.3.conv2 | layer2.3.conv2 | layer3.5.conv2 | |
| 57 | +| WRN34-10 | block1.layer.3.conv2 | block1.layer.2.conv2 | block1.layer.2.conv2 | |
| 58 | + |
| 59 | +### Pipeline |
| 60 | + |
| 61 | +1. Characterize the MRC for each module |
| 62 | + `python main.py --cal_mrc --resume=/path/to/your/model` |
| 63 | + This will output the MRC for each module. |
| 64 | +2. Fine-tuning on non-robust-critical module |
| 65 | + Based on the MRC output, choose a module with lowest MRC value to fine-tune. |
| 66 | + We suggest to choose the **middle layers** according to our experience. |
| 67 | + Try different learning rate! Usually a small learning rate is preferred. |
| 68 | + `python main.py --layer=xxx --lr=yyy --resume=zzz` |
| 69 | + When fine-tuning finish, it will automatically interpolate between adversarially trained weights and fine-tuned weights. |
| 70 | + The robust accuracy, in-distribution test acc are evaluated during the interpolation procedure. |
| 71 | +3. Test OOD performance. Pick he best interpolation factor (the one with max IID generalization increase while not drop robustness so much.) |
| 72 | + `python eval_ood.py --resume=xxx` |
| 73 | + |
| 74 | +## Results |
| 75 | + |
| 76 | + |
| 77 | + |
| 78 | + |
| 79 | + |
| 80 | + |
| 81 | + |
| 82 | + |
| 83 | + |
| 84 | + |
| 85 | +## References & Opensources |
| 86 | + |
| 87 | +- Classification models [code](https://github.com/kuangliu/pytorch-cifar) |
| 88 | +- Adversarial training [code](https://github.com/P2333/Bag-of-Tricks-for-AT) |
| 89 | + |
| 90 | +## Contact |
| 91 | + |
| 92 | + |
| 93 | + |
| 94 | + |
| 95 | +## Citation |
| 96 | + |
| 97 | +``` |
| 98 | +@inproceedings{zhu2023improving, |
| 99 | + title={Improving Generalization of Adversarial Training via Robust Critical Fine-Tuning}, |
| 100 | + author={Zhu, Kaijie and Hu, Xixu and Wang, Jindong and Xie, Xing and Yang, Ge }, |
| 101 | + year={2023}, |
| 102 | + booktitle={International Conference on Computer Vision}, |
| 103 | +} |
| 104 | +``` |
| 105 | + |
0 commit comments