File tree Expand file tree Collapse file tree
chapter15_Differential_Privacy Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -26,9 +26,13 @@ def get_model(name="vgg16", pretrained=True):
2626 else :
2727 return model
2828
29+
2930def model_norm (model_1 , model_2 ):
30- squared_sum = 0
31- for name , layer in model_1 .named_parameters ():
32- # print(torch.mean(layer.data), torch.mean(model_2.state_dict()[name].data))
33- squared_sum += torch .sum (torch .pow (layer .data - model_2 .state_dict ()[name ].data , 2 ))
34- return math .sqrt (squared_sum )
31+ params_1 = torch .cat ([param .view (- 1 ) for param in model_1 .parameters ()])
32+ params_2 = torch .cat ([param .view (- 1 ) for param in model_2 .parameters ()])
33+
34+ return torch .norm (params_1 - params_2 , p = 2 )
35+
36+ def quick_model_norm (model_1 , model_2 ):
37+ diffs = [(p1 - p2 ).view (- 1 ) for p1 , p2 in zip (model_1 .parameters (), model_2 .parameters ())]
38+ return torch .norm (torch .cat (diffs ), p = 2 )
You can’t perform that action at this time.
0 commit comments