We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 105a667 commit c68d724Copy full SHA for c68d724
timm/models/_manipulate.py
@@ -8,6 +8,7 @@
8
import torch
9
import torch.utils.checkpoint
10
from torch import nn as nn
11
+from torch import Tensor
12
13
from timm.layers import use_reentrant_ckpt
14
@@ -284,7 +285,7 @@ def forward(_x):
284
285
return x
286
287
-def adapt_input_conv(in_chans, conv_weight):
288
+def adapt_input_conv(in_chans: int, conv_weight: Tensor) -> Tensor:
289
conv_type = conv_weight.dtype
290
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
291
O, I, J, K = conv_weight.shape
0 commit comments