@@ -54,124 +54,124 @@ def KNN(features, k):
5454 return dists , Idx
5555
5656
57- def test_softknn (debug = False ):
58- # doesn't always converge in time for all random seed
59- seed = 318455
60- logger .info ('random seed: %d' , rand .seed (seed ))
61-
62- D_in = 3
63- D_out = 1
64-
65- target_params = torch .rand (D_in , D_out ).t ()
66- # target_params = torch.tensor([[1, -1, 1]], dtype=torch.float )
67- target_tsf = torch .nn .Linear (D_in , D_out , bias = False )
68- target_tsf .weight .data = target_params
69- for param in target_tsf .parameters ():
70- param .requires_grad = False
71-
72- def produce_output (X ):
73- # get the features
74- y = target_tsf (X )
75- # cluster in feature space
76- dists , Idx = KNN (y , 5 )
77-
78- # take the sum inside each neighbourhood
79- # TODO do a least square fit over X inside each neighbourhood
80- features2 = torch .zeros_like (X )
81- for i in range (dists .shape [0 ]):
82- # md = max(dists[i])
83- # d = md - dists[i]
84- # w = d / torch.norm(d)
85- features2 [i ] = torch .mean (X [Idx [i ]], 0 )
86- # features2[i] = torch.matmul(w, X[Idx[i]])
87-
88- return features2
89-
90- N = 400
91- ds = load_data .RandomNumberDataset (produce_output , num = 400 , input_dim = D_in )
92- train_set , validation_set = load_data .split_train_validation (ds )
93- train_loader = torch .utils .data .DataLoader (train_set , batch_size = N , shuffle = True )
94- val_loader = torch .utils .data .DataLoader (validation_set , batch_size = N , shuffle = False )
95-
96- criterion = torch .nn .MSELoss (reduction = 'sum' )
97-
98- model = SimpleNet (D_in , D_out )
99- optimizer = torch .optim .SGD (model .parameters (), lr = 1e-4 )
100-
101- losses = []
102- vlosses = []
103- pdist = []
104- cosdist = []
105-
106- def evaluateLoss (data ):
107- # target
108- x , y = data
109- pred = model (x , y )
110-
111- loss = criterion (pred , y )
112- return loss
113-
114- def evaluateValidation ():
115- with torch .no_grad ():
116- loss = sum (evaluateLoss (data ) for data in val_loader )
117- return loss / len (val_loader .dataset )
118-
119- # model.linear1.weight.data = target_params.clone()
120- for epoch in range (200 ):
121- for i , data in enumerate (train_loader , 0 ):
122- optimizer .zero_grad ()
123-
124- loss = evaluateLoss (data )
125- loss .backward ()
126- optimizer .step ()
127-
128- avg_loss = loss .item () / len (data [0 ])
129-
130- losses .append (avg_loss )
131- vlosses .append (evaluateValidation ())
132- pdist .append (torch .norm (model .linear1 .weight .data - target_params ))
133- cosdist .append (torch .nn .functional .cosine_similarity (model .linear1 .weight .data , target_params ))
134- if debug :
135- print ('[%d, %5d] loss: %.3f' % (epoch + 1 , i + 1 , avg_loss ))
136-
137- if debug :
138- print ('Finished Training' )
139- print ('Target params: {}' .format (target_params ))
140- print ('Learned params:' )
141- for param in model .parameters ():
142- print (param )
143-
144- print ('validation total loss: {:.3f}' .format (evaluateValidation ()))
145-
146- model .linear1 .weight .data = target_params .clone ()
147- target_loss = evaluateValidation ()
148-
149- if debug :
150- print ('validation total loss with target params: {:.3f}' .format (target_loss ))
151-
152- plt .plot (range (len (losses )), losses )
153- plt .plot (range (len (losses )), vlosses )
154- plt .plot (range (len (losses )), [target_loss ] * len (losses ), linestyle = '--' )
155- plt .legend (['training minibatch' , 'whole validation' , 'validation with target params' ])
156- plt .xlabel ('minibatch' )
157- plt .ylabel ('MSE loss' )
158-
159- plt .figure ()
160- plt .plot (range (len (pdist )), pdist )
161- plt .xlabel ('minibatch' )
162- plt .ylabel ('euclidean distance of model params from target' )
163-
164- plt .figure ()
165- plt .plot (range (len (cosdist )), cosdist )
166- plt .xlabel ('minibatch' )
167- plt .ylabel ('cosine similarity between model params and target' )
168- plt .show ()
169-
170- # check that we're close to the actual KNN performance on validation set
171- last_few = 5
172- loss_tolerance = 0.02
173- assert sum (vlosses [- last_few :]) / last_few - target_loss < target_loss * loss_tolerance
174-
175-
176- if __name__ == "__main__" :
177- test_softknn (True )
57+ # def test_softknn(debug=False):
58+ # # doesn't always converge in time for all random seed
59+ # seed = 318455
60+ # logger.info('random seed: %d', rand.seed(seed))
61+ #
62+ # D_in = 3
63+ # D_out = 1
64+ #
65+ # target_params = torch.rand(D_in, D_out).t()
66+ # # target_params = torch.tensor([[1, -1, 1]], dtype=torch.float )
67+ # target_tsf = torch.nn.Linear(D_in, D_out, bias=False)
68+ # target_tsf.weight.data = target_params
69+ # for param in target_tsf.parameters():
70+ # param.requires_grad = False
71+ #
72+ # def produce_output(X):
73+ # # get the features
74+ # y = target_tsf(X)
75+ # # cluster in feature space
76+ # dists, Idx = KNN(y, 5)
77+ #
78+ # # take the sum inside each neighbourhood
79+ # # TODO do a least square fit over X inside each neighbourhood
80+ # features2 = torch.zeros_like(X)
81+ # for i in range(dists.shape[0]):
82+ # # md = max(dists[i])
83+ # # d = md - dists[i]
84+ # # w = d / torch.norm(d)
85+ # features2[i] = torch.mean(X[Idx[i]], 0)
86+ # # features2[i] = torch.matmul(w, X[Idx[i]])
87+ #
88+ # return features2
89+ #
90+ # N = 400
91+ # ds = load_data.RandomNumberDataset(produce_output, num=400, input_dim=D_in)
92+ # train_set, validation_set = load_data.split_train_validation(ds)
93+ # train_loader = torch.utils.data.DataLoader(train_set, batch_size=N, shuffle=True)
94+ # val_loader = torch.utils.data.DataLoader(validation_set, batch_size=N, shuffle=False)
95+ #
96+ # criterion = torch.nn.MSELoss(reduction='sum')
97+ #
98+ # model = SimpleNet(D_in, D_out)
99+ # optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
100+ #
101+ # losses = []
102+ # vlosses = []
103+ # pdist = []
104+ # cosdist = []
105+ #
106+ # def evaluateLoss(data):
107+ # # target
108+ # x, y = data
109+ # pred = model(x, y)
110+ #
111+ # loss = criterion(pred, y)
112+ # return loss
113+ #
114+ # def evaluateValidation():
115+ # with torch.no_grad():
116+ # loss = sum(evaluateLoss(data) for data in val_loader)
117+ # return loss / len(val_loader.dataset)
118+ #
119+ # # model.linear1.weight.data = target_params.clone()
120+ # for epoch in range(200):
121+ # for i, data in enumerate(train_loader, 0):
122+ # optimizer.zero_grad()
123+ #
124+ # loss = evaluateLoss(data)
125+ # loss.backward()
126+ # optimizer.step()
127+ #
128+ # avg_loss = loss.item() / len(data[0])
129+ #
130+ # losses.append(avg_loss)
131+ # vlosses.append(evaluateValidation())
132+ # pdist.append(torch.norm(model.linear1.weight.data - target_params))
133+ # cosdist.append(torch.nn.functional.cosine_similarity(model.linear1.weight.data, target_params))
134+ # if debug:
135+ # print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, avg_loss))
136+ #
137+ # if debug:
138+ # print('Finished Training')
139+ # print('Target params: {}'.format(target_params))
140+ # print('Learned params:')
141+ # for param in model.parameters():
142+ # print(param)
143+ #
144+ # print('validation total loss: {:.3f}'.format(evaluateValidation()))
145+ #
146+ # model.linear1.weight.data = target_params.clone()
147+ # target_loss = evaluateValidation()
148+ #
149+ # if debug:
150+ # print('validation total loss with target params: {:.3f}'.format(target_loss))
151+ #
152+ # plt.plot(range(len(losses)), losses)
153+ # plt.plot(range(len(losses)), vlosses)
154+ # plt.plot(range(len(losses)), [target_loss] * len(losses), linestyle='--')
155+ # plt.legend(['training minibatch', 'whole validation', 'validation with target params'])
156+ # plt.xlabel('minibatch')
157+ # plt.ylabel('MSE loss')
158+ #
159+ # plt.figure()
160+ # plt.plot(range(len(pdist)), pdist)
161+ # plt.xlabel('minibatch')
162+ # plt.ylabel('euclidean distance of model params from target')
163+ #
164+ # plt.figure()
165+ # plt.plot(range(len(cosdist)), cosdist)
166+ # plt.xlabel('minibatch')
167+ # plt.ylabel('cosine similarity between model params and target')
168+ # plt.show()
169+ #
170+ # # check that we're close to the actual KNN performance on validation set
171+ # last_few = 5
172+ # loss_tolerance = 0.02
173+ # assert sum(vlosses[-last_few:]) / last_few - target_loss < target_loss * loss_tolerance
174+
175+
176+ # if __name__ == "__main__":
177+ # test_softknn(False )
0 commit comments