Skip to content

Commit ca5909c

Browse files
committed
Merge branch 'fix-temp-default-shape' into 'master'
Fix the default shape of loopy.TemporaryVariable See merge request inducer/loopy!354
2 parents 72f1dd1 + 2526196 commit ca5909c

File tree

4 files changed

+14
-8
lines changed

4 files changed

+14
-8
lines changed

loopy/kernel/array.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -934,7 +934,8 @@ def num_target_axes(self):
934934
return len(target_axes)
935935

936936
def num_user_axes(self, require_answer=True):
937-
if self.shape is not None:
937+
from loopy import auto
938+
if self.shape not in (None, auto):
938939
return len(self.shape)
939940
if self.dim_tags is not None:
940941
return len(self.dim_tags)

loopy/kernel/data.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ class TemporaryVariable(ArrayBase):
525525
"_base_storage_access_may_be_aliasing",
526526
]
527527

528-
def __init__(self, name, dtype=None, shape=(), address_space=None,
528+
def __init__(self, name, dtype=None, shape=auto, address_space=None,
529529
dim_tags=None, offset=0, dim_names=None, strides=None, order=None,
530530
base_indices=None, storage_shape=None,
531531
base_storage=None, initializer=None, read_only=False,
@@ -579,7 +579,10 @@ def __init__(self, name, dtype=None, shape=(), address_space=None,
579579

580580
if shape is auto:
581581
shape = initializer.shape
582-
582+
else:
583+
if shape != initializer.shape:
584+
raise LoopyError("Shape of '{}' does not match that of the"
585+
" initializer.".format(name))
583586
else:
584587
raise LoopyError(
585588
"temporary variable '%s': "
@@ -589,7 +592,7 @@ def __init__(self, name, dtype=None, shape=(), address_space=None,
589592
if order is None:
590593
order = "C"
591594

592-
if base_indices is None:
595+
if base_indices is None and shape is not auto:
593596
base_indices = (0,) * len(shape)
594597

595598
if not read_only and initializer is not None:

loopy/target/opencl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,8 +551,10 @@ def emit_atomic_update(self, codegen_state, lhs_atomicity, lhs_var,
551551
from loopy.kernel.data import TemporaryVariable, AddressSpace
552552
ecm = codegen_state.expression_to_code_mapper.with_assignments(
553553
{
554-
old_val_var: TemporaryVariable(old_val_var, lhs_dtype),
555-
new_val_var: TemporaryVariable(new_val_var, lhs_dtype),
554+
old_val_var: TemporaryVariable(old_val_var, lhs_dtype,
555+
shape=()),
556+
new_val_var: TemporaryVariable(new_val_var, lhs_dtype,
557+
shape=()),
556558
})
557559

558560
lhs_expr_code = ecm(lhs_expr, prec=PREC_NONE, type_context=None)

test/test_loopy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def test_globals_decl_once_with_multi_subprogram(ctx_factory):
6868
out[ii] = 2*out[ii]+cnst[ii]{id=second}
6969
""",
7070
[lp.TemporaryVariable(
71-
'cnst', shape=('n'), initializer=cnst,
72-
address_space=lp.AddressSpace.GLOBAL,
71+
'cnst', initializer=cnst,
72+
scope=lp.AddressSpace.GLOBAL,
7373
read_only=True), '...'])
7474
knl = lp.fix_parameters(knl, n=16)
7575
knl = lp.add_barrier(knl, "id:first", "id:second")

0 commit comments

Comments
 (0)