-
Notifications
You must be signed in to change notification settings - Fork 49
Description
Hello Rusty,
Thanks for the amazing paper. I am so sorry that I am not able to reproduce the results on erdos regny random graph matching using the softmax for the different noise levels till 0.5 as defined in figure 2 (a, b, c, and d) of the paper. My results are orders of magnitude lower using the torch model implementation and the random graph generation below:
Generate Erdos-Renyi graphs
def to_onehot(mat):
k = mat.shape[0]
encoded_arr = np.zeros((mat.size,k), dtype=int)
encoded_arr[np.arange(mat.size), mat.astype(int)] = 1
return encoded_arr
def compute_evecs(A, k, type=1):
if type==1:
return np.abs(sp.linalg.svds(A, k, which="LM")[0])
else:
D = [email protected]((A.shape[0], 1))
return to_onehot(D)
def generate_er(n, p, sigma, learned=False, feat_type=1, feat_num=20):
s = 1 - (sigma**2)(1-p)
G = np.random.uniform(size=(n,n)) < p/s
G = np.tril(G, -1) + np.tril(G, -1).T
Z1 = np.random.uniform(size=(n,n)) < s
Z1 = np.tril(Z1, -1) + np.tril(Z1, -1).T
Z2 = np.random.uniform(size=(n,n)) < s
Z2 = np.tril(Z2, -1) + np.tril(Z2, -1).T
A0 = (G * Z1).astype(float)
B0 = (G * Z2).astype(float)
P_rnd = np.eye(n)
idx = np.random.permutation(n)
P_rnd = P_rnd[:, idx]
B0 = P_rnd @ B0 @ P_rnd.T
A = A0 - p
B = B0 - p
A = A/np.sqrt(np*(1-p))
B = B/np.sqrt(np(1-p))
P_orig = P_rnd
if learned:
P_rnd = np.array(P_rnd.nonzero()).T
real_symm = False
if np.allclose(A @ A.T, A.T @ A) and np.allclose(B @ B.T, B @ B.T):
real_symm = True
x_A0 = compute_evecs(A0, feat_num, feat_type)
x_B0 = compute_evecs(B0, feat_num, feat_type)
return A, B, A0, B0, x_A0, x_B0, P_rnd, P_orig
Model
class DGMC(torch.nn.Module):
def init(self, psi_1, psi_2, num_steps, k=-1, detach=False):
super(DGMC, self).init()
self.psi_1 = psi_1
self.psi_2 = psi_2
self.num_steps = num_steps
self.k = k
self.detach = detach
self.mlp = nn.Sequential(nn.Linear(psi_2.output_dim, psi_2.output_dim),
nn.ReLU(), nn.Linear(psi_2.output_dim, 1),)
def reset_parameters(self):
self.psi_1.reset_parameters()
self.psi_2.reset_parameters()
for layer in self.mlp:
if hasattr(layer, 'reset_parameters'):
layer.reset_parameters()
def forward(self, x_s, adj_s, x_t, adj_t, y=None):
h_s = self.psi_1(adj_s, x_s)
h_t = self.psi_1(adj_t, x_t)
h_s, h_t = (h_s.detach(), h_t.detach()) if self.detach else (h_s, h_t)
B, N_s, C_out = h_s.size()
N_t = h_t.size(1)
R_in, R_out = self.psi_2.input_dim, self.psi_2.output_dim
S_hat = h_s @ h_t.transpose(-1, -2)
S_0 = torch.softmax(S_hat, dim=-1)
for _ in range(self.num_steps):
S = torch.softmax(S_hat, dim=-1)
r_s = torch.randn((B, N_s, R_in), dtype=h_s.dtype, device=h_s.device)
r_t = S.transpose(-1, -2) @ r_s
o_s = self.psi_2(adj_s, r_s)
o_t = self.psi_2(adj_t, r_t)
D = o_s.view(B, N_s, 1, R_out) - o_t.view(B, 1, N_t, R_out)
S_hat = S_hat + self.mlp(D).squeeze(-1)
S_L = torch.softmax(S_hat, dim=-1)
return S_0, S_L
def loss(self, S, y, reduction='mean', EPS=1e-8):
assert reduction in ['none', 'mean', 'sum']
B = S.size(0)
nll = 0
for i in range(B):
val = S[i][y[i, :, 0], y[i, :, 1]]
nll += -torch.log(val + EPS).sum()
if reduction == 'mean':
nll = nll / y.size(0)
elif reduction == 'sum':
nll = nll
return nll
@torch.no_grad()
def acc(self, S, y, reduction='mean'):
assert reduction in ['mean', 'sum']
B = S.size(0)
correct = 0
total = 0
for i in range(B):
pred = S[i].argmax(dim=-1)
correct += (pred == y[i, :, 1]).sum().item()
total += y[i, :, 1].size(0)
accuracy = correct / total
return accuracy if reduction == 'mean' else correct
@torch.no_grad()
def hits_at_k(self, k, S, y, reduction='mean'):
assert reduction in ['mean', 'sum']
B = S.size(0)
correct = 0
total = 0
for i in range(B):
pred = S[i].argsort(dim=-1, descending=True)[:, :k]
correct += (pred == y[i, :, 1].view(-1, 1)).sum().item()
total += y[i, :, 1].size(0)
hits = correct / total
return hits if reduction == 'mean' else correct