Skip to content

Commit 3eff2b7

Browse files
authored
Merge pull request #51 from Immortalise/main
add "Improving Generalization of Adversarial Training via Robust Critical Fine-Tuning" ICCV 2023
2 parents 8b8bde5 + 8558848 commit 3eff2b7

32 files changed

+4365
-1
lines changed

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@
2121
Latest research in robust machine learning, including adversarial/backdoor attack and defense, out-of-distribution (OOD) generalization, and safe transfer learning.
2222

2323
Hosted projects:
24+
25+
- **RiFT** (ICCV 2023, #Adversarial Robustness, #Generalization, #OOD)
26+
- [Code](./RiFT/) | [Improving Generalization of Adversarial Training via Robust Critical Fine-Tuning](https://arxiv.org/abs/2308.02533)
27+
2428
- **Diversify** (ICLR 2023, #OOD):
25-
- [Code](./diversify/) | [Out-of-distribution Representation Learning for Time Series Classification](https://arxiv.org/abs/2209.07027)
29+
- [Code](./diversify/) | [Out-of-distribution Representatio[n Learning for Time Series Classification](https://arxiv.org/abs/2209.07027)
2630
- **DRM** (KDD 2023, #OOD):
2731
- [Code](./drm/) | [Domain-Specific Risk Minimization for Out-of-Distribution Generalization](https://arxiv.org/abs/2208.08661)
2832
- **DDLearn** (KDD 2023, #OOD):

RiFT/.gitignore

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
data/
5+
results*/
6+
7+
# Byte-compiled / optimized / DLL files
8+
__pycache__/
9+
*.py[cod]
10+
*$py.class
11+
12+
# C extensions
13+
*.so
14+
15+
# Distribution / packaging
16+
.Python
17+
build/
18+
develop-eggs/
19+
dist/
20+
downloads/
21+
eggs/
22+
.eggs/
23+
lib/
24+
lib64/
25+
parts/
26+
sdist/
27+
var/
28+
wheels/
29+
share/python-wheels/
30+
*.egg-info/
31+
.installed.cfg
32+
*.egg
33+
MANIFEST
34+
35+
# PyInstaller
36+
# Usually these files are written by a python script from a template
37+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
38+
*.manifest
39+
*.spec
40+
41+
# Installer logs
42+
pip-log.txt
43+
pip-delete-this-directory.txt
44+
45+
# Unit test / coverage reports
46+
htmlcov/
47+
.tox/
48+
.nox/
49+
.coverage
50+
.coverage.*
51+
.cache
52+
nosetests.xml
53+
coverage.xml
54+
*.cover
55+
*.py,cover
56+
.hypothesis/
57+
.pytest_cache/
58+
cover/
59+
60+
# Translations
61+
*.mo
62+
*.pot
63+
64+
# Django stuff:
65+
*.log
66+
local_settings.py
67+
db.sqlite3
68+
db.sqlite3-journal
69+
70+
# Flask stuff:
71+
instance/
72+
.webassets-cache
73+
74+
# Scrapy stuff:
75+
.scrapy
76+
77+
# Sphinx documentation
78+
docs/_build/
79+
80+
# PyBuilder
81+
.pybuilder/
82+
target/
83+
84+
# Jupyter Notebook
85+
.ipynb_checkpoints
86+
87+
# IPython
88+
profile_default/
89+
ipython_config.py
90+
91+
# pyenv
92+
# For a library or package, you might want to ignore these files since the code is
93+
# intended to run in multiple environments; otherwise, check them in:
94+
# .python-version
95+
96+
# pipenv
97+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
99+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
100+
# install all needed dependencies.
101+
#Pipfile.lock
102+
103+
# poetry
104+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105+
# This is especially recommended for binary packages to ensure reproducibility, and is more
106+
# commonly ignored for libraries.
107+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108+
#poetry.lock
109+
110+
# pdm
111+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112+
#pdm.lock
113+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114+
# in version control.
115+
# https://pdm.fming.dev/#use-with-ide
116+
.pdm.toml
117+
118+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
119+
__pypackages__/
120+
121+
# Celery stuff
122+
celerybeat-schedule
123+
celerybeat.pid
124+
125+
# SageMath parsed files
126+
*.sage.py
127+
128+
# Environments
129+
.env
130+
.venv
131+
env/
132+
venv/
133+
ENV/
134+
env.bak/
135+
venv.bak/
136+
137+
# Spyder project settings
138+
.spyderproject
139+
.spyproject
140+
141+
# Rope project settings
142+
.ropeproject
143+
144+
# mkdocs documentation
145+
/site
146+
147+
# mypy
148+
.mypy_cache/
149+
.dmypy.json
150+
dmypy.json
151+
152+
# Pyre type checker
153+
.pyre/
154+
155+
# pytype static type analyzer
156+
.pytype/
157+
158+
# Cython debug symbols
159+
cython_debug/
160+
161+
# PyCharm
162+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
163+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
164+
# and can be added to the global gitignore or merged into this file. For a more nuclear
165+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
166+
#.idea/

RiFT/README.md

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
![](https://files.mdnice.com/user/45288/023bf2cb-1685-43ce-bba8-1ba9b66f80b4.png)
2+
3+
# Improving Generalization of Adversarial Training via Robust Critical Fine-Tuning
4+
5+
This is the official implementation of ICCV2023 [Improving Generalization of Adversarial Training via Robust Critical Fine-Tuning](https://arxiv.org/abs/2308.02533).
6+
7+
**Abstract**: Deep neural networks are susceptible to adversarial examples, posing a significant security risk in critical applications. Adversarial Training (AT) is a well-established technique to enhance adversarial robustness, but it often comes at the cost of decreased generalization ability. This paper proposes Robustness Critical Fine-Tuning (RiFT), a novel approach to enhance generalization without compromising adversarial robustness. The core idea of RiFT is to exploit the redundant capacity for robustness by fine-tuning the adversarially trained model on its non-robust-critical module. To do so, we introduce module robust criticality (MRC), a measure that evaluates the significance of a given module to model robustness under worst-case weight perturbations. Using this measure, we identify the module with the lowest MRC value as the non-robust-critical module and fine-tune its weights to obtain fine-tuned weights. Subsequently, we linearly interpolate between the adversarially trained weights and fine-tuned weights to derive the optimal fine-tuned model weights. We demonstrate the efficacy of RiFT on ResNet18, ResNet34, and WideResNet34-10 models trained on CIFAR10, CIFAR100, and Tiny-ImageNet datasets. Our experiments show that RiFT can significantly improve both generalization and out-of-distribution robust- ness by around 1.5% while maintaining or even slightly enhancing adversarial robustness. Code is available at https://github.com/microsoft/robustlearn.
8+
9+
## Requirements
10+
11+
### Running Enviroments
12+
13+
To install requirements:
14+
15+
```
16+
conda env create -f env.yaml
17+
conda activate rift
18+
```
19+
20+
### Datasets
21+
22+
CIFAR10 and CIFAR100 can be downloaded via PyTorch.
23+
24+
For other datasets:
25+
26+
1. [Tiny-ImageNet](http://cs231n.stanford.edu/tiny-imagenet-200.zip)
27+
2. [CIFAR10-C](https://drive.google.com/drive/folders/1HDVw6CmX3HiG0ODFtI75iIfBDxSiSz2K)
28+
3. [CIFAR100-C](https://drive.google.com/drive/folders/1HDVw6CmX3HiG0ODFtI75iIfBDxSiSz2K)
29+
4. [Tiny-ImageNet-C](https://berkeley.app.box.com/s/6zt1qzwm34hgdzcvi45svsb10zspop8a)
30+
31+
After downloading these datasets, move them to ./data.
32+
33+
The images in Tiny-ImageNet datasets are 64x64 with 200 classes.
34+
35+
## Robust Critical Fine-Tuning
36+
37+
### Demo
38+
39+
Here we present a example for RiFT ResNet18 on CIFAR10.
40+
41+
Download the adversarially trained model weights [here](https://drive.google.com/drive/folders/1Uzqm1cOYFXLa97GZjjwfiVS2OcbpJK4o?usp=drive_link).
42+
43+
```
44+
python main.py --layer=layer2.1.conv2 --resume="./ResNet18_CIFAR10.pth"
45+
```
46+
47+
- layer: the desired layer name to fine-tune.
48+
49+
Here, layer2.1.conv2 is a non-robust-critical module.
50+
51+
The non-robust-critical module of each model on each dataset are summarized as follows:
52+
53+
| | CIFAR10 | CIFAR100 | Tiny-ImageNet |
54+
| -------- | -------------------- | -------------------- | -------------------- |
55+
| ResNet18 | layer2.1.conv2 | layer2.1.conv2 | layer3.1.conv2 |
56+
| ResNet34 | layer2.3.conv2 | layer2.3.conv2 | layer3.5.conv2 |
57+
| WRN34-10 | block1.layer.3.conv2 | block1.layer.2.conv2 | block1.layer.2.conv2 |
58+
59+
### Pipeline
60+
61+
1. Characterize the MRC for each module
62+
`python main.py --cal_mrc --resume=/path/to/your/model`
63+
This will output the MRC for each module.
64+
2. Fine-tuning on non-robust-critical module
65+
Based on the MRC output, choose a module with lowest MRC value to fine-tune.
66+
We suggest to choose the **middle layers** according to our experience.
67+
Try different learning rate! Usually a small learning rate is preferred.
68+
`python main.py --layer=xxx --lr=yyy --resume=zzz`
69+
When fine-tuning finish, it will automatically interpolate between adversarially trained weights and fine-tuned weights.
70+
The robust accuracy, in-distribution test acc are evaluated during the interpolation procedure.
71+
3. Test OOD performance. Pick he best interpolation factor (the one with max IID generalization increase while not drop robustness so much.)
72+
`python eval_ood.py --resume=xxx`
73+
74+
## Results
75+
76+
![](https://files.mdnice.com/user/45288/c3c98491-a292-4888-82cc-081bc8d3c3c6.png)
77+
78+
79+
80+
81+
![](https://files.mdnice.com/user/45288/bad5bb9f-788d-4350-ac5c-ddd850ade04f.png)
82+
83+
84+
85+
## References & Opensources
86+
87+
- Classification models [code](https://github.com/kuangliu/pytorch-cifar)
88+
- Adversarial training [code](https://github.com/P2333/Bag-of-Tricks-for-AT)
89+
90+
## Contact
91+
92+
- Kaijie Zhu: [email protected]
93+
- Jindong Wang: [email protected]
94+
95+
## Citation
96+
97+
```
98+
@inproceedings{zhu2023improving,
99+
title={Improving Generalization of Adversarial Training via Robust Critical Fine-Tuning},
100+
author={Zhu, Kaijie and Hu, Xixu and Wang, Jindong and Xie, Xing and Yang, Ge },
101+
year={2023},
102+
booktitle={International Conference on Computer Vision},
103+
}
104+
```
105+

0 commit comments

Comments
 (0)