@@ -160,11 +160,11 @@ def func(*args, **kwargs):
160
160
161
161
### Call the function that accepts flat inputs, which in turn calls the one that
162
162
### combines the inputs and static variables.
163
- jitted_sol_op_jax = jax .jit (func_flattened )
163
+ jitted_jax_op = jax .jit (func_flattened )
164
164
len_gz = len (pttypes_outvars )
165
165
166
- vjp_sol_op_jax = _get_vjp_sol_op_jax (func_flattened , len_gz )
167
- jitted_vjp_sol_op_jax = jax .jit (vjp_sol_op_jax )
166
+ vjp_jax_op = _get_vjp_jax_op (func_flattened , len_gz )
167
+ jitted_vjp_jax_op = jax .jit (vjp_jax_op )
168
168
169
169
# Get classes that creates a Pytensor Op out of our function that accept
170
170
# flattened inputs. They are created each time, to set a custom name for the
@@ -194,8 +194,8 @@ class VJPJAXOp_local(VJPJAXOp):
194
194
outvars_treedef ,
195
195
input_types = pt_vars_types_flat ,
196
196
output_types = pttypes_outvars ,
197
- jitted_sol_op_jax = jitted_sol_op_jax ,
198
- jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax ,
197
+ jitted_jax_op = jitted_jax_op ,
198
+ jitted_vjp_jax_op = jitted_vjp_jax_op ,
199
199
)
200
200
201
201
### Evaluate the Pytensor Op and return unflattened results
@@ -265,8 +265,8 @@ def get_func_with_vars(self, vars):
265
265
return interior_func
266
266
267
267
268
- def _get_vjp_sol_op_jax (jaxfunc , len_gz ):
269
- def vjp_sol_op_jax (args ):
268
+ def _get_vjp_jax_op (jaxfunc , len_gz ):
269
+ def vjp_jax_op (args ):
270
270
y0 = args [:- len_gz ]
271
271
gz = args [- len_gz :]
272
272
if len (gz ) == 1 :
@@ -290,7 +290,7 @@ def func(*inputs):
290
290
else :
291
291
return tuple (vjp_fn (gz ))
292
292
293
- return vjp_sol_op_jax
293
+ return vjp_jax_op
294
294
295
295
296
296
def _partition_jaxfunc (jaxfunc , static_vars , func_vars ):
@@ -350,16 +350,16 @@ def __init__(
350
350
output_treeedef ,
351
351
input_types ,
352
352
output_types ,
353
- jitted_sol_op_jax ,
354
- jitted_vjp_sol_op_jax ,
353
+ jitted_jax_op ,
354
+ jitted_vjp_jax_op ,
355
355
):
356
- self .vjp_sol_op = None
356
+ self .vjp_jax_op = None
357
357
self .input_treedef = input_treedef
358
358
self .output_treedef = output_treeedef
359
359
self .input_types = input_types
360
360
self .output_types = output_types
361
- self .jitted_sol_op_jax = jitted_sol_op_jax
362
- self .jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax
361
+ self .jitted_jax_op = jitted_jax_op
362
+ self .jitted_vjp_jax_op = jitted_vjp_jax_op
363
363
364
364
def make_node (self , * inputs ):
365
365
self .num_inputs = len (inputs )
@@ -368,24 +368,24 @@ def make_node(self, *inputs):
368
368
outputs = [pt .as_tensor_variable (type ()) for type in self .output_types ]
369
369
self .num_outputs = len (outputs )
370
370
371
- self .vjp_sol_op = VJPJAXOp (
371
+ self .vjp_jax_op = VJPJAXOp (
372
372
self .input_treedef ,
373
373
self .input_types ,
374
- self .jitted_vjp_sol_op_jax ,
374
+ self .jitted_vjp_jax_op ,
375
375
)
376
376
377
377
return Apply (self , inputs , outputs )
378
378
379
379
def perform (self , node , inputs , outputs ):
380
- results = self .jitted_sol_op_jax (inputs )
380
+ results = self .jitted_jax_op (inputs )
381
381
if self .num_outputs > 1 :
382
382
for i in range (self .num_outputs ):
383
383
outputs [i ][0 ] = np .array (results [i ], self .output_types [i ].dtype )
384
384
else :
385
385
outputs [0 ][0 ] = np .array (results , self .output_types [0 ].dtype )
386
386
387
387
def perform_jax (self , * inputs ):
388
- results = self .jitted_sol_op_jax (inputs )
388
+ results = self .jitted_jax_op (inputs )
389
389
return results
390
390
391
391
def grad (self , inputs , output_gradients ):
@@ -399,7 +399,7 @@ def grad(self, inputs, output_gradients):
399
399
)
400
400
else :
401
401
output_gradients [i ] = pt .zeros ((), self .output_types [i ].dtype )
402
- result = self .vjp_sol_op (inputs , output_gradients )
402
+ result = self .vjp_jax_op (inputs , output_gradients )
403
403
404
404
if self .num_inputs > 1 :
405
405
return result
@@ -413,11 +413,11 @@ def __init__(
413
413
self ,
414
414
input_treedef ,
415
415
input_types ,
416
- jitted_vjp_sol_op_jax ,
416
+ jitted_vjp_jax_op ,
417
417
):
418
418
self .input_treedef = input_treedef
419
419
self .input_types = input_types
420
- self .jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax
420
+ self .jitted_vjp_jax_op = jitted_vjp_jax_op
421
421
422
422
def make_node (self , y0 , gz ):
423
423
y0 = [
@@ -436,15 +436,15 @@ def make_node(self, y0, gz):
436
436
return Apply (self , y0 + gz_not_disconntected , outputs )
437
437
438
438
def perform (self , node , inputs , outputs ):
439
- results = self .jitted_vjp_sol_op_jax (tuple (inputs ))
439
+ results = self .jitted_vjp_jax_op (tuple (inputs ))
440
440
if len (self .input_types ) > 1 :
441
441
for i , result in enumerate (results ):
442
442
outputs [i ][0 ] = np .array (result , self .input_types [i ].dtype )
443
443
else :
444
444
outputs [0 ][0 ] = np .array (results , self .input_types [0 ].dtype )
445
445
446
446
def perform_jax (self , * inputs ):
447
- results = self .jitted_vjp_sol_op_jax (tuple (inputs ))
447
+ results = self .jitted_vjp_jax_op (tuple (inputs ))
448
448
if self .num_outputs == 1 :
449
449
if isinstance (results , Sequence ):
450
450
return results [0 ]
@@ -455,10 +455,10 @@ def perform_jax(self, *inputs):
455
455
456
456
457
457
@jax_funcify .register (JAXOp )
458
- def sol_op_jax_funcify (op , ** kwargs ):
458
+ def jax_op_funcify (op , ** kwargs ):
459
459
return op .perform_jax
460
460
461
461
462
462
@jax_funcify .register (VJPJAXOp )
463
- def vjp_sol_op_jax_funcify (op , ** kwargs ):
463
+ def vjp_jax_op_funcify (op , ** kwargs ):
464
464
return op .perform_jax
0 commit comments