Skip to content

[WIP] Distill #853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions distill.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
export CUDA_VISIBLE_DEVICES=6,7

python -m paddle.distributed.launch \
./tools/train.py \
-c ./ppfleetx/configs/nlp/gpt/distill_gpt_345M_single_card.yaml
38 changes: 38 additions & 0 deletions ppfleetx/configs/nlp/gpt/distill_gpt_345M_single_card.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
_base_: ./pretrain_gpt_base.yaml

Global:
global_batch_size:
local_batch_size: 1
micro_batch_size: 1

Engine:
save_load:
ckpt_dir: PaddleFleetX_GPT_345M_220826

Model:
module: "GPTDistillModule"
vocab_size: 50304
hidden_size: 1024
num_layers: 24
num_attention_heads: 16
ffn_hidden_size: 4096
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
max_position_embeddings: 1024
type_vocab_size: 16
initializer_range: 0.02
use_recompute: False
recompute_granularity:
no_recompute_layers:


Distributed:
dp_degree: 1
mp_degree: 1
pp_degree: 1
sharding:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
reduce_overlap: False
broadcast_overlap: False
4 changes: 2 additions & 2 deletions ppfleetx/configs/nlp/gpt/eval_gpt_345M_single_card.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ Model:


Offline_Eval:
eval_path: ./wikitext-103/wiki.valid.tokens
cloze_eval: False
eval_path: lambada_test.jsonl
cloze_eval: True
overlapping_eval: 32
batch_size: 8
max_seq_len: 1024
Expand Down
44 changes: 44 additions & 0 deletions ppfleetx/configs/nlp/gpt/eval_gpt_6.7B_single_card.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
_base_: ./pretrain_gpt_base.yaml

Engine:
save_load:
ckpt_dir: pretrain_model

Model:
module: GPTEvalModule
vocab_size: 50304
hidden_size: 4096
num_layers: 32
num_attention_heads: 32
ffn_hidden_size:
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
max_position_embeddings: 1024
type_vocab_size: 16
initializer_range: 0.02
use_recompute: False
recompute_granularity:
no_recompute_layers:


Distributed:
dp_degree: 1
mp_degree: 1
pp_degree: 1
sharding:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
reduce_overlap: False
broadcast_overlap: False


Offline_Eval:
eval_path: lambada_test.jsonl
cloze_eval: True
overlapping_eval: 32
batch_size: 8
max_seq_len: 1024
logging_freq: 10


91 changes: 49 additions & 42 deletions ppfleetx/core/engine/eager_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@
from ppfleetx.utils.version import version_check
from ppfleetx.utils.export import export_inference_model

import nvidia_smi

nvidia_smi.nvmlInit()

handle = nvidia_smi.nvmlDeviceGetHandleByIndex(6)
# card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -179,7 +186,7 @@ def configure_optimizers(self):

# distributed configs
self._distributed = (dist.get_world_size() > 1)

'''
if self._distributed:
self._hcg = fleet.get_hybrid_communicate_group()
self._dp_group = self._hcg.get_data_parallel_group()
Expand All @@ -194,6 +201,8 @@ def configure_optimizers(self):
self._wrap_with_fleet()
else:
self._dp_rank = 0
'''
self._dp_rank = 0

# using for save/load
self._load_recovery = {'step': 0, 'epoch': 0, 'rng_state': -1}
Expand Down Expand Up @@ -284,7 +293,6 @@ def _train_one_epoch(self,

loss = self._fit_impl(batch)
train_losses.append(loss)

if (step + 1) % self._logging_freq == 0:
# Sync for profile time, delete it may be a little faster
paddle.device.cuda.synchronize()
Expand All @@ -300,6 +308,7 @@ def _train_one_epoch(self,
'loss': sum(numpy_losses) / len(numpy_losses),
'lr': self._optimizer.get_lr()
}
# if paddle.distributed.get_rank() == 1:
self._module.training_step_end(log_dict)

train_start = time.time()
Expand Down Expand Up @@ -333,9 +342,10 @@ def _train_one_epoch(self,
}
self._module.validation_step_end(log_dict)

self._module.model.train()
if paddle.distributed.get_rank() == 1:
self._module.model.train()

if self._save_steps > 0 and step % self._save_steps == 0:
if self._save_steps > 0 and step % self._save_steps == 0 and paddle.distributed.get_rank() == 1:
paddle.device.cuda.synchronize()
self.save(epoch=epoch_index, step=step)
else:
Expand All @@ -361,7 +371,10 @@ def fit(self, epoch=1, train_data_loader=None, valid_data_loader=None):
valid_data_loader(DataLoader, None): a collection of :class:`paddle.io.DataLoader`, specifying validation samples.

"""
self._module.model.train()
if paddle.distributed.get_rank() == 0:
self._module.model.eval()
else:
self._module.model.train()

train_cost = 0.0
train_start = time.time()
Expand All @@ -386,7 +399,8 @@ def fit(self, epoch=1, train_data_loader=None, valid_data_loader=None):
if self._run_mode == 'epoch' and self._eval_freq > 0 and \
epoch_index % self._eval_freq == 0:
self._evaluate_one_epoch(epoch_index, valid_data_loader)
self._module.model.train()
if paddle.distributed.get_rank() == 1:
self._module.model.train()
eval_cost = time.time() - eval_start
log_dict = {
'epoch': epoch_index,
Expand Down Expand Up @@ -416,7 +430,8 @@ def _fit_impl(self, batch):
self._dp_group)
else:
loss = self._model_forward_backward(batch)
self._optim_update_params()
if paddle.distributed.get_rank() == 1:
self._optim_update_params()
else:
with paddle.amp.auto_cast(
self._use_pure_fp16,
Expand Down Expand Up @@ -450,13 +465,14 @@ def _model_forward_backward(self, batch):
level='O2'):
loss = self._module.training_step(micro_batch)

if paddle.distributed.get_rank() == 1:
loss_bw = self._scaler.scale(loss) if self._use_pure_fp16 else loss
self._module.backward(loss_bw)
detach_loss = loss.detach()
if final_loss is None:
final_loss = detach_loss
else:
final_loss = paddle.add(final_loss, detach_loss)
detach_loss = loss.detach()
if final_loss is None:
final_loss = detach_loss
else:
final_loss = paddle.add(final_loss, detach_loss)
if self._accumulate_steps > 1:
final_loss = final_loss / self._accumulate_steps
return final_loss
Expand Down Expand Up @@ -625,16 +641,20 @@ def save(self, epoch=0, step=0):
logger.info("DP_Rank %d doesn't save model" % self._dp_rank)
return

if paddle.distributed.get_rank() == 0:
return

if self._output_dir and isinstance(self._output_dir, str):
output_dir = os.path.join(self._output_dir,
"epoch_%d_step_%d" % (epoch, step))
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
logger.info("Save model to %s" % output_dir)

save_dir = "{}/mp_{:0>2d}_sharding_{:0>2d}_pp_{:0>2d}".format(
output_dir, self._mp_rank, self._sharding_rank,
self._pp_rank) if self._distributed else output_dir
# save_dir = "{}/mp_{:0>2d}_sharding_{:0>2d}_pp_{:0>2d}".format(
# output_dir, self._mp_rank, self._sharding_rank,
# self._pp_rank) if self._distributed else output_dir
save_dir = output_dir

if self._sharding_stage == 3:
self._module.model.get_all_parameters(convert2cpu=False)
Expand All @@ -657,55 +677,42 @@ def load(self):
"""
load the saved checkpoint file and update the state dicts of model and optimizer.
"""
if paddle.distributed.get_rank() == 0:
self._ckpt_dir = 'pretrain_model/'
elif paddle.distributed.get_rank() == 1:
self._ckpt_dir = 'PaddleFleetX_GPT_345M_220826/'
self._ckpt_dir = 'output/epoch_0_step_10000/'
if self._ckpt_dir and isinstance(self._ckpt_dir, str):
logger.info("Try to load checkpoint from %s " % self._ckpt_dir)

load_dir = "{}/mp_{:0>2d}_sharding_{:0>2d}_pp_{:0>2d}".format(
self._ckpt_dir, self._mp_rank, self._sharding_rank,
self._pp_rank) if self._distributed else self._ckpt_dir
load_dir = self._ckpt_dir
model_path = os.path.join(load_dir, "model.pdparams")
opt_path = os.path.join(load_dir, "model_state.pdopt")
meta_path = os.path.join(load_dir, "meta_state.pdopt")

if os.path.exists(model_path):
model_dict = paddle.load(model_path)
for name, param in self._module.model.state_dict().items():
print('trying to load {}'.format(name))
assert name in model_dict.keys(
), "No param named `{}` was found in checkpoint file.".format(
name)

if param.dtype != model_dict[name].dtype:
model_dict[name] = model_dict[name].cast(param.dtype)

print("load: {}".format(name))
self._module.model.set_state_dict(model_dict)
else:
raise ValueError("No optimizer checkpoint file found in %s." %
model_path)

if self.mode == 'train':
if os.path.exists(opt_path):
opt_dict = paddle.load(opt_path)
self._optimizer.set_state_dict(opt_dict)
else:
raise ValueError(
"No optimizer checkpoint file found in %s." % opt_path)

if os.path.exists(meta_path):
meta_dict = paddle.load(meta_path)
self._load_recovery = {
'step': meta_dict['step'],
'epoch': meta_dict['epoch'],
'rng_state': meta_dict['cuda_rng_state']
}
else:
raise ValueError("No meta checkpoint file found in %s." %
meta_path)

logger.info("successfully load checkpoints")
else:
logger.warning("`load` requires a valid value of `ckpt_dir`.")
raise TypeError("`load` requires a valid value of `ckpt_dir`.")

info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)

print("Total memory:", info.total/1024**3, "GB")
print("Free memory:", info.free/1024**3, "GB")
print("Used memory:", info.used/1024**3, "GB")

def export(self):
self._module.model.eval()
input_spec = self._module.input_spec()
Expand Down
7 changes: 5 additions & 2 deletions ppfleetx/data/dataset/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def __init__(self,

local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))


'''
if local_rank == 0:
try:
import ppfleetx.data.data_tools.cpp.fast_index_map_helpers
Expand All @@ -68,7 +70,7 @@ def __init__(self,
flush=True)

device_world_size = paddle.distributed.get_world_size()

device_world_size = 1
if device_world_size > 1 and local_rank != 0:
while True:
try:
Expand All @@ -80,12 +82,13 @@ def __init__(self,

try:
data_world_size = env.get_data_world_size()

print(data_world_size)
logger.info(
"The distributed run, total device num:{}, distinct dataflow num:{}.".
format(device_world_size, data_world_size))
except AttributeError:
pass
'''

assert len(input_dir) == 1, "GPT only support one dataset for now."

Expand Down
2 changes: 1 addition & 1 deletion ppfleetx/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import copy

from ppfleetx.core.module.basic_module import BasicModule
from ppfleetx.models.language_model.language_module import GPTModule, GPTGenerationModule, GPTEvalModule, GPTFinetuneModule
from ppfleetx.models.language_model.language_module import GPTModule, GPTGenerationModule, GPTEvalModule, GPTFinetuneModule, GPTDistillModule
from ppfleetx.models.language_model.gpt.auto.auto_module import GPTModuleAuto, GPTGenerationModuleAuto
from ppfleetx.models.vision_model.general_classification_module import GeneralClsModule
from ppfleetx.models.multimodal_model.multimodal_module import ImagenModule
Expand Down
Loading