Skip to content

Commit 80a4877

Browse files
brianhou0208rwightman
authored andcommitted
Fix self.reset_classifier num_classes update
1 parent 84631cb commit 80a4877

File tree

11 files changed

+11
-1
lines changed

11 files changed

+11
-1
lines changed

timm/models/davit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,7 @@ def get_classifier(self) -> nn.Module:
633633
return self.head.fc
634634

635635
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
636+
self.num_classes = num_classes
636637
self.head.reset(num_classes, global_pool)
637638

638639
def forward_features(self, x):

timm/models/focalnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ def get_classifier(self) -> nn.Module:
455455
return self.head.fc
456456

457457
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
458+
self.num_classes = num_classes
458459
self.head.reset(num_classes, pool_type=global_pool)
459460

460461
def forward_features(self, x):

timm/models/metaformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,7 @@ def get_classifier(self) -> nn.Module:
584584
return self.head.fc
585585

586586
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
587+
self.num_classes = num_classes
587588
if global_pool is not None:
588589
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
589590
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()

timm/models/nextvit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@ def get_classifier(self) -> nn.Module:
557557
return self.head.fc
558558

559559
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
560+
self.num_classes = num_classes
560561
self.head.reset(num_classes, pool_type=global_pool)
561562

562563
def forward_features(self, x):

timm/models/nfnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ def get_classifier(self) -> nn.Module:
434434
return self.head.fc
435435

436436
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
437+
self.num_classes = num_classes
437438
self.head.reset(num_classes, global_pool)
438439

439440
def forward_features(self, x):

timm/models/pvt_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
384384
if global_pool is not None:
385385
assert global_pool in ('avg', '')
386386
self.global_pool = global_pool
387-
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
387+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
388388

389389
def forward_features(self, x):
390390
x = self.patch_embed(x)

timm/models/rdnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def get_classifier(self) -> nn.Module:
349349
return self.head.fc
350350

351351
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
352+
self.num_classes = num_classes
352353
self.head.reset(num_classes, global_pool)
353354

354355
def forward_features(self, x):

timm/models/regnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,7 @@ def get_classifier(self) -> nn.Module:
515515
return self.head.fc
516516

517517
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
518+
self.num_classes = num_classes
518519
self.head.reset(num_classes, pool_type=global_pool)
519520

520521
def forward_intermediates(

timm/models/tresnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def get_classifier(self) -> nn.Module:
225225
return self.head.fc
226226

227227
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
228+
self.num_classes = num_classes
228229
self.head.reset(num_classes, pool_type=global_pool)
229230

230231
def forward_features(self, x):

timm/models/vision_transformer_sam.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def get_classifier(self) -> nn.Module:
537537
return self.head
538538

539539
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
540+
self.num_classes = num_classes
540541
self.head.reset(num_classes, global_pool)
541542

542543
def forward_intermediates(

timm/models/xception_aligned.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ def get_classifier(self) -> nn.Module:
275275
return self.head.fc
276276

277277
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
278+
self.num_classes = num_classes
278279
self.head.reset(num_classes, pool_type=global_pool)
279280

280281
def forward_features(self, x):

0 commit comments

Comments
 (0)