diff --git a/sklearn_numba_dpex/common/kernels.py b/sklearn_numba_dpex/common/kernels.py index 2824d73..48bcf47 100644 --- a/sklearn_numba_dpex/common/kernels.py +++ b/sklearn_numba_dpex/common/kernels.py @@ -2,24 +2,26 @@ from functools import lru_cache import dpctl.tensor as dpt -import numba_dpex as dpex +import numba_dpex.experimental as dpex_exp import numpy as np +from numba_dpex.kernel_api import NdItem, NdRange zero_idx = np.int64(0) @lru_cache def make_apply_elementwise_func(shape, func, work_group_size): - func = dpex.func(func) + func = dpex_exp.device_func(func) n_items = math.prod(shape) - @dpex.kernel + @dpex_exp.kernel # fmt: off def elementwise_ops_kernel( + nd_item: NdItem, data, # INOUT (n_items,) ): # fmt: on - item_idx = dpex.get_global_id(zero_idx) + item_idx = nd_item.get_global_id(zero_idx) if item_idx >= n_items: return @@ -30,7 +32,9 @@ def elementwise_ops_kernel( def elementwise_ops(data): data = dpt.reshape(data, (-1,)) - elementwise_ops_kernel[global_size, work_group_size](data) + dpex_exp.call_kernel( + elementwise_ops_kernel, NdRange((global_size,), (work_group_size,)), data + ) return elementwise_ops @@ -41,9 +45,9 @@ def make_initialize_to_zeros_kernel(shape, work_group_size, dtype): global_size = math.ceil(n_items / work_group_size) * work_group_size zero = dtype(0.0) - @dpex.kernel - def initialize_to_zeros_kernel(data): - item_idx = dpex.get_global_id(zero_idx) + @dpex_exp.kernel + def initialize_to_zeros_kernel(nd_item: NdItem, data): + item_idx = nd_item.get_global_id(zero_idx) if item_idx >= n_items: return @@ -52,7 +56,11 @@ def initialize_to_zeros_kernel(data): def initialize_to_zeros(data): data = dpt.reshape(data, (-1,)) - initialize_to_zeros_kernel[global_size, work_group_size](data) + dpex_exp.call_kernel( + initialize_to_zeros_kernel, + NdRange((global_size,), (work_group_size,)), + data, + ) return initialize_to_zeros @@ -65,9 +73,9 @@ def make_broadcast_division_1d_2d_axis0_kernel(shape, work_group_size): # NB: the left operand is modified inplace, the right operand is only read into. # Optimized for C-contiguous array and for # size1 >> preferred_work_group_size_multiple - @dpex.kernel - def broadcast_division(dividend_array, divisor_vector): - col_idx = dpex.get_global_id(zero_idx) + @dpex_exp.kernel + def broadcast_division(nd_item: NdItem, dividend_array, divisor_vector): + col_idx = nd_item.get_global_id(zero_idx) if col_idx >= n_cols: return @@ -79,28 +87,34 @@ def broadcast_division(dividend_array, divisor_vector): dividend_array[row_idx, col_idx] / divisor ) - return broadcast_division[global_size, work_group_size] + def kernel_call(*args): + return dpex_exp.call_kernel( + broadcast_division, NdRange((global_size,), (work_group_size,)), *args + ) + + return kernel_call @lru_cache def make_broadcast_ops_1d_2d_axis1_kernel(shape, ops, work_group_size): """ - ops must be a function that will be interpreted as a dpex.func and is subject to - the same rules. It is expected to take two scalar arguments and return one scalar - value. lambda functions are advised against since the cache will not work with lamda - functions. sklearn_numba_dpex.common._utils expose some pre-defined `ops`. + ops must be a function that will be interpreted as a dpex_exp.device_func and is + subject to the same rules. It is expected to take two scalar arguments and return + one scalar value. lambda functions are advised against since the cache will not + work with lamda functions. sklearn_numba_dpex.common._utils expose some + pre-defined `ops`. """ n_rows, n_cols = shape global_size = math.ceil(n_cols / work_group_size) * work_group_size - ops = dpex.func(ops) + ops = dpex_exp.device_func(ops) # NB: the left operand is modified inplace, the right operand is only read into. # Optimized for C-contiguous array and for # size1 >> preferred_work_group_size_multiple - @dpex.kernel - def broadcast_ops(left_operand_array, right_operand_vector): - col_idx = dpex.get_global_id(zero_idx) + @dpex_exp.kernel + def broadcast_ops(nd_item: NdItem, left_operand_array, right_operand_vector): + col_idx = nd_item.get_global_id(zero_idx) if col_idx >= n_cols: return @@ -110,7 +124,12 @@ def broadcast_ops(left_operand_array, right_operand_vector): left_operand_array[row_idx, col_idx], right_operand_vector[row_idx] ) - return broadcast_ops[global_size, work_group_size] + def kernel_call(*args): + return dpex_exp.call_kernel( + broadcast_ops, NdRange((global_size,), (work_group_size,)), *args + ) + + return kernel_call @lru_cache @@ -122,14 +141,15 @@ def make_half_l2_norm_2d_axis0_kernel(shape, work_group_size, dtype): # Optimized for C-contiguous array and for # size1 >> preferred_work_group_size_multiple - @dpex.kernel + @dpex_exp.kernel # fmt: off def half_l2_norm( + nd_item: NdItem, data, # IN (size0, size1) result, # OUT (size1,) ): # fmt: on - col_idx = dpex.get_global_id(zero_idx) + col_idx = nd_item.get_global_id(zero_idx) if col_idx >= n_cols: return @@ -142,4 +162,9 @@ def half_l2_norm( result[col_idx] = l2_norm / two - return half_l2_norm[global_size, work_group_size] + def kernel_call(*args): + return dpex_exp.call_kernel( + half_l2_norm, NdRange((global_size,), (work_group_size,)), *args + ) + + return kernel_call diff --git a/sklearn_numba_dpex/common/matmul.py b/sklearn_numba_dpex/common/matmul.py index c331824..445de9e 100644 --- a/sklearn_numba_dpex/common/matmul.py +++ b/sklearn_numba_dpex/common/matmul.py @@ -2,7 +2,9 @@ from functools import lru_cache import numba_dpex as dpex +import numba_dpex.experimental as dpex_exp import numpy as np +from numba_dpex.kernel_api import MemoryScope, NdItem, NdRange, group_barrier from sklearn_numba_dpex.common._utils import _enforce_matmul_like_work_group_geometry @@ -54,7 +56,7 @@ def make_matmul_2d_kernel( If `out_fused_elementwise_fn` is not None, it will be applied element-wise once to each element of the output array right before it is returned. This function is compiled and fused into the first kernel as a device function with the help of - `dpex.func`. This comes with limitations as explained in: + `dpex_exp.device_func`. This comes with limitations as explained in: https://intelpython.github.io/numba-dpex/latest/user_guides/kernel_programming_guide/device-functions.html # noqa @@ -147,21 +149,21 @@ def make_matmul_2d_kernel( if multiply_fn is None: - @dpex.func + @dpex_exp.device_func def multiply_fn_(x, y): return x * y else: - multiply_fn_ = dpex.func(multiply_fn) + multiply_fn_ = dpex_exp.device_func(multiply_fn) if out_fused_elementwise_fn is None: - @dpex.func + @dpex_exp.device_func def out_fused_elementwise_fn(x): return x else: - out_fused_elementwise_fn = dpex.func(out_fused_elementwise_fn) + out_fused_elementwise_fn = dpex_exp.device_func(out_fused_elementwise_fn) # Under the same assumption, this value is equal to the number of results computed # by a single work group @@ -268,21 +270,22 @@ def out_fused_elementwise_fn(x): grid_n_groups = global_grid_n_rows * global_grid_n_cols global_size = grid_n_groups * work_group_size - @dpex.kernel + @dpex_exp.kernel # fmt: off def matmul( + nd_item: NdItem, X, # IN (X_n_rows, n_cols) Y_t, # IN (Y_t_n_rows, n_cols) result # OUT (X_n_rows, Y_t_n_rows) ): # fmt: on - work_item_idx = dpex.get_local_id(zero_idx) + work_item_idx = nd_item.get_local_id(zero_idx) # Index the work items in the base sliding window: work_item_row_idx = work_item_idx // sub_group_size work_item_col_idx = work_item_idx % sub_group_size - group_idx = dpex.get_group_id(zero_idx) + group_idx = nd_item.get_group().get_group_id(zero_idx) # Get indices of the row and the column of the top-left corner of the sub-array # of results covered by this work group @@ -352,7 +355,7 @@ def matmul( ) window_loaded_col_idx += sub_group_size - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) _accumulate_private_windows( first_private_loaded_sliding_X_value_idx, @@ -365,7 +368,7 @@ def matmul( private_result ) - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) _write_result( group_first_row_idx + first_private_loaded_sliding_X_value_idx, @@ -376,7 +379,7 @@ def matmul( ) # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_need_to_workaround_numba_dpex_906 # noqa - @dpex.func + @dpex_exp.device_func # fmt: off def _load_sliding_windows( work_item_row_idx, # PARAM @@ -419,7 +422,7 @@ def _load_sliding_windows( Y_t_loaded_row_idx += base_result_window_side Y_t_local_loaded_row_idx += base_result_window_side - @dpex.func + @dpex_exp.device_func # fmt: off def _accumulate_private_windows( private_first_loaded_sliding_X_value_idx, # PARAM @@ -465,7 +468,7 @@ def _accumulate_private_windows( private_array_first_col += private_Y_t_sliding_window_width - @dpex.func + @dpex_exp.device_func # fmt: off def _write_result( result_first_row_idx, # PARAM @@ -486,14 +489,17 @@ def _write_result( result_col_idx += nb_work_items_for_Y_t_window result_row_idx += nb_work_items_for_X_window - return matmul[global_size, work_group_size] + def kernel_call(*args): + dpex_exp.call_kernel(matmul, NdRange((global_size,), (work_group_size,)), *args) + + return kernel_call def _make_accumulate_step_unrolled_kernel_func(private_result_array_width, multiply_fn): if private_result_array_width == 1: - @dpex.func + @dpex_exp.device_func def _accumulate_step_unrolled( i, j, private_loaded_X_value, private_Y_t_sliding_window, private_result ): @@ -503,7 +509,7 @@ def _accumulate_step_unrolled( elif private_result_array_width == 2: - @dpex.func + @dpex_exp.device_func def _accumulate_step_unrolled( i, j, private_loaded_X_value, private_Y_t_sliding_window, private_result ): @@ -516,7 +522,7 @@ def _accumulate_step_unrolled( elif private_result_array_width == 4: - @dpex.func + @dpex_exp.device_func def _accumulate_step_unrolled( i, j, private_loaded_X_value, private_Y_t_sliding_window, private_result ): @@ -535,7 +541,7 @@ def _accumulate_step_unrolled( elif private_result_array_width == 8: - @dpex.func + @dpex_exp.device_func def _accumulate_step_unrolled( i, j, private_loaded_X_value, private_Y_t_sliding_window, private_result ): diff --git a/sklearn_numba_dpex/common/random.py b/sklearn_numba_dpex/common/random.py index 7def861..3714cfc 100644 --- a/sklearn_numba_dpex/common/random.py +++ b/sklearn_numba_dpex/common/random.py @@ -3,9 +3,10 @@ import dpctl import dpctl.tensor as dpt -import numba_dpex as dpex +import numba_dpex.experimental as dpex_exp import numpy as np from numba import float32, float64, int64, uint32, uint64 +from numba_dpex.kernel_api import NdRange from ._utils import _get_sequential_processing_device @@ -46,11 +47,14 @@ def make_random_raw_kernel(): Note, this always uses and updates state states[0]. """ - @dpex.kernel + @dpex_exp.kernel def _get_random_raw_kernel(states, result): result[zero_idx] = _xoroshiro128pp_next(states, zero_idx) - return _get_random_raw_kernel[1, 1] + def kernel_call(*args): + dpex_exp.call_kernel(_get_random_raw_kernel, NdRange((1,), (1,)), *args) + + return kernel_call def make_rand_uniform_kernel_func(dtype): @@ -80,7 +84,7 @@ def make_rand_uniform_kernel_func(dtype): convert_const = float64(uint64(1) << uint32(53)) convert_const_one = float64(1) - @dpex.func + @dpex_exp.device_func def uint64_to_unit_float(x): """Convert uint64 to float64 value in the range [0.0, 1.0)""" return float64(x >> convert_rshift) * (convert_const_one / convert_const) @@ -90,7 +94,7 @@ def uint64_to_unit_float(x): convert_const = float32(uint32(1) << uint32(24)) convert_const_one = float32(1) - @dpex.func + @dpex_exp.device_func def uint64_to_unit_float(x): """Convert uint64 to float32 value in the range [0.0, 1.0) @@ -109,7 +113,7 @@ def uint64_to_unit_float(x): f"dtype.name == {dtype.name}" ) - @dpex.func + @dpex_exp.device_func def xoroshiro128pp_uniform(states, state_idx): """Return one random float in [0, 1) @@ -195,7 +199,7 @@ def _make_init_xoroshiro128pp_states_kernel(n_states, subsequence_start): splitmix64_rshift_2 = uint32(27) splitmix64_rshift_3 = uint32(31) - @dpex.func + @dpex_exp.device_func def _splitmix64_next(state): new_state = z = state + splitmix64_const_1 z = (z ^ (z >> splitmix64_rshift_1)) * splitmix64_const_2 @@ -209,7 +213,7 @@ def _splitmix64_next(state): long_2 = int64(2) long_64 = int64(64) - @dpex.func + @dpex_exp.device_func def _xoroshiro128pp_jump(states, state_idx): """Advance the RNG in ``states[state_idx]`` by 2**64 steps.""" s0 = jump_init @@ -231,7 +235,7 @@ def _xoroshiro128pp_jump(states, state_idx): init_const_1 = np.uint64(0) - @dpex.kernel + @dpex_exp.kernel def init_xoroshiro128pp_states(states, seed): """ Use SplitMix64 to generate an xoroshiro128p state from a uint64 seed. @@ -260,13 +264,16 @@ def init_xoroshiro128pp_states(states, seed): # and jump forward 2**64 steps _xoroshiro128pp_jump(states, idx) - return init_xoroshiro128pp_states[1, 1] + def kernel_call(*args): + dpex_exp.call_kernel(init_xoroshiro128pp_states, NdRange((1,), (1,)), *args) + + return kernel_call _64_as_uint32 = uint32(64) -@dpex.func +@dpex_exp.device_func def _rotl(x, k): """Left rotate x by k bits. x is expected to be a uint64 integer.""" return (x << k) | (x >> (_64_as_uint32 - k)) @@ -278,7 +285,7 @@ def _rotl(x, k): shift_1 = uint32(21) -@dpex.func +@dpex_exp.device_func def _xoroshiro128pp_next(states, state_idx): """Return the next random uint64 and advance the RNG in states[state_idx].""" s0 = states[state_idx, zero_idx] diff --git a/sklearn_numba_dpex/common/reductions.py b/sklearn_numba_dpex/common/reductions.py index c6fa0bf..df18612 100644 --- a/sklearn_numba_dpex/common/reductions.py +++ b/sklearn_numba_dpex/common/reductions.py @@ -3,7 +3,9 @@ import dpctl.tensor as dpt import numba_dpex as dpex +import numba_dpex.experimental as dpex_exp import numpy as np +from numba_dpex.kernel_api import MemoryScope, NdItem, NdRange, group_barrier from sklearn_numba_dpex.common._utils import ( _check_max_work_group_size, @@ -41,10 +43,11 @@ def make_argmin_reduction_1d_kernel(size, device, dtype, work_group_size="max"): # TODO: the first call of partial_argmin_reduction in the final loop should be # written with only two arguments since "previous_result" does not exist yet. # It seems it's not possible to get a good factoring of the code to avoid copying - # most of the code for this with @dpex.kernel, for now we resort to branching. - @dpex.kernel + # most of the code for this with @dpex_exp.kernel, for now we resort to branching. + @dpex_exp.kernel # fmt: off def partial_argmin_reduction( + nd_item: NdItem, values, # IN (size,) previous_result, # IN (current_size,) argmin_indices, # OUT (math.ceil( @@ -53,8 +56,8 @@ def partial_argmin_reduction( # )) ): # fmt: on - group_id = dpex.get_group_id(zero_idx) - local_work_id = dpex.get_local_id(zero_idx) + group_id = nd_item.get_group().get_group_id(zero_idx) + local_work_id = nd_item.get_local_id(zero_idx) first_work_id = local_work_id == zero_idx previous_result_size = previous_result.shape[zero_idx] @@ -77,7 +80,7 @@ def partial_argmin_reduction( local_values, ) - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) n_active_work_items = work_group_size for i in range(n_local_iterations): n_active_work_items = n_active_work_items // two_as_a_long @@ -88,7 +91,7 @@ def partial_argmin_reduction( local_values, local_argmin ) - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) _register_result( first_work_id, @@ -100,7 +103,7 @@ def partial_argmin_reduction( ) # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_need_to_workaround_numba_dpex_906 # noqa - @dpex.func + @dpex_exp.device_func # fmt: off def _prepare_local_memory( local_work_id, # PARAM @@ -144,7 +147,7 @@ def _prepare_local_memory( local_values[local_work_id] = y # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_need_to_workaround_numba_dpex_906 # noqa - @dpex.func + @dpex_exp.device_func # fmt: off def _local_iteration( local_work_id, # PARAM @@ -169,7 +172,7 @@ def _local_iteration( local_argmin[local_x_idx] = local_argmin[local_y_idx] # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_need_to_workaround_numba_dpex_906 # noqa - @dpex.func + @dpex_exp.device_func # fmt: off def _register_result( first_work_id, # PARAM @@ -194,7 +197,7 @@ def _register_result( previous_result = dpt.empty((1,), dtype=np.int32, device=device) while n_groups > 1: n_groups = math.ceil(n_groups / (2 * work_group_size)) - sizes = (n_groups * work_group_size, work_group_size) + sizes = ((n_groups * work_group_size,), (work_group_size,)) result = dpt.empty(n_groups, dtype=np.int32, device=device) kernels_and_empty_tensors_tuples.append( (partial_argmin_reduction, sizes, previous_result, result) @@ -203,7 +206,9 @@ def _register_result( def argmin_reduction(values): for kernel, sizes, previous_result, result in kernels_and_empty_tensors_tuples: - kernel[sizes](values, previous_result, result) + dpex_exp.call_kernel( + kernel, NdRange(*sizes), values, previous_result, result + ) return result return argmin_reduction @@ -261,7 +266,7 @@ def make_sum_reduction_2d_kernel( If `fused_elementwise_func` is not None, it will be applied element-wise once to each element of the input array at the beginning of the the first kernel invocation. This function is compiled and fused into the first - kernel as a device function with the help of `dpex.func`. This comes with + kernel as a device function with the help of `dpex_exp.device_func`. This comes with limitations as explained in: https://intelpython.github.io/numba-dpex/latest/user_guides/kernel_programming_guide/device-functions.html # noqa @@ -382,7 +387,7 @@ def sum_reduction(summands): # TODO: manually dispatch the kernels with a SyclQueue for kernel, sizes, result in kernels_and_empty_tensors_pairs: - kernel[sizes](summands, result) + dpex_exp.call_kernel(kernel, NdRange(*sizes), summands, result) summands = result if is_1d: @@ -400,12 +405,12 @@ def _prepare_sum_reduction_2d_axis0( if fused_elementwise_func is None: - @dpex.func + @dpex_exp.device_func def fused_elementwise_func_(x): return x else: - fused_elementwise_func_ = dpex.func(fused_elementwise_func) + fused_elementwise_func_ = dpex_exp.device_func(fused_elementwise_func) input_work_group_size = work_group_size work_group_size = _check_max_work_group_size( @@ -481,9 +486,10 @@ def _make_partial_sum_reduction_2d_axis0_kernel( # ???: how does this strategy compares to having each thread reducing N contiguous # items ? - @dpex.kernel + @dpex_exp.kernel # fmt: off def partial_sum_reduction( + nd_item: NdItem, summands, # IN (sum_axis_size, n_cols) result, # OUT (math.ceil(size / (2 * reduction_block_size), n_cols) ): @@ -504,13 +510,13 @@ def partial_sum_reduction( # The work groups are indexed in row-major order. From this let's deduce the # position of the window within the column... - local_block_id_in_col = dpex.get_group_id(one_idx) + local_block_id_in_col = nd_item.get_group().get_group_id(one_idx) # Let's map the current work item to an index in a 2D grid, where the # `work_group_size` work items are mapped in row-major order to the array # of size `(n_sub_groups_per_work_group, sub_group_size)`. - local_row_idx = dpex.get_local_id(one_idx) # 2D idx, first coordinate - local_col_idx = dpex.get_local_id(zero_idx) # 2D idx, second coordinate + local_row_idx = nd_item.get_local_id(one_idx) # 2D idx, first coordinate + local_col_idx = nd_item.get_local_id(zero_idx) # 2D idx, second coordinate # This way, each row in the 2D index can be seen as mapped to two rows in the # corresponding window of items of the input `summands`, with the first row of @@ -550,8 +556,9 @@ def partial_sum_reduction( # The current work item use the following second coordinate (given by the # position of the window in the grid of windows, and by the local position of # the work item in the 2D index): + zero_group_idx = nd_item.get_group().get_group_id(zero_idx) col_idx = ( - (dpex.get_group_id(zero) * sub_group_size) + local_col_idx + (zero_group_idx * sub_group_size) + local_col_idx ) sum_axis_size = summands.shape[zero_idx] @@ -567,7 +574,7 @@ def partial_sum_reduction( local_values ) - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) # Then, the sums of two scalars that have been written in `local_array` are # further summed together into `local_array[0, :]`. At each iteration, half @@ -595,7 +602,7 @@ def partial_sum_reduction( local_values ) - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) # At this point local_values[0, :] + local_values[1, :] is equal to the sum of # all elements in summands that have been covered by the work group, we write @@ -610,7 +617,7 @@ def partial_sum_reduction( ) # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_need_to_workaround_numba_dpex_906 # noqa - @dpex.func + @dpex_exp.device_func # fmt: off def _prepare_local_memory( local_row_idx, # PARAM @@ -646,12 +653,12 @@ def _prepare_sum_reduction_2d_axis1( if fused_elementwise_func is None: - @dpex.func + @dpex_exp.device_func def fused_elementwise_func_(x): return x else: - fused_elementwise_func_ = dpex.func(fused_elementwise_func) + fused_elementwise_func_ = dpex_exp.device_func(fused_elementwise_func) input_work_group_size = work_group_size work_group_size = _check_max_work_group_size( @@ -715,9 +722,10 @@ def _make_partial_sum_reduction_2d_axis1_kernel( _sum_and_set_items_if = _make_sum_and_set_items_if_kernel_func() - @dpex.kernel + @dpex_exp.kernel # fmt: off def partial_sum_reduction( + nd_item: NdItem, summands, # IN (n_rows, n_cols) result, # OUT (n_rows, math.ceil(n_cols / (2 * work_group_size),) ): @@ -735,12 +743,12 @@ def partial_sum_reduction( # The work groups are indexed in row-major order, from that let's deduce the # row of `summands` to process by work items in `group_id`... - row_idx = dpex.get_group_id(one_idx) + row_idx = nd_item.get_group().get_group_id(one_idx) # ... and the position of the window within this row, ranging from 0 # (first window in the row) to `n_work_groups_per_row - 1` (last window # in the row): - local_work_group_id_in_row = dpex.get_group_id(zero_idx) + local_work_group_id_in_row = nd_item.get_group().get_group_id(zero_idx) # Since all windows have size `reduction_block_size`, the position of the first # item in the window is given by: @@ -754,7 +762,7 @@ def partial_sum_reduction( # The current work item is indexed locally within the group of work items, with # index `local_work_id` that can range from `0` (first item in the work group) # to `work_group_size - 1` (last item in the work group) - local_work_id = dpex.get_local_id(zero_idx) + local_work_id = nd_item.get_local_id(zero_idx) # Let's remember the size of the array to ensure that the last window in the # row do not try to access items outside the buffer. @@ -796,7 +804,7 @@ def partial_sum_reduction( local_values ) - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) # Then, the sums of two scalars that have been written in `local_array` are # further summed together into `local_array[0]`. At each iteration, half @@ -826,7 +834,7 @@ def partial_sum_reduction( local_values ) - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) # At this point local_values[0] + local_values[1] is equal to the sum of all # elements in summands that have been covered by the work group, we write it @@ -842,7 +850,7 @@ def partial_sum_reduction( ) # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_need_to_workaround_numba_dpex_906 # noqa - @dpex.func + @dpex_exp.device_func # fmt: off def _prepare_local_memory( local_work_id, # PARAM @@ -871,7 +879,7 @@ def _prepare_local_memory( # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_need_to_workaround_numba_dpex_906 # noqa def _make_sum_and_set_items_if_kernel_func(): - @dpex.func + @dpex_exp.device_func # fmt: off def set_sum_of_items_kernel_func( condition, # PARAM diff --git a/sklearn_numba_dpex/common/tests/test_matmul.py b/sklearn_numba_dpex/common/tests/test_matmul.py index af7bb8c..12318c1 100644 --- a/sklearn_numba_dpex/common/tests/test_matmul.py +++ b/sklearn_numba_dpex/common/tests/test_matmul.py @@ -162,3 +162,7 @@ def test_matmul_raise_on_invalid_size_parameters( work_group_size=work_group_size, sub_group_size=sub_group_size, ) + + +if __name__ == "__main__": + _test_matmul_2d(((4, 4), (4, 4)), "arange", 1, 1, np.float32) diff --git a/sklearn_numba_dpex/common/tests/test_random.py b/sklearn_numba_dpex/common/tests/test_random.py index a4ac952..35bf18f 100644 --- a/sklearn_numba_dpex/common/tests/test_random.py +++ b/sklearn_numba_dpex/common/tests/test_random.py @@ -3,8 +3,10 @@ import dpctl.tensor as dpt import numba_dpex as dpex +import numba_dpex.experimental as dpex_exp import numpy as np import pytest +from numba_dpex.kernel_api import NdItem, NdRange from sklearn.utils._testing import assert_allclose from sklearn_numba_dpex.common.random import ( @@ -138,7 +140,12 @@ def _get_single_rand_value(random_state, dtype): """Return a single rand value sampled uniformly in [0, 1)""" _get_single_rand_value_kernel = _make_get_single_rand_value_kernel(dtype) single_rand_value = dpt.empty(1, dtype=dtype) - _get_single_rand_value_kernel[1, 1](random_state, single_rand_value) + dpex_exp.call_kernel( + _get_single_rand_value_kernel, + NdRange((1,), (1,)), + random_state, + single_rand_value, + ) return dpt.asnumpy(single_rand_value)[0] @@ -154,7 +161,9 @@ def _rand_uniform(size, dtype, seed, n_work_items=1000): math.ceil(size / (size_per_work_item * work_group_size)) * work_group_size ) states = create_xoroshiro128pp_states(n_states=global_size, seed=seed) - _rand_uniform_kernel[global_size, work_group_size](states, out) + dpex_exp.call_kernel( + _rand_uniform_kernel, NdRange((global_size,), (work_group_size,)), states, out + ) return out @@ -163,7 +172,7 @@ def _make_get_single_rand_value_kernel(dtype): rand_uniform_kernel_func = make_rand_uniform_kernel_func(np.dtype(dtype)) zero_idx = np.int64(0) - @dpex.kernel + @dpex_exp.kernel # fmt: off def get_single_rand_value( random_state, # IN (1, 2) @@ -180,14 +189,15 @@ def _make_rand_uniform_kernel(size, dtype, size_per_work_item): rand_uniform_kernel_func = make_rand_uniform_kernel_func(np.dtype(dtype)) private_states_shape = (1, 2) - @dpex.kernel + @dpex_exp.kernel # fmt: off def _rand_uniform_kernel( + nd_item: NdItem, states, # IN (global_size, 2) out, # OUT (size,) ): # fmt: on - item_idx = dpex.get_global_id(0) + item_idx = nd_item.get_global_id(0) out_idx = item_idx * size_per_work_item private_states = dpex.private.array(shape=private_states_shape, dtype=np.uint64) diff --git a/sklearn_numba_dpex/common/topk.py b/sklearn_numba_dpex/common/topk.py index 46790ee..7053a31 100644 --- a/sklearn_numba_dpex/common/topk.py +++ b/sklearn_numba_dpex/common/topk.py @@ -11,7 +11,9 @@ import dpctl.tensor as dpt import numba_dpex as dpex +import numba_dpex.experimental as dpex_exp import numpy as np +from numba_dpex.kernel_api import AtomicRef, MemoryScope, NdItem, NdRange, group_barrier from sklearn_numba_dpex.common._utils import ( _enforce_matmul_like_work_group_geometry, @@ -86,7 +88,7 @@ def _make_lexicographical_mapping_kernel_func(dtype): uint_type = uint_type_mapping[dtype] sign_mask = uint_type(2 ** (sign_bit_idx)) - @dpex.func + @dpex_exp.device_func def lexicographical_mapping(item): mask = (-(item >> sign_bit_idx)) | sign_mask return item ^ mask @@ -101,7 +103,7 @@ def _make_lexicographical_unmapping_kernel_func(dtype): one_as_uint_dtype = uint_type(1) sign_mask = uint_type(2 ** (sign_bit_idx)) - @dpex.func + @dpex_exp.device_func def lexicographical_unmapping(item): mask = ((item >> sign_bit_idx) - one_as_uint_dtype) | sign_mask return item ^ mask @@ -618,9 +620,10 @@ def _make_create_radix_histogram_kernel( one_as_uint_dtype << np.uint32(radix_bits) ) - one_as_uint_dtype - @dpex.kernel + @dpex_exp.kernel # fmt: off def create_radix_histogram( + nd_item: NdItem, array_in_uint, # IN READ-ONLY (n_rows, n_items) active_rows_mapping, # IN (n_rows,) mask_for_desired_value, # IN (1,) @@ -654,15 +657,15 @@ def create_radix_histogram( """ # Row and column indices of the value in `array_in_uint` whose radix will be # computed by the current work item - row_idx = active_rows_mapping[dpex.get_global_id(one_idx)] - col_idx = dpex.get_global_id(zero_idx) + ( - sub_group_size * dpex.get_global_id(two_idx)) + row_idx = active_rows_mapping[nd_item.get_global_id(one_idx)] + col_idx = nd_item.get_global_id(zero_idx) + ( + sub_group_size * nd_item.get_global_id(two_idx)) # Index of the subgroup and position within this sub group. Incidentally, this # also indexes the location to which the radix value will be written in the # shared memory buffer. - local_subgroup = dpex.get_local_id(two_idx) - local_subgroup_work_id = dpex.get_local_id(zero_idx) + local_subgroup = nd_item.get_local_id(two_idx) + local_subgroup_work_id = nd_item.get_local_id(zero_idx) # Like `col_idx`, but where the first value of `array_in_uint` covered by the # current work group is indexed with zero. @@ -702,7 +705,7 @@ def create_radix_histogram( radix_values ) - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) # The first `n_local_histograms` work items read `sub_group_size` # values each and compute the histogram of their occurences in private memory. @@ -721,7 +724,7 @@ def create_radix_histogram( private_counts ) - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) # The first `n_local_histograms` work items write their private histogram # into the shared memory buffer, effectively sharing it with all other work @@ -735,7 +738,7 @@ def create_radix_histogram( local_counts ) - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) # This is the merge step, where all shared histograms are summed # together into the first buffer local_counts[0], in a bracket manner. @@ -752,7 +755,7 @@ def create_radix_histogram( # OUT local_counts ) - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) # The current histogram is local to the current work group. Summing right away # all histograms to a unique, global histogram in global memory might give poor @@ -770,13 +773,13 @@ def create_radix_histogram( privatized_counts ) - # HACK 906: all instructions inbetween barriers must be defined in `dpex.func` - # device functions. + # HACK 906: all instructions inbetween barriers must be defined in + # `dpex_exp.device_func` device functions. # See sklearn_numba_dpex.patches.tests.test_patches.test_need_to_workaround_numba_dpex_906 # noqa # HACK 906: start - @dpex.func + @dpex_exp.device_func # fmt: off def compute_radixes( row_idx, # PARAM @@ -841,7 +844,7 @@ def compute_radixes( # `sub_group_size` here. col_idx_increment_per_step = n_sub_groups_for_local_histograms * sub_group_size - @dpex.func + @dpex_exp.device_func # fmt: off def compute_private_histogram( col_idx, # PARAM @@ -875,7 +878,7 @@ def compute_private_histogram( # always divisible by `n_local_histograms` here. n_iter_for_radixes = sub_group_size // n_local_histograms - @dpex.func + @dpex_exp.device_func # fmt: off def compute_private_histogram( col_idx, # PARAM @@ -901,7 +904,7 @@ def compute_private_histogram( current_col_idx += n_local_histograms starting_col_idx += sub_group_size - @dpex.func + @dpex_exp.device_func # fmt: off def share_private_histograms( local_subgroup, # PARAM @@ -923,7 +926,7 @@ def share_private_histograms( ] = private_counts[col_idx] col_idx = (col_idx + one_idx) % sub_group_size - @dpex.func + @dpex_exp.device_func # fmt: off def partial_local_histograms_reduction( local_subgroup, # PARAM @@ -937,7 +940,7 @@ def partial_local_histograms_reduction( local_subgroup + reduction_active_subgroups, local_subgroup_work_id ] - @dpex.func + @dpex_exp.device_func # fmt: off def merge_histogram_in_global_memory( row_idx, # PARAM @@ -952,11 +955,10 @@ def merge_histogram_in_global_memory( privatization_idx = (col_idx // work_group_size) % n_counts_private_copies if local_subgroup == zero_idx: - dpex.atomic.add( + AtomicRef( privatized_counts, (privatization_idx, row_idx, local_subgroup_work_id), - local_counts[zero_idx, local_subgroup_work_id], - ) + ).fetch_add(local_counts[zero_idx, local_subgroup_work_id]) # HACK 906: end @@ -976,7 +978,9 @@ def _create_radix_histogram( n_active_rows, n_local_histograms * n_work_groups_per_row, ) - create_radix_histogram[global_shape, work_group_shape]( + dpex_exp.call_kernel( + create_radix_histogram, + NdRange(global_shape, work_group_shape), array_in_uint, active_rows_mapping, mask_for_desired_value, @@ -1000,9 +1004,10 @@ def _make_check_radix_histogram_kernel(radix_size, dtype, work_group_size): uint_type = uint_type_mapping[dtype] zero_as_uint_dtype = uint_type(0) - @dpex.kernel + @dpex_exp.kernel # fmt: off def check_radix_histogram( + nd_item: NdItem, counts, # IN (n_rows, radix_size,) radix_position, # IN (1,) n_active_rows, # IN (1,) @@ -1014,7 +1019,7 @@ def check_radix_histogram( new_active_rows_mapping, # OUT (n_rows,) ): # fmt: on - work_item_idx = dpex.get_global_id(zero_idx) + work_item_idx = nd_item.get_global_id(zero_idx) if work_item_idx >= n_active_rows[zero_idx]: return @@ -1074,14 +1079,14 @@ def check_radix_histogram( desired_masked_value_ = lexicographical_unmapping(desired_masked_value_) else: - new_active_row_idx = dpex.atomic.add( - new_n_active_rows, zero_idx, count_one_as_an_int - ) + new_active_row_idx = AtomicRef( + new_n_active_rows, zero_idx, + ).fetch_add(np.int64(count_one_as_an_int)) new_active_rows_mapping[new_active_row_idx] = row_idx desired_masked_value[row_idx] = desired_masked_value_ - @dpex.kernel + @dpex_exp.kernel def update_radix_position(radix_position, mask_for_desired_value): # The current partial analysis with the current radixes seen was not enough # to find the k-th element. Let's inspect the next `radix_bits`. @@ -1108,7 +1113,9 @@ def _check_radix_histogram( n_active_rows_ = int(n_active_rows[0]) global_size = math.ceil(n_active_rows_ / work_group_size) * work_group_size - check_radix_histogram[global_size, work_group_size]( + dpex_exp.call_kernel( + check_radix_histogram, + NdRange((global_size,), (work_group_size,)), counts, active_rows_mapping, n_active_rows, @@ -1120,7 +1127,10 @@ def _check_radix_histogram( new_n_active_rows, ) - return update_radix_position[1, 1], _check_radix_histogram + def kernel_call(*args): + dpex_exp.call_kernel(update_radix_position, NdRange((1,), (1,)), *args) + + return kernel_call, _check_radix_histogram @lru_cache @@ -1140,9 +1150,10 @@ def _make_gather_topk_kernel( work_group_shape = (1, work_group_size) global_shape = (n_rows, n_work_groups_per_row * work_group_size) - @dpex.kernel + @dpex_exp.kernel # fmt: off def gather_topk( + nd_item: NdItem, array_in, # IN READONLY (n_rows, n_cols) threshold, # IN (n_rows,) n_threshold_occurences_in_topk, # IN (n_rows,) @@ -1151,8 +1162,8 @@ def gather_topk( result, # OUT (n_rows, k) ): # fmt: on - row_idx = dpex.get_global_id(zero_idx) - col_idx = dpex.get_global_id(one_idx) + row_idx = nd_item.get_global_id(zero_idx) + col_idx = nd_item.get_global_id(one_idx) n_threshold_occurences_in_topk_ = n_threshold_occurences_in_topk[row_idx] @@ -1186,7 +1197,7 @@ def gather_topk( # in the top-k values are known. When those two numbers are equal, the kernel can # be written more efficient and much simpler, and the condition is not unusual. # Let's write a separate kernel for this special case. - @dpex.func + @dpex_exp.device_func # fmt: off def gather_topk_include_all_threshold_occurences( row_idx, # PARAM @@ -1206,11 +1217,12 @@ def gather_topk_include_all_threshold_occurences( item = array_in[row_idx, col_idx] if item >= threshold: - result_col_idx_ = dpex.atomic.add( - result_col_idx, row_idx, count_one_as_an_int) + result_col_idx_ = AtomicRef( + result_col_idx, row_idx, + ).fetch_add(count_one_as_an_int) result[row_idx, result_col_idx_] = item - @dpex.func + @dpex_exp.device_func # fmt: off def gather_topk_generic( row_idx, # PARAM @@ -1240,10 +1252,17 @@ def gather_topk_generic( if item <= threshold: return - result_col_idx_ = dpex.atomic.add(result_col_idx, row_idx, count_one_as_an_int) + result_col_idx_ = AtomicRef( + result_col_idx, row_idx, + ).fetch_add(count_one_as_an_int) result[row_idx, result_col_idx_] = item - return gather_topk[global_shape, work_group_shape] + def kernel_call(*args): + dpex_exp.call_kernel( + gather_topk, NdRange(global_shape, work_group_shape), *args + ) + + return kernel_call @lru_cache @@ -1258,9 +1277,10 @@ def _make_gather_topk_idx_kernel( work_group_shape = (1, work_group_size) global_shape = (n_rows, n_work_groups_per_row * work_group_size) - @dpex.kernel + @dpex_exp.kernel # fmt: off def gather_topk_idx( + nd_item: NdItem, array_in, # IN READONLY (n_rows, n_cols) threshold, # IN (n_rows,) n_threshold_occurences_in_topk, # IN (n_rows,) @@ -1269,8 +1289,8 @@ def gather_topk_idx( result, # OUT (n_rows, k) ): # fmt: on - row_idx = dpex.get_global_id(zero_idx) - col_idx = dpex.get_global_id(one_idx) + row_idx = nd_item.get_global_id(zero_idx) + col_idx = nd_item.get_global_id(one_idx) n_threshold_occurences_in_topk_ = n_threshold_occurences_in_topk[row_idx] @@ -1296,7 +1316,7 @@ def gather_topk_idx( result, ) - @dpex.func + @dpex_exp.device_func # fmt: off def gather_topk_idx_include_all_threshold_occurences( row_idx, # PARAM @@ -1316,11 +1336,12 @@ def gather_topk_idx_include_all_threshold_occurences( item = array_in[row_idx, col_idx] if item >= threshold: - result_col_idx_ = dpex.atomic.add( - result_col_idx, row_idx, count_one_as_an_int) + result_col_idx_ = AtomicRef( + result_col_idx, row_idx, + ).fetch_add(count_one_as_an_int) result[row_idx, result_col_idx_] = col_idx - @dpex.func + @dpex_exp.device_func # fmt: off def gather_topk_idx_generic( row_idx, # PARAM @@ -1344,9 +1365,9 @@ def gather_topk_idx_generic( return if item > threshold: - result_col_idx_ = dpex.atomic.add( - result_col_idx, row_idx, count_one_as_an_int - ) + result_col_idx_ = AtomicRef( + result_col_idx, row_idx, + ).fetch_add(count_one_as_an_int) result[row_idx, result_col_idx_] = col_idx return @@ -1357,9 +1378,14 @@ def gather_topk_idx_generic( n_threshold_occurences, row_idx, count_one_as_an_int) if remaining_n_threshold_occurences > zero_idx: - result_col_idx_ = dpex.atomic.add( - result_col_idx, row_idx, count_one_as_an_int - ) + result_col_idx_ = AtomicRef( + result_col_idx, row_idx, + ).fetch_add(count_one_as_an_int) result[row_idx, result_col_idx_] = col_idx - return gather_topk_idx[global_shape, work_group_shape] + def kernel_call(*args): + dpex_exp.call_kernel( + gather_topk_idx, NdRange(global_shape, work_group_shape), *args + ) + + return kernel_call diff --git a/sklearn_numba_dpex/kmeans/kernels/_base_kmeans_kernel_funcs.py b/sklearn_numba_dpex/kmeans/kernels/_base_kmeans_kernel_funcs.py index ac105e8..4c23a4e 100644 --- a/sklearn_numba_dpex/kmeans/kernels/_base_kmeans_kernel_funcs.py +++ b/sklearn_numba_dpex/kmeans/kernels/_base_kmeans_kernel_funcs.py @@ -1,4 +1,4 @@ -import numba_dpex as dpex +import numba_dpex.experimental as dpex_exp import numpy as np zero_as_a_long = np.int64(0) @@ -79,7 +79,7 @@ def make_pairwise_ops_base_kernel_funcs( ) ) - @dpex.func + @dpex_exp.device_func def accumulate_dot_products( sample_idx, first_feature_idx, @@ -130,7 +130,7 @@ def accumulate_dot_products( last_window_n_centroids ) - @dpex.func + @dpex_exp.device_func def initialize_window_half_l2_norm( local_row_idx, local_col_idx, @@ -184,7 +184,7 @@ def __init__(self, n_samples, n_features, n_clusters, ops, dtype): self.dtype = dtype def make_initialize_window_half_l2_norm_kernel_func(self, window_n_centroids): - @dpex.func + @dpex_exp.device_func # fmt: off def _initialize_window_of_centroids( local_row_idx, # PARAM @@ -224,7 +224,7 @@ def make_load_window_kernel_func(self): zero = self.dtype(0.0) - @dpex.func + @dpex_exp.device_func # fmt: off def _load_window_of_centroids_and_features( first_feature_idx, # PARAM @@ -261,7 +261,7 @@ def make_accumulate_sum_of_ops_kernel_func( n_samples = self.n_samples accumulate_dot_product = self.accumulate_dot_product - @dpex.func + @dpex_exp.device_func # fmt: off def _accumulate_sum_of_ops( sample_idx, # PARAM @@ -315,7 +315,7 @@ def make_update_closest_centroid_kernel_func(n_clusters, window_n_centroids): last_window_n_centroids ) - @dpex.func + @dpex_exp.device_func def update_closest_centroid( first_centroid_idx, min_idx, @@ -346,7 +346,7 @@ def update_closest_centroid( def _make_update_closest_centroid_kernel_func(window_n_centroids): - @dpex.func + @dpex_exp.device_func # fmt: off def update_closest_centroid( first_centroid_idx, # PARAM diff --git a/sklearn_numba_dpex/kmeans/kernels/compute_euclidean_distances.py b/sklearn_numba_dpex/kmeans/kernels/compute_euclidean_distances.py index 1a54ba4..0db6232 100644 --- a/sklearn_numba_dpex/kmeans/kernels/compute_euclidean_distances.py +++ b/sklearn_numba_dpex/kmeans/kernels/compute_euclidean_distances.py @@ -2,7 +2,9 @@ from functools import lru_cache import numba_dpex as dpex +import numba_dpex.experimental as dpex_exp import numpy as np +from numba_dpex.kernel_api import MemoryScope, NdItem, NdRange, group_barrier from sklearn_numba_dpex.common._utils import _check_max_work_group_size @@ -60,9 +62,10 @@ def make_compute_euclidean_distances_fixed_window_kernel( zero_idx = np.int64(0) one_idx = np.int64(1) - @dpex.kernel + @dpex_exp.kernel # fmt: off def compute_distances( + nd_item: NdItem, X_t, # IN READ-ONLY (n_features, n_samples) current_centroids_t, # IN READ-ONLY (n_features, n_clusters) euclidean_distances_t, # OUT (n_clusters, n_samples) @@ -75,13 +78,13 @@ def compute_distances( first_centroid_idx = zero_idx - local_col_idx = dpex.get_local_id(zero_idx) + local_col_idx = nd_item.get_local_id(zero_idx) - window_loading_feature_offset = dpex.get_local_id(one_idx) + window_loading_feature_offset = nd_item.get_local_id(one_idx) window_loading_centroid_idx = local_col_idx sample_idx = ( - (dpex.get_global_id(one_idx) * sub_group_size) + (nd_item.get_global_id(one_idx) * sub_group_size) + local_col_idx ) @@ -103,7 +106,7 @@ def compute_distances( centroids_window, ) - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) accumulate_sq_distances( sample_idx, first_feature_idx, @@ -116,7 +119,7 @@ def compute_distances( first_feature_idx += centroids_window_height - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) _save_distance( sample_idx, @@ -128,10 +131,10 @@ def compute_distances( first_centroid_idx += window_n_centroids - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_need_to_workaround_numba_dpex_906 # noqa - @dpex.func + @dpex_exp.device_func # fmt: off def _save_distance( sample_idx, # PARAM @@ -158,4 +161,10 @@ def _save_distance( math.ceil(n_windows_for_sample / centroids_window_height) * centroids_window_height, ) - return compute_distances[global_size, work_group_shape] + + def kernel_call(*args): + dpex_exp.call_kernel( + compute_distances, NdRange(global_size, work_group_shape), *args + ) + + return kernel_call diff --git a/sklearn_numba_dpex/kmeans/kernels/compute_inertia.py b/sklearn_numba_dpex/kmeans/kernels/compute_inertia.py index 59184dc..993ec14 100644 --- a/sklearn_numba_dpex/kmeans/kernels/compute_inertia.py +++ b/sklearn_numba_dpex/kmeans/kernels/compute_inertia.py @@ -1,8 +1,9 @@ import math from functools import lru_cache -import numba_dpex as dpex +import numba_dpex.experimental as dpex_exp import numpy as np +from numba_dpex.kernel_api import NdItem, NdRange @lru_cache @@ -11,9 +12,10 @@ def make_compute_inertia_kernel(n_samples, n_features, work_group_size, dtype): zero_idx = np.int64(0) zero_init = dtype(0.0) - @dpex.kernel + @dpex_exp.kernel # fmt: off def compute_inertia( + nd_item: NdItem, X_t, # IN READ-ONLY (n_features, n_samples) sample_weight, # IN READ-ONLY (n_features,) centroids_t, # IN READ-ONLY (n_features, n_clusters) @@ -22,7 +24,7 @@ def compute_inertia( ): # fmt: on - sample_idx = dpex.get_global_id(zero_idx) + sample_idx = nd_item.get_global_id(zero_idx) if sample_idx >= n_samples: return @@ -39,4 +41,10 @@ def compute_inertia( per_sample_inertia[sample_idx] = inertia * sample_weight[sample_idx] global_size = (math.ceil(n_samples / work_group_size)) * (work_group_size) - return compute_inertia[global_size, work_group_size] + + def kernel_call(*args): + dpex_exp.call_kernel( + compute_inertia, NdRange((global_size,), (work_group_size,)), *args + ) + + return kernel_call diff --git a/sklearn_numba_dpex/kmeans/kernels/compute_labels.py b/sklearn_numba_dpex/kmeans/kernels/compute_labels.py index 0e1646e..8b379c9 100644 --- a/sklearn_numba_dpex/kmeans/kernels/compute_labels.py +++ b/sklearn_numba_dpex/kmeans/kernels/compute_labels.py @@ -2,7 +2,9 @@ from functools import lru_cache import numba_dpex as dpex +import numba_dpex.experimental as dpex_exp import numpy as np +from numba_dpex.kernel_api import MemoryScope, NdItem, NdRange, group_barrier from sklearn_numba_dpex.common._utils import _check_max_work_group_size @@ -72,19 +74,20 @@ def make_label_assignment_fixed_window_kernel( zero_idx = np.int64(0) one_idx = np.int64(1) - @dpex.kernel + @dpex_exp.kernel # fmt: off def assignment( + nd_item: NdItem, X_t, # IN READ-ONLY (n_features, n_samples) centroids_t, # IN READ-ONLY (n_features, n_clusters) centroids_half_l2_norm, # IN (n_clusters,) assignments_idx, # OUT (n_samples,) ): # fmt: on - local_row_idx = dpex.get_local_id(one_idx) - local_col_idx = dpex.get_local_id(zero_idx) + local_row_idx = nd_item.get_local_id(one_idx) + local_col_idx = nd_item.get_local_id(zero_idx) sample_idx = ( - (dpex.get_global_id(one_idx) * sub_group_size) + (nd_item.get_global_id(one_idx) * sub_group_size) + local_col_idx ) @@ -129,7 +132,7 @@ def assignment( centroids_window, ) - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) accumulate_dot_products( sample_idx, @@ -144,7 +147,7 @@ def assignment( first_feature_idx += centroids_window_height - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) min_idx, min_sample_pseudo_inertia = update_closest_centroid( first_centroid_idx, @@ -157,7 +160,7 @@ def assignment( first_centroid_idx += window_n_centroids - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) _setitem_if( sample_idx < n_samples, @@ -168,7 +171,7 @@ def assignment( ) # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_need_to_workaround_numba_dpex_906 # noqa - @dpex.func + @dpex_exp.device_func def _setitem_if(condition, index, value, array): if condition: array[index] = value @@ -182,4 +185,7 @@ def _setitem_if(condition, index, value, array): * centroids_window_height, ) - return assignment[global_size, work_group_shape] + def kernel_call(*args): + dpex_exp.call_kernel(assignment, NdRange(global_size, work_group_shape), *args) + + return kernel_call diff --git a/sklearn_numba_dpex/kmeans/kernels/kmeans_plusplus.py b/sklearn_numba_dpex/kmeans/kernels/kmeans_plusplus.py index 2b606d0..c7b8abd 100644 --- a/sklearn_numba_dpex/kmeans/kernels/kmeans_plusplus.py +++ b/sklearn_numba_dpex/kmeans/kernels/kmeans_plusplus.py @@ -2,7 +2,9 @@ from functools import lru_cache import numba_dpex as dpex +import numba_dpex.experimental as dpex_exp import numpy as np +from numba_dpex.kernel_api import MemoryScope, NdItem, NdRange, group_barrier from sklearn_numba_dpex.common._utils import _check_max_work_group_size from sklearn_numba_dpex.common.random import make_rand_uniform_kernel_func @@ -24,9 +26,10 @@ def make_kmeansplusplus_init_kernel( zero_idx = np.int64(0) zero_init = dtype(0.0) - @dpex.kernel + @dpex_exp.kernel # fmt: off def kmeansplusplus_init( + nd_item: NdItem, X_t, # IN READ-ONLY (n_features, n_samples) sample_weight, # IN READ-ONLY (n_samples,) centers_t, # OUT (n_features, n_clusters) @@ -34,7 +37,7 @@ def kmeansplusplus_init( closest_dist_sq, # OUT (n_samples,) ): # fmt: on - sample_idx = dpex.get_global_id(zero_idx) + sample_idx = nd_item.get_global_id(zero_idx) if sample_idx >= n_samples: return @@ -55,7 +58,13 @@ def kmeansplusplus_init( centers_t[feature_idx, zero_idx] = X_t[feature_idx, starting_center_id_] global_size = (math.ceil(n_samples / work_group_size)) * (work_group_size) - return kmeansplusplus_init[global_size, work_group_size] + + def kernel_call(*args): + dpex_exp.call_kernel( + kmeansplusplus_init, NdRange((global_size,), (work_group_size,)), *args + ) + + return kernel_call @lru_cache @@ -74,16 +83,17 @@ def make_sample_center_candidates_kernel( zero_init = dtype(0.0) max_candidate_id = np.int32(n_samples - 1) - @dpex.kernel + @dpex_exp.kernel # fmt: off def sample_center_candidates( + nd_item: NdItem, closest_dist_sq, # IN (n_features, n_samples) total_potential, # IN (1,) random_state, # INOUT (n_local_trials, 2) candidates_id, # OUT (n_local_trials,) ): # fmt: on - local_trial_idx = dpex.get_global_id(zero_idx) + local_trial_idx = nd_item.get_global_id(zero_idx) if local_trial_idx >= n_local_trials: return random_value = (rand_uniform_kernel_func(random_state, local_trial_idx) @@ -100,7 +110,13 @@ def sample_center_candidates( candidates_id[local_trial_idx] = candidate_id global_size = (math.ceil(n_local_trials / work_group_size)) * work_group_size - return sample_center_candidates[global_size, work_group_size] + + def kernel_call(*args): + dpex_exp.call_kernel( + sample_center_candidates, NdRange((global_size,), (work_group_size,)), *args + ) + + return kernel_call @lru_cache @@ -151,9 +167,10 @@ def make_kmeansplusplus_single_step_fixed_window_kernel( zero_idx = np.int64(0) one_idx = np.int64(1) - @dpex.kernel + @dpex_exp.kernel # fmt: off def kmeansplusplus_single_step( + nd_item: NdItem, X_t, # IN READ-ONLY (n_features, n_samples) sample_weight, # IN READ-ONLY (n_samples,) candidates_ids, # IN (n_candidates,) @@ -167,13 +184,13 @@ def kmeansplusplus_single_step( first_candidate_idx = zero_idx - local_col_idx = dpex.get_local_id(zero_idx) + local_col_idx = nd_item.get_local_id(zero_idx) - window_loading_feature_offset = dpex.get_local_id(one_idx) + window_loading_feature_offset = nd_item.get_local_id(one_idx) window_loading_candidate_idx = local_col_idx sample_idx = ( - (dpex.get_global_id(one_idx) * sub_group_size) + (nd_item.get_global_id(one_idx) * sub_group_size) + local_col_idx ) @@ -199,7 +216,7 @@ def kmeansplusplus_single_step( candidates_window, ) - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) accumulate_sq_distances( sample_idx, @@ -213,7 +230,7 @@ def kmeansplusplus_single_step( first_feature_idx += candidates_window_height - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) _save_sq_distances( sample_idx, @@ -227,10 +244,10 @@ def kmeansplusplus_single_step( first_candidate_idx += window_n_candidates - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_need_to_workaround_numba_dpex_906 # noqa - @dpex.func + @dpex_exp.device_func # fmt: off def _save_sq_distances( sample_idx, # PARAM @@ -259,4 +276,10 @@ def _save_sq_distances( math.ceil(n_windows_for_samples / candidates_window_height) * candidates_window_height, ) - return kmeansplusplus_single_step[global_size, work_group_shape] + + def kernel_call(*args): + dpex_exp.call_kernel( + kmeansplusplus_single_step, NdRange(global_size, work_group_shape), *args + ) + + return kernel_call diff --git a/sklearn_numba_dpex/kmeans/kernels/lloyd_single_step.py b/sklearn_numba_dpex/kmeans/kernels/lloyd_single_step.py index b35c2ea..a089e61 100644 --- a/sklearn_numba_dpex/kmeans/kernels/lloyd_single_step.py +++ b/sklearn_numba_dpex/kmeans/kernels/lloyd_single_step.py @@ -2,7 +2,9 @@ from functools import lru_cache import numba_dpex as dpex +import numba_dpex.experimental as dpex_exp import numpy as np +from numba_dpex.kernel_api import AtomicRef, MemoryScope, NdItem, NdRange, group_barrier from sklearn_numba_dpex.common._utils import _check_max_work_group_size @@ -139,16 +141,18 @@ def make_lloyd_single_step_fixed_window_kernel( # inputs such as X_t it is generally regarded as faster. Once support is available # (NB: it's already supported by numba.cuda) X_t should be an input to the factory # rather than an input to the kernel. - # XXX: parts of the kernels are factorized using `dpex.func` namespace that allow + # XXX: parts of the kernels are factorized using `dpex_exp.device_func` namespace + # that allow # defining device functions that can be used within `dpex.kernel` definitions. - # Howver, `dpex.func` functions does not support dpex.barrier calls nor + # Howver, `dpex_exp.device_func` functions does not support dpex.barrier calls nor # creating local or private arrays. As a consequence, factorizing the kmeans kernels # remains a best effort and some code patternsd remain duplicated, In particular # the following kernel definition contains a lot of inline comments but those # comments are not repeated in the similar patterns in the other kernels - @dpex.kernel + @dpex_exp.kernel # fmt: off def fused_lloyd_single_step( + nd_item: NdItem, X_t, # IN READ-ONLY (n_features, n_samples) sample_weight, # IN READ-ONLY (n_features,) current_centroids_t, # IN (n_features, n_clusters) @@ -189,9 +193,9 @@ def fused_lloyd_single_step( # reads like a SYCL kernel that maps 2D group size with a row-major order, # despite that `numba_dpex` chose to mimic the column-major order style of # mapping 2D group sizes in cuda. - sub_group_idx = dpex.get_global_id(one_idx) - local_row_idx = dpex.get_local_id(one_idx) - local_col_idx = dpex.get_local_id(zero_idx) + sub_group_idx = nd_item.get_global_id(one_idx) + local_row_idx = nd_item.get_local_id(one_idx) + local_col_idx = nd_item.get_local_id(zero_idx) # Let's start by remapping the 2D grid of work items to a 1D grid that reflect # how contiguous work items address one contiguoue sample_idx: @@ -278,7 +282,7 @@ def fused_lloyd_single_step( # Since other work items are responsible for loading the relevant data # for the next step, we need to wait for completion of all work items # before going forward - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) accumulate_dot_products( sample_idx, @@ -296,7 +300,7 @@ def fused_lloyd_single_step( # When the next iteration starts work items will overwrite shared memory # with new values, so before that we must wait for all reading # operations in the current iteration to be over for all work items. - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) # End of inner loop. The pseudo inertia is now computed for all centroids # in the window, we can coalesce it to the accumulation of the min pseudo @@ -315,7 +319,7 @@ def fused_lloyd_single_step( # When the next iteration starts work items will overwrite shared memory # with new values, so before that we must wait for all reading # operations in the current iteration to be over for all work items. - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) # End of outer loop. By now min_idx and min_sample_pseudo_inertia # contains the expected values. @@ -335,7 +339,7 @@ def fused_lloyd_single_step( ) # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_need_to_workaround_numba_dpex_906 # noqa - @dpex.func + @dpex_exp.device_func # fmt: off def _update_result_data( sample_idx, # PARAM @@ -400,25 +404,25 @@ def _update_result_data( privatization_idx = sub_group_idx % n_centroids_private_copies weight = sample_weight[sample_idx] - dpex.atomic.add( + AtomicRef( cluster_sizes_private_copies, (privatization_idx, min_idx), - weight - ) + ).fetch_add(weight) for feature_idx in range(n_features): - dpex.atomic.add( + AtomicRef( new_centroids_t_private_copies, (privatization_idx, feature_idx, min_idx), - X_t[feature_idx, sample_idx] * weight, - ) + ).fetch_add(X_t[feature_idx, sample_idx] * weight) global_size = ( window_n_centroids, math.ceil(n_subgroups / centroids_window_height) * centroids_window_height, ) - return ( - n_centroids_private_copies, - fused_lloyd_single_step[global_size, work_group_shape], - ) + def kernel_call(*args): + dpex_exp.call_kernel( + fused_lloyd_single_step, NdRange(global_size, work_group_shape), *args + ) + + return (n_centroids_private_copies, kernel_call) diff --git a/sklearn_numba_dpex/kmeans/kernels/utils.py b/sklearn_numba_dpex/kmeans/kernels/utils.py index 36c8fc6..c62cd3b 100644 --- a/sklearn_numba_dpex/kmeans/kernels/utils.py +++ b/sklearn_numba_dpex/kmeans/kernels/utils.py @@ -3,7 +3,9 @@ import dpctl.tensor as dpt import numba_dpex as dpex +import numba_dpex.experimental as dpex_exp import numpy as np +from numba_dpex.kernel_api import AtomicRef, NdItem, NdRange zero_idx = np.int64(0) @@ -18,9 +20,10 @@ def make_relocate_empty_clusters_kernel( zero = dtype(0.0) - @dpex.kernel + @dpex_exp.kernel # fmt: off def relocate_empty_clusters( + nd_item: NdItem, X_t, # IN READ-ONLY (n_features, n_samples) sample_weight, # IN READ-ONLY (n_samples,) assignments_idx, # IN (n_samples,) @@ -31,8 +34,8 @@ def relocate_empty_clusters( cluster_sizes # INOUT (n_clusters,) ): # fmt: on - group_idx = dpex.get_group_id(zero_idx) - item_idx = dpex.get_local_id(zero_idx) + group_idx = nd_item.get_group().get_group_id(zero_idx) + item_idx = nd_item.get_local_id(zero_idx) relocated_idx = group_idx // n_work_groups_for_cluster feature_idx = ( ((group_idx % n_work_groups_for_cluster) * work_group_size) + item_idx @@ -73,7 +76,12 @@ def relocate_empty_clusters( ) cluster_sizes[relocated_cluster_idx] = new_location_weight - return relocate_empty_clusters[global_size, work_group_size] + def kernel_call(*args): + dpex_exp.call_kernel( + relocate_empty_clusters, NdRange((global_size,), (work_group_size,)), *args + ) + + return kernel_call @lru_cache @@ -83,15 +91,16 @@ def make_centroid_shifts_kernel(n_clusters, n_features, work_group_size, dtype): # Optimized for C-contiguous array and for # size1 >> preferred_work_group_size_multiple - @dpex.kernel + @dpex_exp.kernel # fmt: off def centroid_shifts( + nd_item: NdItem, centroids_t, # IN (n_features, n_clusters) new_centroids_t, # IN (n_features, n_clusters) centroid_shifts, # OUT (n_clusters,) ): # fmt: on - sample_idx = dpex.get_global_id(zero_idx) + sample_idx = nd_item.get_global_id(zero_idx) if sample_idx >= n_clusters: return @@ -107,7 +116,12 @@ def centroid_shifts( centroid_shifts[sample_idx] = squared_centroid_diff - return centroid_shifts[global_size, work_group_size] + def kernel_call(*args): + dpex_exp.call_kernel( + centroid_shifts, NdRange((global_size,), (work_group_size,)), *args + ) + + return kernel_call @lru_cache @@ -127,9 +141,10 @@ def make_reduce_centroid_data_kernel( # Optimized for C-contiguous array and assuming # n_features * n_clusters >> preferred_work_group_size_multiple - @dpex.kernel + @dpex_exp.kernel # fmt: off def _reduce_centroid_data_kernel( + nd_item: NdItem, cluster_sizes_private_copies, # IN (n_copies, n_clusters) centroids_t_private_copies_reshaped, # IN (n_copies, n_features * n_clusters) # noqa cluster_sizes, # OUT (n_clusters,) @@ -138,7 +153,7 @@ def _reduce_centroid_data_kernel( n_empty_clusters, # OUT (1,) ): # fmt: on - item_idx = dpex.get_global_id(zero_idx) + item_idx = nd_item.get_global_id(zero_idx) if item_idx >= n_sums: return @@ -161,14 +176,17 @@ def _reduce_centroid_data_kernel( # register empty clusters if sum_ == zero: - current_n_empty_clusters = dpex.atomic.add( - n_empty_clusters, zero_idx, one_incr - ) + current_n_empty_clusters = AtomicRef( + n_empty_clusters, zero_idx, + ).fetch_add(one_incr) empty_clusters_list[current_n_empty_clusters] = cluster_idx - reduce_centroid_data_kernel = _reduce_centroid_data_kernel[ - global_size, work_group_size - ] + def reduce_centroid_data_kernel(*args): + dpex_exp.call_kernel( + _reduce_centroid_data_kernel, + NdRange((global_size,), (work_group_size,)), + *args, + ) def reduce_centroid_data( cluster_sizes_private_copies, @@ -207,21 +225,33 @@ def make_is_same_clustering_kernel(n_samples, n_clusters, work_group_size, devic def is_same_clustering(labels1, labels2): mapping = dpt.empty((n_clusters,), dtype=np.int32, device=device) result = dpt.asarray([1], dtype=np.int32, device=device) - _build_mapping[global_size, work_group_size](labels1, labels2, mapping) - _is_same_clustering[global_size, work_group_size]( - labels1, labels2, mapping, result + dpex_exp.call_kernel( + _build_mapping, + NdRange((global_size,), (work_group_size,)), + labels1, + labels2, + mapping, + ) + dpex_exp.call_kernel( + _is_same_clustering, + NdRange((global_size,), (work_group_size,)), + labels1, + labels2, + mapping, + result, ) return bool(result[0]) - @dpex.kernel + @dpex_exp.kernel # fmt: off def _build_mapping( + nd_item: NdItem, labels1, # IN (n_samples,) labels2, # IN (n_samples,) mapping, # BUFFER (n_clusters,) ): # fmt: on - sample_idx = dpex.get_global_id(zero_idx) + sample_idx = nd_item.get_global_id(zero_idx) if sample_idx >= n_samples: return @@ -229,9 +259,10 @@ def _build_mapping( label2 = labels2[sample_idx] mapping[label1] = label2 - @dpex.kernel + @dpex_exp.kernel # fmt: off def _is_same_clustering( + nd_item: NdItem, labels1, labels2, mapping, @@ -241,7 +272,7 @@ def _is_same_clustering( """`result` is expected to be an empty array with dtype int32 of size (1,) initialized with value 1. """ - sample_idx = dpex.get_global_id(zero_idx) + sample_idx = nd_item.get_global_id(zero_idx) if sample_idx >= n_samples: return @@ -259,9 +290,11 @@ def make_get_nb_distinct_clusters_kernel( ): one_incr = np.int32(1) - @dpex.kernel - def get_nb_distinct_clusters(labels, clusters_seen, nb_distinct_clusters): - sample_idx = dpex.get_global_id(zero_idx) + @dpex_exp.kernel + def get_nb_distinct_clusters( + nd_item: NdItem, labels, clusters_seen, nb_distinct_clusters + ): + sample_idx = nd_item.get_global_id(zero_idx) if sample_idx >= n_samples: return @@ -271,9 +304,15 @@ def get_nb_distinct_clusters(labels, clusters_seen, nb_distinct_clusters): if clusters_seen[label] > zero_idx: return - previous_value = dpex.atomic.add(clusters_seen, label, one_incr) + previous_value = AtomicRef(clusters_seen, label).fetch_add(one_incr) if previous_value == zero_idx: - dpex.atomic.add(nb_distinct_clusters, zero_idx, one_incr) + AtomicRef(nb_distinct_clusters, zero_idx).fetch_add(one_incr) global_size = math.ceil(n_samples / work_group_size) * work_group_size - return get_nb_distinct_clusters[global_size, work_group_size] + + def kernel_call(*args): + dpex_exp.call_kernel( + get_nb_distinct_clusters, NdRange((global_size,), (work_group_size,)), *args + ) + + return kernel_call diff --git a/sklearn_numba_dpex/patches/tests/test_patches.py b/sklearn_numba_dpex/patches/tests/test_patches.py index e95e94f..ec4fda2 100644 --- a/sklearn_numba_dpex/patches/tests/test_patches.py +++ b/sklearn_numba_dpex/patches/tests/test_patches.py @@ -1,8 +1,10 @@ import dpctl import dpctl.tensor as dpt import numba_dpex as dpex +import numba_dpex.experimental as dpex_exp import numpy as np import pytest +from numba_dpex.kernel_api import MemoryScope, NdItem, NdRange, group_barrier # TODO: remove this test after going through the code base and reverting unnecessary @@ -26,28 +28,28 @@ def test_need_to_workaround_numba_dpex_906(): The hack consist in wrapping instructions that are suspected of triggering the bug (basically all write operations in kernels that also contain a barrier) in - `dpex.func` device functions. + `dpex_exp.device_func` device functions. This hack makes the code significantly harder to read and should be reverted ASAP. """ dtype = np.float32 - @dpex.kernel - def kernel(result): - local_idx = dpex.get_local_id(0) + @dpex_exp.kernel + def kernel(nd_item: NdItem, result): + local_idx = nd_item.get_local_id(0) local_values = dpex.local.array((1,), dtype=dtype) if local_idx < 1: local_values[0] = 1 - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) if local_idx < 1: result[0] = 10 result = dpt.zeros((1), dtype=dtype, device=dpctl.SyclDevice("cpu")) - kernel[32, 32](result) + dpex_exp.call_kernel(kernel, NdRange((32,), (32,)), result) rationale = """If this test fails, it means that the bug reported at https://github.com/IntelPython/numba-dpex/issues/906 has been fixed, and all the @@ -58,28 +60,28 @@ def kernel(result): assert dpt.asnumpy(result)[0] != 10, rationale # Test that highlight how the hack works - @dpex.kernel - def kernel(result): - local_idx = dpex.get_local_id(0) + @dpex_exp.kernel + def kernel(nd_item: NdItem, result): + local_idx = nd_item.get_local_id(0) local_values = dpex.local.array((1,), dtype=dtype) _setitem_if((local_idx < 1), 0, 1, local_values) - dpex.barrier(dpex.LOCAL_MEM_FENCE) + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) _setitem_if((local_idx < 1), 0, 10, result) _setitem_if = make_setitem_if_kernel_func() result = dpt.zeros((1), dtype=dtype, device=dpctl.SyclDevice("cpu")) - kernel[32, 32](result) + dpex_exp.call_kernel(kernel, NdRange((32,), (32,)), result) assert dpt.asnumpy(result)[0] == 10 # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_need_to_workaround_numba_dpex_906 # noqa def make_setitem_if_kernel_func(): - @dpex.func + @dpex_exp.device_func def _setitem_if(condition, index, value, array): if condition: array[index] = value