-
Notifications
You must be signed in to change notification settings - Fork 57
Open
Description
To run CUDA_VISIBLE_DEVICES=0,1,2,3 bash training_scripts/baichuan/run_baichuan_7b.sh
You should also pip install bitsandbytes
and add trust_remote_code=True
in utils.py and model_utils.py
`def get_tokenizer(model_name_or_path, fast_tokenizer=True):
if "llama" in model_name_or_path:
from transformers.models.llama import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained(
model_name_or_path, fast_tokenizer=fast_tokenizer, trust_remote_code=True
)
if tokenizer.pad_token is None:
# assert tokenizer.eos_token is not None
# tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.padding_side = "right"
else:
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, fast_tokenizer=fast_tokenizer, trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
# make sure tokenizer is right pad in our logic
tokenizer.padding_side = "right"
return tokenizer`
`def create_hf_model(model_class,
model_name_or_path,
tokenizer,
ds_config=None,
rlhf_training=False,
dropout=None):
model_config = AutoConfig.from_pretrained(model_name_or_path,trust_remote_code=True)
configure_dropout(model_config, dropout)
# Note: dschf is defined in function scope to avoid global effects
# https://huggingface.co/docs/transformers/main_classes/deepspeed#nontrainer-deepspeed-integration
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
dschf = HfDeepSpeedConfig(ds_config)
else:
dschf = None
if rlhf_training:
# the weight loading is handled by create critic model
model = model_class.from_config(model_config,trust_remote_code=True)
else:
model = model_class.from_pretrained(
model_name_or_path,
from_tf=bool(".ckpt" in model_name_or_path),
config=model_config,trust_remote_code=True)
model.config.end_token_id = tokenizer.eos_token_id
model.config.pad_token_id = model.config.eos_token_id
model.resize_token_embeddings(int(
8 *
math.ceil(len(tokenizer) / 8.0))) # make the vocab size multiple of 8
return model`
Metadata
Metadata
Assignees
Labels
No labels
Activity