Skip to content

Commit f35bda9

Browse files
committed
move import back
1 parent e200a3b commit f35bda9

File tree

7 files changed

+9
-394
lines changed

7 files changed

+9
-394
lines changed

apps/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import torch
22

3+
from flashmodels import Builder, Trainer, accelerate, arguments
4+
5+
36
def train():
47
torch.manual_seed(101)
58

69
# parse args
7-
from flashmodels import arguments
810
args = arguments.parse()
911

1012
# build model, tokenizer, loader, optimizer and lr_scheduler
1113
# and use accelerator to speed up training
12-
from flashmodels import Builder, Trainer, accelerate
1314
builder = Builder(args)
1415
model, loader, tokenizer = builder.build_model_dataloader()
1516
model, loader = accelerate(model, loader, args)

flashmodels/accelerators/cuda_llama_accelerator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def apply_checkpointing(self, model):
7171
checkpoint_wrapper,
7272
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
7373
)
74-
check_fn = lambda submodule: isinstance(LlamaDecoderLayer)
74+
check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)
7575
apply_activation_checkpointing(
7676
model,
7777
checkpoint_wrapper_fn=non_reentrant_wrapper,
@@ -97,7 +97,9 @@ def fsdp(self, model):
9797
convert_outputs_to_fp32(model.forward.__func__), model)
9898

9999
# Use auto_wrap_poliy for nested wrapping instead of only a top-level FSDP.
100-
auto_wrap_policy = ModuleWrapPolicy({LlamaDecoderLayer, })
100+
auto_wrap_policy = ModuleWrapPolicy({
101+
LlamaDecoderLayer,
102+
})
101103

102104
mixed_precision_policy = None
103105
if self.args.fp16 or self.args.bf16:

flashmodels/arguments.py.bak

Lines changed: 0 additions & 273 deletions
This file was deleted.

flashmodels/patch/__init__.py.bak

Lines changed: 0 additions & 5 deletions
This file was deleted.

flashmodels/patch/llama_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch_xla.core.xla_model as xm
99
from torch import nn
1010
from torchacc.dist.tp import Mesh, mark_sharding
11-
from transformer.cache_utils import Cache
1211
from transformers.models.llama.configuration_llama import LlamaConfig
1312
from transformers.models.llama.modeling_llama import (ACT2FN, LlamaRMSNorm,
1413
LlamaRotaryEmbedding,

flashmodels/patch/patch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def rewrite_load():
3333

3434
def patch_llama(use_flash_attn):
3535
patch.patch_llama(use_flash_attn)
36+
from flashmodels.patch.llama_model import (LlamaAttention,
37+
LlamaDecoderLayer, LlamaMLP)
3638
if os.environ.get("ACC_LLAMA_TP") == "1":
3739
transformers.models.llama.modeling_llama.LlamaMLP = LlamaMLP
3840
if os.getenv("XLA_USE_SPMD") == "1":

0 commit comments

Comments
 (0)