-
Notifications
You must be signed in to change notification settings - Fork 249
Expand file tree
/
Copy pathmodels.py
More file actions
38 lines (31 loc) · 1.19 KB
/
models.py
File metadata and controls
38 lines (31 loc) · 1.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from torchvision import models
import math
def get_model(name="vgg16", pretrained=True):
if name == "resnet18":
model = models.resnet18(pretrained=pretrained)
elif name == "resnet50":
model = models.resnet50(pretrained=pretrained)
elif name == "densenet121":
model = models.densenet121(pretrained=pretrained)
elif name == "alexnet":
model = models.alexnet(pretrained=pretrained)
elif name == "vgg16":
model = models.vgg16(pretrained=pretrained)
elif name == "vgg19":
model = models.vgg19(pretrained=pretrained)
elif name == "inception_v3":
model = models.inception_v3(pretrained=pretrained)
elif name == "googlenet":
model = models.googlenet(pretrained=pretrained)
if torch.cuda.is_available():
return model.cuda()
else:
return model
def model_norm(model_1, model_2):
params_1 = torch.cat([param.view(-1) for param in model_1.parameters()])
params_2 = torch.cat([param.view(-1) for param in model_2.parameters()])
return torch.norm(params_1 - params_2, p = 2)
def quick_model_norm(model_1, model_2):
diffs = [(p1 - p2).view(-1) for p1, p2 in zip(model_1.parameters(), model_2.parameters())]
return torch.norm(torch.cat(diffs), p = 2)