Skip to content

Commit dc7cb7a

Browse files
committed
Remove sol in variable names
1 parent acd529f commit dc7cb7a

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

pytensor/link/jax/ops.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,11 @@ def func(*args, **kwargs):
160160

161161
### Call the function that accepts flat inputs, which in turn calls the one that
162162
### combines the inputs and static variables.
163-
jitted_sol_op_jax = jax.jit(func_flattened)
163+
jitted_jax_op = jax.jit(func_flattened)
164164
len_gz = len(pttypes_outvars)
165165

166-
vjp_sol_op_jax = _get_vjp_sol_op_jax(func_flattened, len_gz)
167-
jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax)
166+
vjp_jax_op = _get_vjp_jax_op(func_flattened, len_gz)
167+
jitted_vjp_jax_op = jax.jit(vjp_jax_op)
168168

169169
# Get classes that creates a Pytensor Op out of our function that accept
170170
# flattened inputs. They are created each time, to set a custom name for the
@@ -194,8 +194,8 @@ class VJPJAXOp_local(VJPJAXOp):
194194
outvars_treedef,
195195
input_types=pt_vars_types_flat,
196196
output_types=pttypes_outvars,
197-
jitted_sol_op_jax=jitted_sol_op_jax,
198-
jitted_vjp_sol_op_jax=jitted_vjp_sol_op_jax,
197+
jitted_jax_op=jitted_jax_op,
198+
jitted_vjp_jax_op=jitted_vjp_jax_op,
199199
)
200200

201201
### Evaluate the Pytensor Op and return unflattened results
@@ -265,8 +265,8 @@ def get_func_with_vars(self, vars):
265265
return interior_func
266266

267267

268-
def _get_vjp_sol_op_jax(jaxfunc, len_gz):
269-
def vjp_sol_op_jax(args):
268+
def _get_vjp_jax_op(jaxfunc, len_gz):
269+
def vjp_jax_op(args):
270270
y0 = args[:-len_gz]
271271
gz = args[-len_gz:]
272272
if len(gz) == 1:
@@ -290,7 +290,7 @@ def func(*inputs):
290290
else:
291291
return tuple(vjp_fn(gz))
292292

293-
return vjp_sol_op_jax
293+
return vjp_jax_op
294294

295295

296296
def _partition_jaxfunc(jaxfunc, static_vars, func_vars):
@@ -350,16 +350,16 @@ def __init__(
350350
output_treeedef,
351351
input_types,
352352
output_types,
353-
jitted_sol_op_jax,
354-
jitted_vjp_sol_op_jax,
353+
jitted_jax_op,
354+
jitted_vjp_jax_op,
355355
):
356-
self.vjp_sol_op = None
356+
self.vjp_jax_op = None
357357
self.input_treedef = input_treedef
358358
self.output_treedef = output_treeedef
359359
self.input_types = input_types
360360
self.output_types = output_types
361-
self.jitted_sol_op_jax = jitted_sol_op_jax
362-
self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax
361+
self.jitted_jax_op = jitted_jax_op
362+
self.jitted_vjp_jax_op = jitted_vjp_jax_op
363363

364364
def make_node(self, *inputs):
365365
self.num_inputs = len(inputs)
@@ -368,24 +368,24 @@ def make_node(self, *inputs):
368368
outputs = [pt.as_tensor_variable(type()) for type in self.output_types]
369369
self.num_outputs = len(outputs)
370370

371-
self.vjp_sol_op = VJPJAXOp(
371+
self.vjp_jax_op = VJPJAXOp(
372372
self.input_treedef,
373373
self.input_types,
374-
self.jitted_vjp_sol_op_jax,
374+
self.jitted_vjp_jax_op,
375375
)
376376

377377
return Apply(self, inputs, outputs)
378378

379379
def perform(self, node, inputs, outputs):
380-
results = self.jitted_sol_op_jax(inputs)
380+
results = self.jitted_jax_op(inputs)
381381
if self.num_outputs > 1:
382382
for i in range(self.num_outputs):
383383
outputs[i][0] = np.array(results[i], self.output_types[i].dtype)
384384
else:
385385
outputs[0][0] = np.array(results, self.output_types[0].dtype)
386386

387387
def perform_jax(self, *inputs):
388-
results = self.jitted_sol_op_jax(inputs)
388+
results = self.jitted_jax_op(inputs)
389389
return results
390390

391391
def grad(self, inputs, output_gradients):
@@ -399,7 +399,7 @@ def grad(self, inputs, output_gradients):
399399
)
400400
else:
401401
output_gradients[i] = pt.zeros((), self.output_types[i].dtype)
402-
result = self.vjp_sol_op(inputs, output_gradients)
402+
result = self.vjp_jax_op(inputs, output_gradients)
403403

404404
if self.num_inputs > 1:
405405
return result
@@ -413,11 +413,11 @@ def __init__(
413413
self,
414414
input_treedef,
415415
input_types,
416-
jitted_vjp_sol_op_jax,
416+
jitted_vjp_jax_op,
417417
):
418418
self.input_treedef = input_treedef
419419
self.input_types = input_types
420-
self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax
420+
self.jitted_vjp_jax_op = jitted_vjp_jax_op
421421

422422
def make_node(self, y0, gz):
423423
y0 = [
@@ -436,15 +436,15 @@ def make_node(self, y0, gz):
436436
return Apply(self, y0 + gz_not_disconntected, outputs)
437437

438438
def perform(self, node, inputs, outputs):
439-
results = self.jitted_vjp_sol_op_jax(tuple(inputs))
439+
results = self.jitted_vjp_jax_op(tuple(inputs))
440440
if len(self.input_types) > 1:
441441
for i, result in enumerate(results):
442442
outputs[i][0] = np.array(result, self.input_types[i].dtype)
443443
else:
444444
outputs[0][0] = np.array(results, self.input_types[0].dtype)
445445

446446
def perform_jax(self, *inputs):
447-
results = self.jitted_vjp_sol_op_jax(tuple(inputs))
447+
results = self.jitted_vjp_jax_op(tuple(inputs))
448448
if self.num_outputs == 1:
449449
if isinstance(results, Sequence):
450450
return results[0]
@@ -455,10 +455,10 @@ def perform_jax(self, *inputs):
455455

456456

457457
@jax_funcify.register(JAXOp)
458-
def sol_op_jax_funcify(op, **kwargs):
458+
def jax_op_funcify(op, **kwargs):
459459
return op.perform_jax
460460

461461

462462
@jax_funcify.register(VJPJAXOp)
463-
def vjp_sol_op_jax_funcify(op, **kwargs):
463+
def vjp_jax_op_funcify(op, **kwargs):
464464
return op.perform_jax

0 commit comments

Comments
 (0)