Skip to content

[WIP] Add recompute support for imagen model. #750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
6caed08
add super res 512 and 1024
firestonelib Sep 17, 2022
9ca662c
[WIP] Add recompute support for imagen model.
GhostScreaming Sep 19, 2022
02bf031
Merge branch 'develop' of https://github.com/GhostScreaming/FleetX in…
GhostScreaming Sep 19, 2022
15b8c38
Add gradient-merge support.
GhostScreaming Sep 20, 2022
4ef3ded
Fix some problems.
GhostScreaming Sep 20, 2022
6476c29
Adapting imagen model for bfloat16 dtype.
GhostScreaming Sep 21, 2022
7159a2c
Merge branch 'develop' of https://github.com/PaddlePaddle/FleetX into…
GhostScreaming Sep 22, 2022
f12b77a
Merge branch 'develop' of https://github.com/PaddlePaddle/FleetX into…
GhostScreaming Sep 28, 2022
075930d
Merge branch 'imagen-mp' of https://github.com/GhostScreaming/FleetX …
GhostScreaming Sep 28, 2022
c5a8795
[WIP] add sharding and bfloat16 training strategy.
GhostScreaming Sep 28, 2022
71369f5
Merge branch 'develop' of https://github.com/PaddlePaddle/FleetX into…
GhostScreaming Oct 18, 2022
9eb034b
Merge branch 'develop' of https://github.com/PaddlePaddle/FleetX into…
GhostScreaming Oct 25, 2022
f79ef92
Polish Code.
GhostScreaming Oct 25, 2022
30099e6
Polish code.
GhostScreaming Oct 25, 2022
8947f4e
Fix bug of config.py and eager_engine.py
GhostScreaming Oct 25, 2022
1caba29
Polish Code.
GhostScreaming Oct 27, 2022
e6caf8a
Merge branch 'imagen-mp' of https://github.com/GhostScreaming/FleetX …
GhostScreaming Oct 27, 2022
94c25ef
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleFleet…
GhostScreaming Nov 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ Model:
dynamic_thresholding_percentile: 0.95
only_train_unet_number: 1
use_recompute: False
recompute_granularity: full

Engine:
mix_precision:
use_pure_fp16: True
scale_loss: 32768.0
custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"]
custom_white_list: ["lookup_table", "lookup_table_v2"]
fp16_dtype: "bfloat16"

Data:
Train:
Expand All @@ -51,6 +60,6 @@ Distributed:
mp_degree: 1
pp_degree: 1
sharding:
sharding_degree: 1
sharding_stage: 1
sharding_degree: 8
sharding_stage: 2
sharding_offload: False
41 changes: 29 additions & 12 deletions ppfleetx/core/engine/eager_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def configure_optimizers(self):
'custom_black_list']
self._custom_white_list = self._configs['mix_precision'][
'custom_white_list']
self._fp16_dtype = "float16" if 'fp16_dtype' not in self._configs['mix_precision'] \
else self._configs['mix_precision']['fp16_dtype']

self._save_steps = self._configs['save_load']['save_steps']
self._save_epoch = self._configs['save_load']['save_epoch']
Expand All @@ -158,6 +160,17 @@ def configure_optimizers(self):
self._use_recompute = configs['Model']['use_recompute']
self._quant_mode = True if 'Quantization' in configs and configs[
'Quantization']['enable'] else False

self._lr_scheduler_mode = configs.Optimizer.lr.pop('run_mode', 'step')
assert self._lr_scheduler_mode in [
'epoch', 'step'
], 'lr.run_mode must be epoch or step'
self._lr_scheduler = build_lr_scheduler(
configs.Optimizer.lr) if mode == 'train' else None

self._optimizer = build_optimizer(
configs.Optimizer, self._module.model,
self._lr_scheduler) if mode == 'train' else None

if self._use_pure_fp16:
if mode == 'train':
Expand All @@ -167,20 +180,17 @@ def configure_optimizers(self):
# Save dtype is the same as model dtype. Also can set save_dtype='float32' when
# training with pure fp16 strategy, but will cause the rise of memory.
self._module.model = paddle.amp.decorate(
models=self._module.model, level='O2')
models=self._module.model, level='O2', dtype=self._fp16_dtype)
else:
self._scaler = None

self._lr_scheduler_mode = configs.Optimizer.lr.pop('run_mode', 'step')
assert self._lr_scheduler_mode in [
'epoch', 'step'
], 'lr.run_mode must be epoch or step'
self._lr_scheduler = build_lr_scheduler(
configs.Optimizer.lr) if mode == 'train' else None

self._optimizer = build_optimizer(
configs.Optimizer, self._module.model,
self._lr_scheduler) if mode == 'train' else None
# self._lr_scheduler = build_lr_scheduler(
# configs.Optimizer.lr) if mode == 'train' else None

# self._optimizer = build_optimizer(
# configs.Optimizer, self._module.model,
# self._lr_scheduler) if mode == 'train' else None

# distributed configs
self._distributed = (dist.get_world_size() > 1)
Expand Down Expand Up @@ -411,6 +421,13 @@ def _fit_impl(self, batch):
self._module.model.train()

batch = self._module.pretreating_batch(batch)
if self._fp16_dtype is 'bfloat16':
with paddle.no_grad():
batch = [
paddle.cast(
t, dtype=paddle.bfloat16)
if t.dtype == paddle.float32 else t for t in batch
]
if self._pp_degree == 1:
if self._use_recompute and isinstance(self._module.model,
paddle.DataParallel):
Expand Down Expand Up @@ -456,9 +473,9 @@ def _model_forward_backward(self, batch):
self._use_pure_fp16,
custom_black_list=self._custom_black_list,
custom_white_list=self._custom_white_list,
level='O2'):
level='O2',
dtype=self._fp16_dtype):
loss = self._module.training_step(micro_batch)

loss_bw = self._scaler.scale(loss) if self._use_pure_fp16 else loss
if self._accumulate_steps > 1:
# div the loss for backward
Expand Down
17 changes: 11 additions & 6 deletions ppfleetx/models/multimodal_model/imagen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(self,
dynamic_thresholding_percentile=0.95,
only_train_unet_number=None,
use_recompute=False,
recompute_granularity="full",
fused_linear=False):
super().__init__()

Expand All @@ -165,6 +166,9 @@ def __init__(self,

self.channels = in_chans

# use recompute
self.use_recompute = use_recompute

# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet

Expand Down Expand Up @@ -284,15 +288,15 @@ def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1

if isinstance(self.unets, nn.LayerList):
unets_list = [unet for unet in self.unets]
delattr(self, 'unets')
self.unets = unets_list
# if isinstance(self.unets, nn.LayerList):
# unets_list = [unet for unet in self.unets]
# delattr(self, 'unets')
# self.unets = unets_list
self.unet_being_trained_index = index
return self.unets[index]

def reset_unets(self, ):
self.unets = nn.LayerList([*self.unets])
# self.unets = nn.LayerList([*self.unets])
self.unet_being_trained_index = -1

@contextmanager
Expand Down Expand Up @@ -691,7 +695,8 @@ def p_losses(self,
lowres_noise_times=self.lowres_noise_schedule.get_condition(
lowres_aug_times),
lowres_cond_img=lowres_cond_img_noisy,
cond_drop_prob=self.cond_drop_prob, )
cond_drop_prob=self.cond_drop_prob,
use_recompute=self.use_recompute)

# prediction objective

Expand Down
43 changes: 34 additions & 9 deletions ppfleetx/models/multimodal_model/imagen/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from paddle import nn
from paddle import nn, einsum
import paddle.nn.functional as F
from paddle.distributed.fleet.utils import recompute

from .utils import (zeros_, zero_init_, default, exists, cast_tuple,
resize_image_to, prob_mask_like, masked_mean, Identity,
Expand Down Expand Up @@ -1020,8 +1021,9 @@ def __init__(self,
layer_cross_attns = cast_tuple(layer_cross_attns, num_layers)

assert all([
layers == num_layers for layers in
list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))
layers == num_layers
for layers in list(
map(len, (resnet_groups, layer_attns, layer_cross_attns)))
])

# downsample klass
Expand Down Expand Up @@ -1298,7 +1300,8 @@ def forward(self,
text_embeds=None,
text_mask=None,
cond_images=None,
cond_drop_prob=0.):
cond_drop_prob=0.,
use_recompute=False):
batch_size = x.shape[0]

# add low resolution conditioning, if present
Expand Down Expand Up @@ -1340,6 +1343,8 @@ def forward(self,

time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens)
if use_recompute:
t.stop_gradient = True

# add lowres time conditioning to time hiddens
# and add lowres time tokens along sequence dimension for attention
Expand Down Expand Up @@ -1426,6 +1431,8 @@ def forward(self,
# normalize conditioning tokens

c = self.norm_cond(c)
if use_recompute:
c.stop_gradient = True

if exists(self.init_resnet_block):
x = self.init_resnet_block(x, t)
Expand All @@ -1436,38 +1443,56 @@ def forward(self,
if exists(pre_downsample):
x = pre_downsample(x)

x = init_block(x, t, c)
if use_recompute:
x = recompute(init_block, x, t, c)
else:
x = init_block(x, t, c)

for resnet_block in resnet_blocks:
x = resnet_block(x, t)
hiddens.append(x)

x = attn_block(x, c)
if use_recompute:
x = recompute(attn_block, x, c)
else:
x = attn_block(x, c)
hiddens.append(x)

if exists(post_downsample):
x = post_downsample(x)

x = self.mid_block1(x, t, c)
if use_recompute:
x = recompute(self.mid_block1, x, t, c)
else:
x = self.mid_block1(x, t, c)

if exists(self.mid_attn):
x = self.mid_attn(x)

x = self.mid_block2(x, t, c)
if use_recompute:
x = recompute(self.mid_block2, x, t, c)
else:
x = self.mid_block2(x, t, c)

add_skip_connection = lambda x: paddle.concat((x, hiddens.pop() * self.skip_connect_scale), axis=1)

up_hiddens = []

for init_block, resnet_blocks, attn_block, upsample in self.ups:
x = add_skip_connection(x)
x = init_block(x, t, c)
if use_recompute:
x = recompute(init_block, x, t, c)
else:
x = init_block(x, t, c)

for resnet_block in resnet_blocks:
x = add_skip_connection(x)
x = resnet_block(x, t)

x = attn_block(x, c)
if use_recompute:
x = recompute(attn_block, x, c)
else:
x = attn_block(x, c)
up_hiddens.append(x)
x = upsample(x)

Expand Down
6 changes: 3 additions & 3 deletions ppfleetx/models/multimodal_model/multimodal_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def training_step(self, batch):
return loss

def training_step_end(self, log_dict):
speed = self.configs.Engine.logging_freq / log_dict['train_cost']
speed = 1.0 / log_dict['train_cost']

logger.info(
"[train] epoch: %d, batch: %d, loss: %.9f, avg_batch_cost: %.5f sec, speed: %.2f step/s, learning rate: %.5e"
Expand All @@ -63,7 +63,7 @@ def validation_step(self, batch):
return loss

def validation_step_end(self, log_dict):
speed = self.configs.Engine.logging_freq / log_dict['eval_cost']
speed = 1.0 / log_dict['eval_cost']
logger.info(
"[eval] epoch: %d, batch: %d, loss: %.9f, avg_eval_cost: %.5f sec, speed: %.2f step/s"
% (log_dict['epoch'], log_dict['batch'], log_dict['loss'],
Expand All @@ -77,7 +77,7 @@ def test_step(self, batch):
return loss

def test_step_end(self, log_dict):
speed = self.configs.Engine.logging_freq / log_dict['test_cost']
speed = 1.0 / log_dict['test_cost']
logger.info(
"[test] epoch: %d, batch: %d, loss: %.9f, avg_test_cost: %.5f sec, speed: %.2f step/s"
% (log_dict['epoch'], log_dict['batch'], log_dict['loss'],
Expand Down
23 changes: 23 additions & 0 deletions projects/imagen/run_super_resolusion_1024_DP8.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#! /bin/bash

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

log_dir=log_imagen_1024_DP8
rm -rf $log_dir

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -m paddle.distributed.launch --log_dir $log_dir --devices "0,1,2,3,4,5,6,7" \
tools/train.py -c ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml \
-o Data.Train.loader.num_workers=8
24 changes: 24 additions & 0 deletions projects/imagen/run_super_resolusion_1024_sharding.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#! /bin/bash

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

log_dir=log_imagen_1024_sharding
rm -rf $log_dir

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -m paddle.distributed.launch --log_dir $log_dir --devices "0,1,2,3,4,5,6,7" \
tools/train.py -c ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml \
-o Data.Train.loader.num_workers=8 -o Distributed.sharding.sharding_degree=8 \
-o Distributed.sharding.sharding_stage=2