Skip to content

Commit dee236c

Browse files
committed
Standardize various method overrides
1 parent c17f657 commit dee236c

File tree

20 files changed

+239
-190
lines changed

20 files changed

+239
-190
lines changed

pytensor/graph/destroyhandler.py

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -521,72 +521,72 @@ def fast_destroy(self, fgraph, app, reason):
521521
# assert len(v) <= 1
522522
# assert len(d) <= 1
523523

524-
def on_import(self, fgraph, app, reason):
524+
def on_import(self, fgraph, node, reason):
525525
"""
526526
Add Apply instance to set which must be computed.
527527
528528
"""
529-
if app in self.debug_all_apps:
529+
if node in self.debug_all_apps:
530530
raise ProtocolError("double import")
531-
self.debug_all_apps.add(app)
531+
self.debug_all_apps.add(node)
532532
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
533533

534534
# If it's a destructive op, add it to our watch list
535-
dmap = app.op.destroy_map
536-
vmap = app.op.view_map
535+
dmap = node.op.destroy_map
536+
vmap = node.op.view_map
537537
if dmap:
538-
self.destroyers.add(app)
538+
self.destroyers.add(node)
539539
if self.algo == "fast":
540-
self.fast_destroy(fgraph, app, reason)
540+
self.fast_destroy(fgraph, node, reason)
541541

542542
# add this symbol to the forward and backward maps
543543
for o_idx, i_idx_list in vmap.items():
544544
if len(i_idx_list) > 1:
545545
raise NotImplementedError(
546-
"destroying this output invalidates multiple inputs", (app.op)
546+
"destroying this output invalidates multiple inputs", (node.op)
547547
)
548-
o = app.outputs[o_idx]
549-
i = app.inputs[i_idx_list[0]]
548+
o = node.outputs[o_idx]
549+
i = node.inputs[i_idx_list[0]]
550550
self.view_i[o] = i
551551
self.view_o.setdefault(i, OrderedSet()).add(o)
552552

553553
# update self.clients
554-
for i, input in enumerate(app.inputs):
555-
self.clients.setdefault(input, {}).setdefault(app, 0)
556-
self.clients[input][app] += 1
554+
for i, input in enumerate(node.inputs):
555+
self.clients.setdefault(input, {}).setdefault(node, 0)
556+
self.clients[input][node] += 1
557557

558-
for i, output in enumerate(app.outputs):
558+
for i, output in enumerate(node.outputs):
559559
self.clients.setdefault(output, {})
560560

561561
self.stale_droot = True
562562

563-
def on_prune(self, fgraph, app, reason):
563+
def on_prune(self, fgraph, node, reason):
564564
"""
565565
Remove Apply instance from set which must be computed.
566566
567567
"""
568-
if app not in self.debug_all_apps:
568+
if node not in self.debug_all_apps:
569569
raise ProtocolError("prune without import")
570-
self.debug_all_apps.remove(app)
570+
self.debug_all_apps.remove(node)
571571

572572
# UPDATE self.clients
573-
for input in set(app.inputs):
574-
del self.clients[input][app]
573+
for input in set(node.inputs):
574+
del self.clients[input][node]
575575

576-
if app.op.destroy_map:
577-
self.destroyers.remove(app)
576+
if node.op.destroy_map:
577+
self.destroyers.remove(node)
578578

579579
# Note: leaving empty client dictionaries in the struct.
580580
# Why? It's a pain to remove them. I think they aren't doing any harm, they will be
581581
# deleted on_detach().
582582

583583
# UPDATE self.view_i, self.view_o
584-
for o_idx, i_idx_list in app.op.view_map.items():
584+
for o_idx, i_idx_list in node.op.view_map.items():
585585
if len(i_idx_list) > 1:
586586
# destroying this output invalidates multiple inputs
587587
raise NotImplementedError()
588-
o = app.outputs[o_idx]
589-
i = app.inputs[i_idx_list[0]]
588+
o = node.outputs[o_idx]
589+
i = node.inputs[i_idx_list[0]]
590590

591591
del self.view_i[o]
592592

@@ -595,53 +595,53 @@ def on_prune(self, fgraph, app, reason):
595595
del self.view_o[i]
596596

597597
self.stale_droot = True
598-
if app in self.fail_validate:
599-
del self.fail_validate[app]
598+
if node in self.fail_validate:
599+
del self.fail_validate[node]
600600

601-
def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
601+
def on_change_input(self, fgraph, node, i, var, new_var, reason=None):
602602
"""
603-
app.inputs[i] changed from old_r to new_r.
603+
node.inputs[i] changed from var to new_var.
604604
605605
"""
606-
if isinstance(app.op, Output):
607-
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
606+
if isinstance(node.op, Output):
607+
# node == 'output' is special key that means FunctionGraph is redefining which nodes are being
608608
# considered 'outputs' of the graph.
609609
pass
610610
else:
611-
if app not in self.debug_all_apps:
611+
if node not in self.debug_all_apps:
612612
raise ProtocolError("change without import")
613613

614614
# UPDATE self.clients
615-
self.clients[old_r][app] -= 1
616-
if self.clients[old_r][app] == 0:
617-
del self.clients[old_r][app]
615+
self.clients[var][node] -= 1
616+
if self.clients[var][node] == 0:
617+
del self.clients[var][node]
618618

619-
self.clients.setdefault(new_r, {}).setdefault(app, 0)
620-
self.clients[new_r][app] += 1
619+
self.clients.setdefault(new_var, {}).setdefault(node, 0)
620+
self.clients[new_var][node] += 1
621621

622622
# UPDATE self.view_i, self.view_o
623-
for o_idx, i_idx_list in app.op.view_map.items():
623+
for o_idx, i_idx_list in node.op.view_map.items():
624624
if len(i_idx_list) > 1:
625625
# destroying this output invalidates multiple inputs
626626
raise NotImplementedError()
627627
i_idx = i_idx_list[0]
628-
output = app.outputs[o_idx]
628+
output = node.outputs[o_idx]
629629
if i_idx == i:
630-
if app.inputs[i_idx] is not new_r:
630+
if node.inputs[i_idx] is not new_var:
631631
raise ProtocolError("wrong new_r on change")
632632

633-
self.view_i[output] = new_r
633+
self.view_i[output] = new_var
634634

635-
self.view_o[old_r].remove(output)
636-
if not self.view_o[old_r]:
637-
del self.view_o[old_r]
635+
self.view_o[var].remove(output)
636+
if not self.view_o[var]:
637+
del self.view_o[var]
638638

639-
self.view_o.setdefault(new_r, OrderedSet()).add(output)
639+
self.view_o.setdefault(new_var, OrderedSet()).add(output)
640640

641641
if self.algo == "fast":
642-
if app in self.fail_validate:
643-
del self.fail_validate[app]
644-
self.fast_destroy(fgraph, app, reason)
642+
if node in self.fail_validate:
643+
del self.fail_validate[node]
644+
self.fast_destroy(fgraph, node, reason)
645645
self.stale_droot = True
646646

647647
def validate(self, fgraph):

pytensor/graph/features.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -416,11 +416,11 @@ def on_detach(self, fgraph):
416416
del fgraph.revert
417417
del self.history[fgraph]
418418

419-
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
419+
def on_change_input(self, fgraph, node, i, var, new_var, reason=None):
420420
if self.history[fgraph] is None:
421421
return
422422
h = self.history[fgraph]
423-
h.append(LambdaExtract(fgraph, node, i, r, reason))
423+
h.append(LambdaExtract(fgraph, node, i, var, reason))
424424

425425
def revert(self, fgraph, checkpoint):
426426
"""
@@ -544,9 +544,9 @@ def on_attach(self, fgraph):
544544
raise ValueError("Full History already attached to another fgraph")
545545
self.fg = fgraph
546546

547-
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
548-
self.bw.append(LambdaExtract(fgraph, node, i, r, reason))
549-
self.fw.append(LambdaExtract(fgraph, node, i, new_r, reason))
547+
def on_change_input(self, fgraph, node, i, var, new_var, reason=None):
548+
self.bw.append(LambdaExtract(fgraph, node, i, var, reason))
549+
self.fw.append(LambdaExtract(fgraph, node, i, new_var, reason))
550550
self.pointer += 1
551551
if self.callback:
552552
self.callback()
@@ -832,15 +832,15 @@ class PreserveVariableAttributes(Feature):
832832
This preserve some variables attributes and tag during optimization.
833833
"""
834834

835-
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
835+
def on_change_input(self, fgraph, node, i, var, new_var, reason=None):
836836
# Don't change the name of constants
837-
if r.owner and r.name is not None and new_r.name is None:
838-
new_r.name = r.name
837+
if var.owner and var.name is not None and new_var.name is None:
838+
new_var.name = var.name
839839
if (
840-
getattr(r.tag, "nan_guard_mode_check", False)
841-
and getattr(new_r.tag, "nan_guard_mode_check", False) is False
840+
getattr(var.tag, "nan_guard_mode_check", False)
841+
and getattr(new_var.tag, "nan_guard_mode_check", False) is False
842842
):
843-
new_r.tag.nan_guard_mode_check = r.tag.nan_guard_mode_check
843+
new_var.tag.nan_guard_mode_check = var.tag.nan_guard_mode_check
844844

845845

846846
class NoOutputFromInplace(Feature):

pytensor/graph/rewriting/basic.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -550,15 +550,15 @@ def on_attach(self, fgraph):
550550
def clone(self):
551551
return type(self)()
552552

553-
def on_change_input(self, fgraph, node, i, r, new_r, reason):
553+
def on_change_input(self, fgraph, node, i, var, new_var, reason=None):
554554
if node in self.nodes_seen:
555555
# If inputs to a node change, it's not guaranteed that the node is
556556
# distinct from the other nodes in `self.nodes_seen`.
557557
self.nodes_seen.discard(node)
558558
self.process_node(fgraph, node)
559559

560-
if isinstance(new_r, AtomicVariable):
561-
self.process_atomic(fgraph, new_r)
560+
if isinstance(new_var, AtomicVariable):
561+
self.process_atomic(fgraph, new_var)
562562

563563
def on_import(self, fgraph, node, reason):
564564
for c in node.inputs:
@@ -973,7 +973,7 @@ def __init__(self, fn, tracks=None, requirements=()):
973973
)
974974
self.requirements = requirements
975975

976-
def transform(self, fgraph, node, enforce_tracks: bool = True):
976+
def transform(self, fgraph, node, enforce_tracks: bool = True, *args, **kwargs):
977977
if enforce_tracks and self._tracks:
978978
node_op = node.op
979979
if not (
@@ -1230,7 +1230,7 @@ def tracks(self):
12301230
t.extend(at)
12311231
return t
12321232

1233-
def transform(self, fgraph, node, enforce_tracks=False):
1233+
def transform(self, fgraph, node, enforce_tracks=False, *args, **kwargs):
12341234
if len(self.rewrites) == 0:
12351235
return
12361236

@@ -1385,7 +1385,7 @@ def __init__(self, op1, op2, transfer_tags=True):
13851385
def tracks(self):
13861386
return [self.op1]
13871387

1388-
def transform(self, fgraph, node, enforce_tracks=True):
1388+
def transform(self, fgraph, node, enforce_tracks=True, *args, **kwargs):
13891389
if enforce_tracks and (node.op != self.op1):
13901390
return False
13911391
repl = self.op2.make_node(*node.inputs)
@@ -1713,9 +1713,9 @@ def on_prune(self, fgraph, node, reason):
17131713
if self.pruner:
17141714
self.pruner(node)
17151715

1716-
def on_change_input(self, fgraph, node, i, r, new_r, reason):
1716+
def on_change_input(self, fgraph, node, i, var, new_var, reason=None):
17171717
if self.chin:
1718-
self.chin(node, i, r, new_r, reason)
1718+
self.chin(node, i, var, new_var, reason)
17191719

17201720
def on_detach(self, fgraph):
17211721
# To allow pickling this object
@@ -2160,7 +2160,7 @@ def on_import(self, fgraph, node, reason):
21602160
self.nb_imported += 1
21612161
self.changed = True
21622162

2163-
def on_change_input(self, fgraph, node, i, r, new_r, reason):
2163+
def on_change_input(self, fgraph, node, i, var, new_var, reason=None):
21642164
self.changed = True
21652165

21662166
def reset(self):

pytensor/graph/rewriting/db.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -396,10 +396,10 @@ def __init__(
396396
self.__position__ = {}
397397
self.failure_callback = failure_callback
398398

399-
def register(self, name, obj, *tags, **kwargs):
399+
def register(self, name, rewriter, *tags, **kwargs):
400400
position = kwargs.pop("position", "last")
401401

402-
super().register(name, obj, *tags, **kwargs)
402+
super().register(name, rewriter, *tags, **kwargs)
403403

404404
if position == "last":
405405
if len(self.__position__) == 0:
@@ -497,8 +497,8 @@ def __init__(
497497
self.node_rewriter = node_rewriter
498498
self.__name__: str = ""
499499

500-
def register(self, name, obj, *tags, position="last", **kwargs):
501-
super().register(name, obj, *tags, position=position, **kwargs)
500+
def register(self, name, rewriter, *tags, position="last", **kwargs):
501+
super().register(name, rewriter, *tags, position=position, **kwargs)
502502

503503
def query(self, *tags, **kwtags):
504504
rewrites = list(super().query(*tags, **kwtags))

pytensor/graph/rewriting/kanren.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def results_filter(
7474
self.node_filter = node_filter
7575
super().__init__()
7676

77-
def transform(self, fgraph, node, enforce_tracks: bool = True):
77+
def transform(self, fgraph, node, enforce_tracks: bool = True, *args, **kwargs):
7878
if self.node_filter(node) is False:
7979
return False
8080

pytensor/misc/ordered_set.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ def __init__(self, iterable: Iterable | None = None) -> None:
1111
else:
1212
self.values = dict.fromkeys(iterable)
1313

14-
def __contains__(self, value) -> bool:
15-
return value in self.values
14+
def __contains__(self, x) -> bool:
15+
return x in self.values
1616

1717
def __iter__(self) -> Iterator:
1818
yield from self.values

0 commit comments

Comments
 (0)