@@ -17,7 +17,7 @@ def __init__(self, kern_list, output_dim, name=None):
17
17
super (SwitchedKernel , self ).__init__ (kernels = kern_list ,
18
18
name = name )
19
19
self .output_dim = output_dim
20
- self .num_kernels = len (self .kern_list )
20
+ self .num_kernels = len (self .kernels )
21
21
assert self .output_dim == self .num_kernels
22
22
23
23
@params_as_tensors
@@ -43,7 +43,7 @@ def K(self, X, X2=None, presliced=False):
43
43
ind_X2 , self .output_dim )
44
44
45
45
Ks = []
46
- for k , p , p2 in zip (self .kern_list , ind_X_parts , ind_X2_parts ):
46
+ for k , p , p2 in zip (self .kernels , ind_X_parts , ind_X2_parts ):
47
47
gram = k .K (tf .gather (X , p ), tf .gather (X2 , p2 ))
48
48
Ks .append (gram )
49
49
@@ -68,5 +68,5 @@ def Kdiag(self, X, prescliced=False):
68
68
ind_X_parts = tf .dynamic_partition (tf .range (0 , tf .size (ind_X )),
69
69
ind_X , self .output_dim )
70
70
71
- Ks = [k .Kdiag (tf .gather (X , p )) for k , p in zip (self .kern_list , ind_X_parts )]
71
+ Ks = [k .Kdiag (tf .gather (X , p )) for k , p in zip (self .kernels , ind_X_parts )]
72
72
return tf .dynamic_stitch (ind_X_parts , Ks )
0 commit comments