Skip to content

Commit 5a1a095

Browse files
authored
[test] refactored with the new rerun decorator (#763)
* [test] refactored with the new rerun decorator * polish test case
1 parent deaf99f commit 5a1a095

34 files changed

+80
-75
lines changed

tests/test_amp/test_naive_fp16.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.multiprocessing as mp
44
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
55
from tests.components_to_test.registry import non_distributed_component_funcs
6-
from colossalai.testing import assert_close_loose, rerun_on_exception
6+
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use
77
from colossalai.utils import free_port
88
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
99

@@ -84,7 +84,7 @@ def run_dist(rank, world_size, port):
8484

8585

8686
@pytest.mark.dist
87-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
87+
@rerun_if_address_is_in_use()
8888
def test_naive_amp():
8989
world_size = 1
9090
run_func = partial(run_dist, world_size=world_size, port=free_port())

tests/test_comm/test_comm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from colossalai.core import global_context as gpc
1010
from colossalai.initialize import launch
1111
from colossalai.utils import free_port, get_current_device
12-
from colossalai.testing import rerun_on_exception
12+
from colossalai.testing import rerun_if_address_is_in_use
1313

1414
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
1515

@@ -64,7 +64,7 @@ def check_layer(rank, world_size, port):
6464

6565

6666
@pytest.mark.dist
67-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
67+
@rerun_if_address_is_in_use()
6868
def test_comm():
6969
world_size = 4
7070
run_func = partial(check_layer, world_size=world_size, port=free_port())

tests/test_context/test_hybrid_parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from colossalai.utils import free_port
1414
from colossalai.context import reset_seeds
1515
from colossalai.global_variables import tensor_parallel_env as tp_env
16-
from colossalai.testing import rerun_on_exception
16+
from colossalai.testing import rerun_if_address_is_in_use
1717

1818
CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py'))
1919

@@ -141,7 +141,7 @@ def run_dist(rank, world_size, backend, port_list, host):
141141

142142

143143
@pytest.mark.cpu
144-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
144+
@rerun_if_address_is_in_use()
145145
def test_context():
146146
"""
147147
As no computation or communication is done, we can run this test on CPU.

tests/test_data/test_data_parallel_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from colossalai.context import ParallelMode, Config
1818
from colossalai.core import global_context as gpc
1919
from colossalai.utils import get_dataloader, free_port
20-
from colossalai.testing import rerun_on_exception
20+
from colossalai.testing import rerun_if_address_is_in_use
2121

2222
CONFIG = Config(
2323
dict(
@@ -67,7 +67,7 @@ def run_data_sampler(rank, world_size, port):
6767

6868

6969
@pytest.mark.cpu
70-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
70+
@rerun_if_address_is_in_use()
7171
def test_data_sampler():
7272
world_size = 4
7373
test_func = partial(run_data_sampler, world_size=world_size, port=free_port())

tests/test_data/test_deterministic_dataloader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from colossalai.context import ParallelMode, Config
1818
from colossalai.core import global_context as gpc
1919
from colossalai.utils import free_port
20-
from colossalai.testing import rerun_on_exception
20+
from colossalai.testing import rerun_if_address_is_in_use
2121

2222
CONFIG = Config(
2323
dict(
@@ -79,7 +79,7 @@ def run_data_sampler(rank, world_size, port):
7979

8080

8181
@pytest.mark.cpu
82-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
82+
@rerun_if_address_is_in_use()
8383
def test_data_sampler():
8484
world_size = 4
8585
test_func = partial(run_data_sampler, world_size=world_size, port=free_port())

tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,18 @@
1515
from colossalai.trainer import Trainer, hooks
1616
from colossalai.utils import free_port, get_dataloader
1717
from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep
18-
from colossalai.testing import rerun_on_exception
18+
from colossalai.testing import rerun_if_address_is_in_use
1919
from model_zoo.vit import vit_tiny_patch4_32
2020
from torchvision import transforms
2121
from torchvision.datasets import CIFAR10
2222

2323
BATCH_SIZE = 4
2424
NUM_EPOCHS = 60
2525
WARMUP_EPOCHS = 5
26-
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
27-
fp16=dict(mode=AMP_TYPE.NAIVE),
28-
gradient_accumulation=2)
26+
CONFIG = dict(NUM_MICRO_BATCHES=2,
27+
parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
28+
fp16=dict(mode=AMP_TYPE.NAIVE),
29+
gradient_accumulation=2)
2930

3031

3132
def run_trainer(rank, world_size, port):
@@ -79,7 +80,7 @@ def run_trainer(rank, world_size, port):
7980

8081

8182
@pytest.mark.dist
82-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
83+
@rerun_if_address_is_in_use()
8384
def test_hybrid_parallel():
8485
world_size = 8
8586
run_func = partial(run_trainer, world_size=world_size, port=free_port())

tests/test_engine/test_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from colossalai.core import global_context as gpc
88
from colossalai.utils import free_port
99
from tests.components_to_test.registry import non_distributed_component_funcs
10-
from colossalai.testing import parameterize, rerun_on_exception
10+
from colossalai.testing import parameterize, rerun_if_address_is_in_use
1111

1212
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
1313
fp16=dict(mode=None),
@@ -56,7 +56,7 @@ def run_engine(rank, world_size, port):
5656

5757

5858
@pytest.mark.dist
59-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
59+
@rerun_if_address_is_in_use()
6060
def test_engine():
6161
world_size = 2
6262
run_func = partial(run_engine, world_size=world_size, port=free_port())

tests/test_layers/test_1d/test_1d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from colossalai.logging import disable_existing_loggers
1111
from colossalai.initialize import launch
1212
from colossalai.utils import free_port
13-
from colossalai.testing import rerun_on_exception
13+
from colossalai.testing import rerun_if_address_is_in_use
1414
from checks_1d.check_layer_1d import *
1515

1616
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),)
@@ -35,7 +35,7 @@ def check_layer(rank, world_size, port):
3535

3636

3737
@pytest.mark.dist
38-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
38+
@rerun_if_address_is_in_use()
3939
def test_1d():
4040
world_size = 4
4141
run_func = partial(check_layer, world_size=world_size, port=free_port())

tests/test_layers/test_2d/test_2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from colossalai.initialize import launch
1111
from colossalai.logging import disable_existing_loggers
1212
from colossalai.utils import free_port
13-
from colossalai.testing import rerun_on_exception
13+
from colossalai.testing import rerun_if_address_is_in_use
1414
from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
1515
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
1616
check_vocab_parallel_classifier_given_embed_weight,
@@ -55,7 +55,7 @@ def check_layer_and_operation(rank, world_size, port):
5555

5656

5757
@pytest.mark.dist
58-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
58+
@rerun_if_address_is_in_use()
5959
def test_2d():
6060
world_size = 4
6161
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())

tests/test_layers/test_2p5d/test_2p5d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from colossalai.initialize import launch
88
from colossalai.logging import disable_existing_loggers
99
from colossalai.utils import free_port
10-
from colossalai.testing import rerun_on_exception
10+
from colossalai.testing import rerun_if_address_is_in_use
1111
from checks_2p5d.check_layer_2p5d import *
1212
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
1313

@@ -51,7 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
5151

5252

5353
@pytest.mark.dist
54-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
54+
@rerun_if_address_is_in_use()
5555
def test_2p5d():
5656
world_size = 4
5757
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())

tests/test_layers/test_3d/test_3d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from colossalai.initialize import launch
1010
from colossalai.logging import disable_existing_loggers
1111
from colossalai.utils import free_port
12-
from colossalai.testing import rerun_on_exception
12+
from colossalai.testing import rerun_if_address_is_in_use
1313
from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
1414
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
1515
check_vocab_parallel_classifier_given_embed_weight,
@@ -51,7 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
5151

5252

5353
@pytest.mark.dist
54-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
54+
@rerun_if_address_is_in_use()
5555
def test_3d():
5656
world_size = 8
5757
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())

tests/test_layers/test_sequence/test_sequence.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from colossalai.core import global_context as gpc
99
from colossalai.context import ParallelMode
10-
from colossalai.testing import rerun_on_exception
10+
from colossalai.testing import rerun_if_address_is_in_use
1111
from functools import partial
1212

1313
CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence')))
@@ -132,7 +132,7 @@ def run_test(rank, world_size):
132132

133133

134134
@pytest.mark.dist
135-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
135+
@rerun_if_address_is_in_use()
136136
def test_sequence():
137137
world_size = 4
138138
run_func = partial(run_test, world_size=world_size)

tests/test_moe/test_grad_handler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
from colossalai.context.moe_context import MOE_CONTEXT
1111
from colossalai.utils.moe import sync_moe_model_param
1212
from colossalai.engine.gradient_handler import MoeGradientHandler
13-
from colossalai.testing import assert_equal_in_group
14-
from colossalai.testing import rerun_on_exception
13+
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use
1514

1615
BATCH_SIZE = 4
1716
DIM = 16
@@ -63,7 +62,7 @@ def run_test(rank, world_size, port):
6362

6463

6564
@pytest.mark.dist
66-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
65+
@rerun_if_address_is_in_use()
6766
def test_grad_handler():
6867
world_size = 4
6968
run_func = partial(run_test, world_size=world_size, port=free_port())

tests/test_moe/test_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from colossalai.utils import free_port, get_current_device
1010
from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts
1111
from colossalai.context.moe_context import MOE_CONTEXT
12-
from colossalai.testing import rerun_on_exception
12+
from colossalai.testing import rerun_if_address_is_in_use
1313

1414
BATCH_SIZE = 16
1515
NUM_EXPERTS = 4
@@ -87,7 +87,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
8787
@pytest.mark.parametrize("hidden_size", [32, 144])
8888
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
8989
@pytest.mark.parametrize("router", [Top1Router, Top2Router])
90-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
90+
@rerun_if_address_is_in_use()
9191
def test_moe_kernel(rs, hidden_size, data_type, router):
9292
world_size = 4
9393
run_func = partial(run_routing,

tests/test_moe/test_moe_group.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from colossalai.nn.layer.moe import Experts
99
from colossalai.context.moe_context import MOE_CONTEXT
1010
from colossalai.utils.moe import sync_moe_model_param
11-
from colossalai.testing import assert_equal_in_group, rerun_on_exception
11+
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use
1212

1313
D_MODEL = 4
1414
D_FF = 8
@@ -60,7 +60,7 @@ def run_test(rank, port):
6060

6161

6262
@pytest.mark.dist
63-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
63+
@rerun_if_address_is_in_use()
6464
def test_moe_initialization():
6565
world_size = 4
6666
run_func = partial(run_test, port=free_port())

tests/test_moe/test_moe_zero_init.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from colossalai.zero.init_ctx import ZeroInitContext
1515
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
1616

17-
from colossalai.testing import rerun_on_exception
17+
from colossalai.testing import rerun_if_address_is_in_use
1818
from colossalai.utils import get_current_device
1919
from tests.test_zero.common import CONFIG
2020

@@ -91,7 +91,7 @@ def _run_dist(rank, world_size, port):
9191

9292
@pytest.mark.dist
9393
@pytest.mark.parametrize("world_size", [2, 4])
94-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
94+
@rerun_if_address_is_in_use()
9595
def test_moe_zero_init(world_size):
9696
run_func = partial(_run_dist, world_size=world_size, port=free_port())
9797
mp.spawn(run_func, nprocs=world_size)

tests/test_moe/test_moe_zero_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
import torch
66
import torch.multiprocessing as mp
7-
from colossalai.testing import parameterize, rerun_on_exception
7+
from colossalai.testing import parameterize, rerun_if_address_is_in_use
88
from colossalai.utils import free_port
99
from colossalai.zero.init_ctx import ZeroInitContext
1010
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
@@ -65,7 +65,7 @@ def run_dist(rank, world_size, port):
6565

6666
@pytest.mark.dist
6767
@pytest.mark.parametrize("world_size", [2])
68-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
68+
@rerun_if_address_is_in_use()
6969
def test_moe_zero_model(world_size):
7070
run_func = partial(run_dist, world_size=world_size, port=free_port())
7171
mp.spawn(run_func, nprocs=world_size)

tests/test_moe/test_moe_zero_optim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.multiprocessing as mp
77
from colossalai.amp import convert_to_apex_amp
88
from colossalai.nn.optimizer import CPUAdam
9-
from colossalai.testing import parameterize, rerun_on_exception
9+
from colossalai.testing import parameterize, rerun_if_address_is_in_use
1010
from colossalai.utils import free_port
1111
from colossalai.zero.init_ctx import ZeroInitContext
1212
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
@@ -120,7 +120,7 @@ def _run_dist(rank, world_size, port):
120120
# use_cpuadam = True can be used with cpu_offload = False
121121
@pytest.mark.dist
122122
@pytest.mark.parametrize("world_size", [2])
123-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
123+
@rerun_if_address_is_in_use()
124124
def test_moe_zero_optim(world_size):
125125
run_func = partial(_run_dist, world_size=world_size, port=free_port())
126126
mp.spawn(run_func, nprocs=world_size)

tests/test_trainer/test_trainer_with_non_pipe_schedule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from colossalai.trainer import Trainer
1010
from colossalai.utils import MultiTimer, free_port
1111
from tests.components_to_test.registry import non_distributed_component_funcs
12-
from colossalai.testing import parameterize, rerun_on_exception
12+
from colossalai.testing import parameterize, rerun_if_address_is_in_use
1313

1414
BATCH_SIZE = 4
1515
IMG_SIZE = 32
@@ -51,7 +51,7 @@ def run_dist(rank, world_size, port):
5151

5252

5353
@pytest.mark.dist
54-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
54+
@rerun_if_address_is_in_use()
5555
def test_trainer_no_pipeline():
5656
world_size = 4
5757
run_func = partial(run_dist, world_size=world_size, port=free_port())

tests/test_trainer/test_trainer_with_pipe_schedule.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717
from torchvision import transforms
1818
from torchvision.datasets import CIFAR10
1919
from torchvision.models import resnet18
20-
from colossalai.testing import rerun_on_exception
20+
from colossalai.testing import rerun_if_address_is_in_use
2121

2222
BATCH_SIZE = 4
2323
IMG_SIZE = 32
2424
NUM_EPOCHS = 200
2525

26-
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2),)
26+
CONFIG = dict(
27+
NUM_MICRO_BATCHES=2,
28+
parallel=dict(pipeline=2),
29+
)
2730

2831

2932
def run_trainer_with_pipeline(rank, world_size, port):
@@ -85,7 +88,7 @@ def forward(self, x):
8588

8689

8790
@pytest.mark.dist
88-
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
91+
@rerun_if_address_is_in_use()
8992
def test_trainer_with_pipeline():
9093
world_size = 4
9194
run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port())

0 commit comments

Comments
 (0)