Skip to content

Commit a0f3bc1

Browse files
committed
fixes with compatibility with GPflow API
1 parent 2b6a5b8 commit a0f3bc1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

dgplib/specialized_kernels.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, kern_list, output_dim, name=None):
1717
super(SwitchedKernel, self).__init__(kernels=kern_list,
1818
name=name)
1919
self.output_dim = output_dim
20-
self.num_kernels = len(self.kern_list)
20+
self.num_kernels = len(self.kernels)
2121
assert self.output_dim==self.num_kernels
2222

2323
@params_as_tensors
@@ -43,7 +43,7 @@ def K(self, X, X2=None, presliced=False):
4343
ind_X2, self.output_dim)
4444

4545
Ks = []
46-
for k, p, p2 in zip(self.kern_list, ind_X_parts, ind_X2_parts):
46+
for k, p, p2 in zip(self.kernels, ind_X_parts, ind_X2_parts):
4747
gram = k.K(tf.gather(X, p), tf.gather(X2, p2))
4848
Ks.append(gram)
4949

@@ -68,5 +68,5 @@ def Kdiag(self, X, prescliced=False):
6868
ind_X_parts = tf.dynamic_partition(tf.range(0, tf.size(ind_X)),
6969
ind_X, self.output_dim)
7070

71-
Ks = [k.Kdiag(tf.gather(X, p)) for k, p in zip(self.kern_list, ind_X_parts)]
71+
Ks = [k.Kdiag(tf.gather(X, p)) for k, p in zip(self.kernels, ind_X_parts)]
7272
return tf.dynamic_stitch(ind_X_parts, Ks)

0 commit comments

Comments
 (0)