diff --git a/inferencer.py b/inferencer.py index 7f19d1a..4abdcb8 100644 --- a/inferencer.py +++ b/inferencer.py @@ -7,8 +7,8 @@ from PIL import Image import torch -from data.data_utils import pil_img2rgb -from modeling.bagel.qwen2_navit import NaiveCache +from .data.data_utils import pil_img2rgb +from .modeling.bagel.qwen2_navit import NaiveCache diff --git a/modeling/bagel/bagel.py b/modeling/bagel/bagel.py index 7f38131..7cba1e4 100644 --- a/modeling/bagel/bagel.py +++ b/modeling/bagel/bagel.py @@ -13,7 +13,7 @@ from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel -from data.data_utils import ( +from ...data.data_utils import ( create_sparse_mask, get_flattened_position_ids_extrapolate, get_flattened_position_ids_interpolate, diff --git a/modeling/bagel/qwen2_navit.py b/modeling/bagel/qwen2_navit.py index f818565..7363be5 100644 --- a/modeling/bagel/qwen2_navit.py +++ b/modeling/bagel/qwen2_navit.py @@ -22,7 +22,7 @@ from transformers.utils import ModelOutput from flash_attn import flash_attn_varlen_func -from modeling.qwen2.modeling_qwen2 import ( +from ..qwen2.modeling_qwen2 import ( Qwen2Attention, Qwen2MLP, Qwen2PreTrainedModel, @@ -31,7 +31,7 @@ apply_rotary_pos_emb, ) -from modeling.qwen2.configuration_qwen2 import Qwen2Config as _Qwen2Config +from ..qwen2.configuration_qwen2 import Qwen2Config as _Qwen2Config torch._dynamo.config.cache_size_limit = 512 diff --git a/modeling/bagel/siglip_navit.py b/modeling/bagel/siglip_navit.py index 26efea3..716481c 100644 --- a/modeling/bagel/siglip_navit.py +++ b/modeling/bagel/siglip_navit.py @@ -13,8 +13,8 @@ from torch import nn from transformers.activations import ACT2FN -from modeling.siglip.configuration_siglip import SiglipVisionConfig as _SiglipVisionConfig -from modeling.siglip.modeling_siglip import SiglipAttention, SiglipPreTrainedModel +from ..siglip.configuration_siglip import SiglipVisionConfig as _SiglipVisionConfig +from ..siglip.modeling_siglip import SiglipAttention, SiglipPreTrainedModel from flash_attn import flash_attn_varlen_func diff --git a/nodes.py b/nodes.py index 1b15e0f..dbcf21d 100644 --- a/nodes.py +++ b/nodes.py @@ -22,15 +22,15 @@ from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights from safetensors.torch import load_file -from data.transforms import ImageTransform -from data.data_utils import pil_img2rgb, add_special_tokens -from modeling.bagel import ( +from .data.transforms import ImageTransform +from .data.data_utils import pil_img2rgb, add_special_tokens +from .modeling.bagel import ( BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel ) -from modeling.qwen2 import Qwen2Tokenizer -from modeling.bagel.qwen2_navit import NaiveCache -from modeling.autoencoder import load_ae -from inferencer import InterleaveInferencer +from .modeling.qwen2 import Qwen2Tokenizer +from .modeling.bagel.qwen2_navit import NaiveCache +from .modeling.autoencoder import load_ae +from .inferencer import InterleaveInferencer class LoadBAGELModel: