Skip to content

Reproducing Random graph matching Results #22

@BarakeelFanseuKamhoua

Description

@BarakeelFanseuKamhoua

Hello Rusty,
Thanks for the amazing paper. I am so sorry that I am not able to reproduce the results on erdos regny random graph matching using the softmax for the different noise levels till 0.5 as defined in figure 2 (a, b, c, and d) of the paper. My results are orders of magnitude lower using the torch model implementation and the random graph generation below:

Generate Erdos-Renyi graphs

def to_onehot(mat):
k = mat.shape[0]
encoded_arr = np.zeros((mat.size,k), dtype=int)
encoded_arr[np.arange(mat.size), mat.astype(int)] = 1
return encoded_arr

def compute_evecs(A, k, type=1):
if type==1:
return np.abs(sp.linalg.svds(A, k, which="LM")[0])
else:
D = [email protected]((A.shape[0], 1))
return to_onehot(D)

def generate_er(n, p, sigma, learned=False, feat_type=1, feat_num=20):
s = 1 - (sigma**2)(1-p)
G = np.random.uniform(size=(n,n)) < p/s
G = np.tril(G, -1) + np.tril(G, -1).T
Z1 = np.random.uniform(size=(n,n)) < s
Z1 = np.tril(Z1, -1) + np.tril(Z1, -1).T
Z2 = np.random.uniform(size=(n,n)) < s
Z2 = np.tril(Z2, -1) + np.tril(Z2, -1).T
A0 = (G * Z1).astype(float)
B0 = (G * Z2).astype(float)
P_rnd = np.eye(n)
idx = np.random.permutation(n)
P_rnd = P_rnd[:, idx]
B0 = P_rnd @ B0 @ P_rnd.T
A = A0 - p
B = B0 - p
A = A/np.sqrt(n
p*(1-p))
B = B/np.sqrt(np(1-p))
P_orig = P_rnd
if learned:
P_rnd = np.array(P_rnd.nonzero()).T
real_symm = False
if np.allclose(A @ A.T, A.T @ A) and np.allclose(B @ B.T, B @ B.T):
real_symm = True
x_A0 = compute_evecs(A0, feat_num, feat_type)
x_B0 = compute_evecs(B0, feat_num, feat_type)
return A, B, A0, B0, x_A0, x_B0, P_rnd, P_orig

Model

class DGMC(torch.nn.Module):
def init(self, psi_1, psi_2, num_steps, k=-1, detach=False):
super(DGMC, self).init()
self.psi_1 = psi_1
self.psi_2 = psi_2
self.num_steps = num_steps
self.k = k
self.detach = detach
self.mlp = nn.Sequential(nn.Linear(psi_2.output_dim, psi_2.output_dim),
nn.ReLU(), nn.Linear(psi_2.output_dim, 1),)

def reset_parameters(self):
      self.psi_1.reset_parameters()
      self.psi_2.reset_parameters()
      for layer in self.mlp:
          if hasattr(layer, 'reset_parameters'):
              layer.reset_parameters()

def forward(self, x_s, adj_s, x_t, adj_t, y=None):
      h_s = self.psi_1(adj_s, x_s)
      h_t = self.psi_1(adj_t, x_t)
      h_s, h_t = (h_s.detach(), h_t.detach()) if self.detach else (h_s, h_t)
      B, N_s, C_out = h_s.size()
      N_t = h_t.size(1)
      R_in, R_out = self.psi_2.input_dim, self.psi_2.output_dim
      S_hat = h_s @ h_t.transpose(-1, -2)
      S_0 = torch.softmax(S_hat, dim=-1)
      for _ in range(self.num_steps):
            S = torch.softmax(S_hat, dim=-1)
            r_s = torch.randn((B, N_s, R_in), dtype=h_s.dtype, device=h_s.device)
            r_t = S.transpose(-1, -2) @ r_s                
            o_s = self.psi_2(adj_s, r_s)
            o_t = self.psi_2(adj_t, r_t)
            D = o_s.view(B, N_s, 1, R_out) - o_t.view(B, 1, N_t, R_out)
            S_hat = S_hat + self.mlp(D).squeeze(-1)
      S_L = torch.softmax(S_hat, dim=-1)
      return S_0, S_L
    
def loss(self, S, y, reduction='mean', EPS=1e-8):
      assert reduction in ['none', 'mean', 'sum']
      B = S.size(0)
      nll = 0
      for i in range(B):
          val = S[i][y[i, :, 0], y[i, :, 1]]
          nll += -torch.log(val + EPS).sum()
      if reduction == 'mean':
          nll = nll / y.size(0)
      elif reduction == 'sum':
          nll = nll
      return nll

@torch.no_grad()
def acc(self, S, y, reduction='mean'):
      assert reduction in ['mean', 'sum']
      B = S.size(0)
      correct = 0
      total = 0
      for i in range(B):
          pred = S[i].argmax(dim=-1)
          correct += (pred == y[i, :, 1]).sum().item()
          total += y[i, :, 1].size(0)
      accuracy = correct / total
      return accuracy if reduction == 'mean' else correct

@torch.no_grad()
def hits_at_k(self, k, S, y, reduction='mean'):
      assert reduction in ['mean', 'sum']
      B = S.size(0)
      correct = 0
      total = 0
      for i in range(B):
          pred = S[i].argsort(dim=-1, descending=True)[:, :k]
          correct += (pred == y[i, :, 1].view(-1, 1)).sum().item()
          total += y[i, :, 1].size(0)
      hits = correct / total
      return hits if reduction == 'mean' else correct

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions