Skip to content

Commit 6aeed97

Browse files
committed
Simplify typify of Generators in JAXLinker
1 parent 9f80bdc commit 6aeed97

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

pytensor/link/jax/linker.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,8 @@ def create_thunk_inputs(self, storage_map):
117117
for n in self.fgraph.inputs:
118118
sinput = storage_map[n]
119119
if isinstance(sinput[0], Generator):
120-
new_value = jax_typify(
121-
sinput[0], dtype=getattr(sinput[0], "dtype", None)
122-
)
123-
sinput[0] = new_value
120+
# Neet to convert Generator into JAX PRNGkey
121+
sinput[0] = jax_typify(sinput[0])
124122
thunk_inputs.append(sinput)
125123

126124
return thunk_inputs

0 commit comments

Comments
 (0)