Skip to content

WIP: Feature/upgrade dpex syntax #123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 50 additions & 25 deletions sklearn_numba_dpex/common/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,26 @@
from functools import lru_cache

import dpctl.tensor as dpt
import numba_dpex as dpex
import numba_dpex.experimental as dpex_exp
import numpy as np
from numba_dpex.kernel_api import NdItem, NdRange

zero_idx = np.int64(0)


@lru_cache
def make_apply_elementwise_func(shape, func, work_group_size):
func = dpex.func(func)
func = dpex_exp.device_func(func)
n_items = math.prod(shape)

@dpex.kernel
@dpex_exp.kernel
# fmt: off
def elementwise_ops_kernel(
nd_item: NdItem,
data, # INOUT (n_items,)
):
# fmt: on
item_idx = dpex.get_global_id(zero_idx)
item_idx = nd_item.get_global_id(zero_idx)
if item_idx >= n_items:
return

Expand All @@ -30,7 +32,9 @@ def elementwise_ops_kernel(

def elementwise_ops(data):
data = dpt.reshape(data, (-1,))
elementwise_ops_kernel[global_size, work_group_size](data)
dpex_exp.call_kernel(
elementwise_ops_kernel, NdRange((global_size,), (work_group_size,)), data
)

return elementwise_ops

Expand All @@ -41,9 +45,9 @@ def make_initialize_to_zeros_kernel(shape, work_group_size, dtype):
global_size = math.ceil(n_items / work_group_size) * work_group_size
zero = dtype(0.0)

@dpex.kernel
def initialize_to_zeros_kernel(data):
item_idx = dpex.get_global_id(zero_idx)
@dpex_exp.kernel
def initialize_to_zeros_kernel(nd_item: NdItem, data):
item_idx = nd_item.get_global_id(zero_idx)

if item_idx >= n_items:
return
Expand All @@ -52,7 +56,11 @@ def initialize_to_zeros_kernel(data):

def initialize_to_zeros(data):
data = dpt.reshape(data, (-1,))
initialize_to_zeros_kernel[global_size, work_group_size](data)
dpex_exp.call_kernel(
initialize_to_zeros_kernel,
NdRange((global_size,), (work_group_size,)),
data,
)

return initialize_to_zeros

Expand All @@ -65,9 +73,9 @@ def make_broadcast_division_1d_2d_axis0_kernel(shape, work_group_size):
# NB: the left operand is modified inplace, the right operand is only read into.
# Optimized for C-contiguous array and for
# size1 >> preferred_work_group_size_multiple
@dpex.kernel
def broadcast_division(dividend_array, divisor_vector):
col_idx = dpex.get_global_id(zero_idx)
@dpex_exp.kernel
def broadcast_division(nd_item: NdItem, dividend_array, divisor_vector):
col_idx = nd_item.get_global_id(zero_idx)

if col_idx >= n_cols:
return
Expand All @@ -79,28 +87,34 @@ def broadcast_division(dividend_array, divisor_vector):
dividend_array[row_idx, col_idx] / divisor
)

return broadcast_division[global_size, work_group_size]
def kernel_call(*args):
return dpex_exp.call_kernel(
broadcast_division, NdRange((global_size,), (work_group_size,)), *args
)

return kernel_call


@lru_cache
def make_broadcast_ops_1d_2d_axis1_kernel(shape, ops, work_group_size):
"""
ops must be a function that will be interpreted as a dpex.func and is subject to
the same rules. It is expected to take two scalar arguments and return one scalar
value. lambda functions are advised against since the cache will not work with lamda
functions. sklearn_numba_dpex.common._utils expose some pre-defined `ops`.
ops must be a function that will be interpreted as a dpex_exp.device_func and is
subject to the same rules. It is expected to take two scalar arguments and return
one scalar value. lambda functions are advised against since the cache will not
work with lamda functions. sklearn_numba_dpex.common._utils expose some
pre-defined `ops`.
"""
n_rows, n_cols = shape

global_size = math.ceil(n_cols / work_group_size) * work_group_size
ops = dpex.func(ops)
ops = dpex_exp.device_func(ops)

# NB: the left operand is modified inplace, the right operand is only read into.
# Optimized for C-contiguous array and for
# size1 >> preferred_work_group_size_multiple
@dpex.kernel
def broadcast_ops(left_operand_array, right_operand_vector):
col_idx = dpex.get_global_id(zero_idx)
@dpex_exp.kernel
def broadcast_ops(nd_item: NdItem, left_operand_array, right_operand_vector):
col_idx = nd_item.get_global_id(zero_idx)

if col_idx >= n_cols:
return
Expand All @@ -110,7 +124,12 @@ def broadcast_ops(left_operand_array, right_operand_vector):
left_operand_array[row_idx, col_idx], right_operand_vector[row_idx]
)

return broadcast_ops[global_size, work_group_size]
def kernel_call(*args):
return dpex_exp.call_kernel(
broadcast_ops, NdRange((global_size,), (work_group_size,)), *args
)

return kernel_call


@lru_cache
Expand All @@ -122,14 +141,15 @@ def make_half_l2_norm_2d_axis0_kernel(shape, work_group_size, dtype):

# Optimized for C-contiguous array and for
# size1 >> preferred_work_group_size_multiple
@dpex.kernel
@dpex_exp.kernel
# fmt: off
def half_l2_norm(
nd_item: NdItem,
data, # IN (size0, size1)
result, # OUT (size1,)
):
# fmt: on
col_idx = dpex.get_global_id(zero_idx)
col_idx = nd_item.get_global_id(zero_idx)

if col_idx >= n_cols:
return
Expand All @@ -142,4 +162,9 @@ def half_l2_norm(

result[col_idx] = l2_norm / two

return half_l2_norm[global_size, work_group_size]
def kernel_call(*args):
return dpex_exp.call_kernel(
half_l2_norm, NdRange((global_size,), (work_group_size,)), *args
)

return kernel_call
42 changes: 24 additions & 18 deletions sklearn_numba_dpex/common/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from functools import lru_cache

import numba_dpex as dpex
import numba_dpex.experimental as dpex_exp
import numpy as np
from numba_dpex.kernel_api import MemoryScope, NdItem, NdRange, group_barrier

from sklearn_numba_dpex.common._utils import _enforce_matmul_like_work_group_geometry

Expand Down Expand Up @@ -54,7 +56,7 @@ def make_matmul_2d_kernel(
If `out_fused_elementwise_fn` is not None, it will be applied element-wise once to
each element of the output array right before it is returned. This function is
compiled and fused into the first kernel as a device function with the help of
`dpex.func`. This comes with limitations as explained in:
`dpex_exp.device_func`. This comes with limitations as explained in:

https://intelpython.github.io/numba-dpex/latest/user_guides/kernel_programming_guide/device-functions.html # noqa

Expand Down Expand Up @@ -147,21 +149,21 @@ def make_matmul_2d_kernel(

if multiply_fn is None:

@dpex.func
@dpex_exp.device_func
def multiply_fn_(x, y):
return x * y

else:
multiply_fn_ = dpex.func(multiply_fn)
multiply_fn_ = dpex_exp.device_func(multiply_fn)

if out_fused_elementwise_fn is None:

@dpex.func
@dpex_exp.device_func
def out_fused_elementwise_fn(x):
return x

else:
out_fused_elementwise_fn = dpex.func(out_fused_elementwise_fn)
out_fused_elementwise_fn = dpex_exp.device_func(out_fused_elementwise_fn)

# Under the same assumption, this value is equal to the number of results computed
# by a single work group
Expand Down Expand Up @@ -268,21 +270,22 @@ def out_fused_elementwise_fn(x):
grid_n_groups = global_grid_n_rows * global_grid_n_cols
global_size = grid_n_groups * work_group_size

@dpex.kernel
@dpex_exp.kernel
# fmt: off
def matmul(
nd_item: NdItem,
X, # IN (X_n_rows, n_cols)
Y_t, # IN (Y_t_n_rows, n_cols)
result # OUT (X_n_rows, Y_t_n_rows)
):
# fmt: on
work_item_idx = dpex.get_local_id(zero_idx)
work_item_idx = nd_item.get_local_id(zero_idx)

# Index the work items in the base sliding window:
work_item_row_idx = work_item_idx // sub_group_size
work_item_col_idx = work_item_idx % sub_group_size

group_idx = dpex.get_group_id(zero_idx)
group_idx = nd_item.get_group().get_group_id(zero_idx)

# Get indices of the row and the column of the top-left corner of the sub-array
# of results covered by this work group
Expand Down Expand Up @@ -352,7 +355,7 @@ def matmul(
)
window_loaded_col_idx += sub_group_size

dpex.barrier(dpex.LOCAL_MEM_FENCE)
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)

_accumulate_private_windows(
first_private_loaded_sliding_X_value_idx,
Expand All @@ -365,7 +368,7 @@ def matmul(
private_result
)

dpex.barrier(dpex.LOCAL_MEM_FENCE)
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)

_write_result(
group_first_row_idx + first_private_loaded_sliding_X_value_idx,
Expand All @@ -376,7 +379,7 @@ def matmul(
)

# HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_need_to_workaround_numba_dpex_906 # noqa
@dpex.func
@dpex_exp.device_func
# fmt: off
def _load_sliding_windows(
work_item_row_idx, # PARAM
Expand Down Expand Up @@ -419,7 +422,7 @@ def _load_sliding_windows(
Y_t_loaded_row_idx += base_result_window_side
Y_t_local_loaded_row_idx += base_result_window_side

@dpex.func
@dpex_exp.device_func
# fmt: off
def _accumulate_private_windows(
private_first_loaded_sliding_X_value_idx, # PARAM
Expand Down Expand Up @@ -465,7 +468,7 @@ def _accumulate_private_windows(

private_array_first_col += private_Y_t_sliding_window_width

@dpex.func
@dpex_exp.device_func
# fmt: off
def _write_result(
result_first_row_idx, # PARAM
Expand All @@ -486,14 +489,17 @@ def _write_result(
result_col_idx += nb_work_items_for_Y_t_window
result_row_idx += nb_work_items_for_X_window

return matmul[global_size, work_group_size]
def kernel_call(*args):
dpex_exp.call_kernel(matmul, NdRange((global_size,), (work_group_size,)), *args)

return kernel_call


def _make_accumulate_step_unrolled_kernel_func(private_result_array_width, multiply_fn):

if private_result_array_width == 1:

@dpex.func
@dpex_exp.device_func
def _accumulate_step_unrolled(
i, j, private_loaded_X_value, private_Y_t_sliding_window, private_result
):
Expand All @@ -503,7 +509,7 @@ def _accumulate_step_unrolled(

elif private_result_array_width == 2:

@dpex.func
@dpex_exp.device_func
def _accumulate_step_unrolled(
i, j, private_loaded_X_value, private_Y_t_sliding_window, private_result
):
Expand All @@ -516,7 +522,7 @@ def _accumulate_step_unrolled(

elif private_result_array_width == 4:

@dpex.func
@dpex_exp.device_func
def _accumulate_step_unrolled(
i, j, private_loaded_X_value, private_Y_t_sliding_window, private_result
):
Expand All @@ -535,7 +541,7 @@ def _accumulate_step_unrolled(

elif private_result_array_width == 8:

@dpex.func
@dpex_exp.device_func
def _accumulate_step_unrolled(
i, j, private_loaded_X_value, private_Y_t_sliding_window, private_result
):
Expand Down
Loading