Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 77 additions & 39 deletions inference/v1/models/rfdetr/rfdetr_backbone_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from transformers import AutoBackbone
import types

from inference.v1.models.rfdetr.dinov2_with_windowed_attn import WindowedDinov2WithRegistersConfig, \
WindowedDinov2WithRegistersBackbone
from inference.v1.models.rfdetr.dinov2_with_windowed_attn import (
WindowedDinov2WithRegistersConfig,
WindowedDinov2WithRegistersBackbone,
)
from inference.v1.models.rfdetr.misc import NestedTensor
from inference.v1.models.rfdetr.projector import MultiScaleProjector

Expand All @@ -35,29 +37,49 @@


def get_config(size, use_registers):
cache_key = (size, use_registers)
if cache_key in _config_cache:
return _config_cache[cache_key]

config_dict = size_to_config_with_registers if use_registers else size_to_config
current_dir = os.path.dirname(os.path.abspath(__file__))
configs_dir = os.path.join(current_dir, "dinov2_configs")
config_path = os.path.join(configs_dir, config_dict[size])
config_file = config_dict[size]
config_path = os.path.join(_configs_dir, config_file)
with open(config_path, "r") as f:
dino_config = json.load(f)
_config_cache[cache_key] = dino_config
return dino_config


class DinoV2(nn.Module):
def __init__(self, shape=(640, 640), out_feature_indexes=[2, 4, 5, 9], size="base", use_registers=True,
use_windowed_attn=True, gradient_checkpointing=False, load_dinov2_weights=True):
def __init__(
self,
shape=(640, 640),
out_feature_indexes=[2, 4, 5, 9],
size="base",
use_registers=True,
use_windowed_attn=True,
gradient_checkpointing=False,
load_dinov2_weights=True,
):
super().__init__()

name = f"facebook/dinov2-with-registers-{size}" if use_registers else f"facebook/dinov2-{size}"
name = (
f"facebook/dinov2-with-registers-{size}"
if use_registers
else f"facebook/dinov2-{size}"
)

self.shape = shape

# Create the encoder

if not use_windowed_attn:
assert not gradient_checkpointing, "Gradient checkpointing is not supported for non-windowed attention"
assert load_dinov2_weights, "Using non-windowed attention requires loading dinov2 weights from hub"
assert (
not gradient_checkpointing
), "Gradient checkpointing is not supported for non-windowed attention"
assert (
load_dinov2_weights
), "Using non-windowed attention requires loading dinov2 weights from hub"
self.encoder = AutoBackbone.from_pretrained(
name,
out_features=[f"stage{i}" for i in out_feature_indexes],
Expand Down Expand Up @@ -88,10 +110,14 @@ def __init__(self, shape=(640, 640), out_feature_indexes=[2, 4, 5, 9], size="bas
num_register_tokens=0,
gradient_checkpointing=gradient_checkpointing,
)
self.encoder = WindowedDinov2WithRegistersBackbone.from_pretrained(
name,
config=windowed_dino_config,
) if load_dinov2_weights else WindowedDinov2WithRegistersBackbone(windowed_dino_config)
self.encoder = (
WindowedDinov2WithRegistersBackbone.from_pretrained(
name,
config=windowed_dino_config,
)
if load_dinov2_weights
else WindowedDinov2WithRegistersBackbone(windowed_dino_config)
)

self._out_feature_channels = [size_to_width[size]] * len(out_feature_indexes)
self._export = False
Expand All @@ -103,7 +129,7 @@ def export(self):
shape = self.shape

def make_new_interpolated_pos_encoding(
position_embeddings, patch_size, height, width
position_embeddings, patch_size, height, width
):

num_positions = position_embeddings.shape[1] - 1
Expand Down Expand Up @@ -154,13 +180,13 @@ def new_interpolate_pos_encoding(self_mod, embeddings, height, width):

self.encoder.embeddings.position_embeddings = nn.Parameter(new_positions)
self.encoder.embeddings.interpolate_pos_encoding = types.MethodType(
new_interpolate_pos_encoding,
self.encoder.embeddings
new_interpolate_pos_encoding, self.encoder.embeddings
)

def forward(self, x):
assert x.shape[2] % 14 == 0 and x.shape[
3] % 14 == 0, f"Dinov2 requires input shape to be divisible by 14, but got {x.shape}"
assert (
x.shape[2] % 14 == 0 and x.shape[3] % 14 == 0
), f"Dinov2 requires input shape to be divisible by 14, but got {x.shape}"
x = self.encoder(x)
return list(x[0])

Expand All @@ -169,29 +195,31 @@ class BackboneBase(nn.Module):
def __init__(self):
super().__init__()

def get_named_param_lr_pairs(self, args, prefix:str):
def get_named_param_lr_pairs(self, args, prefix: str):
raise NotImplementedError


class Backbone(BackboneBase):
"""backbone."""
def __init__(self,
name: str,
pretrained_encoder: str=None,
window_block_indexes: list=None,
drop_path=0.0,
out_channels=256,
out_feature_indexes: list=None,
projector_scale: list=None,
use_cls_token: bool = False,
freeze_encoder: bool = False,
layer_norm: bool = False,
target_shape: tuple[int, int] = (640, 640),
rms_norm: bool = False,
backbone_lora: bool = False,
gradient_checkpointing: bool = False,
load_dinov2_weights: bool = True,
):

def __init__(
self,
name: str,
pretrained_encoder: str = None,
window_block_indexes: list = None,
drop_path=0.0,
out_channels=256,
out_feature_indexes: list = None,
projector_scale: list = None,
use_cls_token: bool = False,
freeze_encoder: bool = False,
layer_norm: bool = False,
target_shape: tuple[int, int] = (640, 640),
rms_norm: bool = False,
backbone_lora: bool = False,
gradient_checkpointing: bool = False,
load_dinov2_weights: bool = True,
):
super().__init__()
# an example name here would be "dinov2_base" or "dinov2_registers_windowed_base"
# if "registers" is in the name, then use_registers is set to True, otherwise it is set to False
Expand All @@ -209,7 +237,9 @@ def __init__(self,
if "windowed" in name_parts:
use_windowed_attn = True
name_parts.remove("windowed")
assert len(name_parts) == 2, "name should be dinov2, then either registers, windowed, both, or none, then the size"
assert (
len(name_parts) == 2
), "name should be dinov2, then either registers, windowed, both, or none, then the size"
self.encoder = DinoV2(
size=name_parts[-1],
out_feature_indexes=out_feature_indexes,
Expand Down Expand Up @@ -326,6 +356,7 @@ def get_dinov2_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
layer_id = int(name[name.find(".layer.") :].split(".")[2]) + 1
return lr_decay_rate ** (num_layers + 1 - layer_id)


def get_dinov2_weight_decay_rate(name, weight_decay_rate=1.0):
if (
("gamma" in name)
Expand All @@ -336,4 +367,11 @@ def get_dinov2_weight_decay_rate(name, weight_decay_rate=1.0):
or ("embeddings" in name)
):
weight_decay_rate = 0.0
return weight_decay_rate
return weight_decay_rate


_current_dir = os.path.dirname(os.path.abspath(__file__))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

64x spedup sounds nice, but really its moving the loading the files to module load time, which given that this might not get instantiated at all might actually slow down things down on average?


_configs_dir = os.path.join(_current_dir, "dinov2_configs")

_config_cache = {}