@@ -349,26 +349,29 @@ void LaunchBmmCustomKernel(OpKernelContext* ctx, const T* A, const T* B, T* C,
349349 sycl::range<3 > local{1 , BS_X, BS_Y};
350350 Tensor A_offset_tensor, B_offset_tensor;
351351
352+ if (src_dims > 3 && is_bcast_required) {
353+ const std::vector<int64_t >& x_batch_indices = bcast.x_batch_indices ();
354+ const std::vector<int64_t >& y_batch_indices = bcast.y_batch_indices ();
355+ OP_REQUIRES_OK (ctx,
356+ ctx->allocate_temp (DataTypeToEnum<int64_t >::value,
357+ TensorShape ({bs}), &A_offset_tensor));
358+ OP_REQUIRES_OK (ctx,
359+ ctx->allocate_temp (DataTypeToEnum<int64_t >::value,
360+ TensorShape ({bs}), &B_offset_tensor));
361+ stream
362+ ->memcpy (GetTensorBuffer<int64_t >(&A_offset_tensor),
363+ x_batch_indices.data (), bs * sizeof (int64_t ))
364+ .wait ();
365+ stream
366+ ->memcpy (GetTensorBuffer<int64_t >(&B_offset_tensor),
367+ y_batch_indices.data (), bs * sizeof (int64_t ))
368+ .wait ();
369+ }
370+
352371 stream->submit ([&](sycl::handler& cgh) {
353372 LocalAcc<T> Asub (sycl::range<2 >{c_M * BS_X, TILE_K}, cgh);
354373 LocalAcc<T> Bsub (sycl::range<2 >{TILE_K, c_P * BS_Y}, cgh);
355374 if (src_dims > 3 && is_bcast_required) {
356- const std::vector<int64_t >& x_batch_indices = bcast.x_batch_indices ();
357- const std::vector<int64_t >& y_batch_indices = bcast.y_batch_indices ();
358- OP_REQUIRES_OK (ctx,
359- ctx->allocate_temp (DataTypeToEnum<int64_t >::value,
360- TensorShape ({bs}), &A_offset_tensor));
361- OP_REQUIRES_OK (ctx,
362- ctx->allocate_temp (DataTypeToEnum<int64_t >::value,
363- TensorShape ({bs}), &B_offset_tensor));
364- stream
365- ->memcpy (GetTensorBuffer<int64_t >(&A_offset_tensor),
366- x_batch_indices.data (), bs * sizeof (int64_t ))
367- .wait ();
368- stream
369- ->memcpy (GetTensorBuffer<int64_t >(&B_offset_tensor),
370- y_batch_indices.data (), bs * sizeof (int64_t ))
371- .wait ();
372375 BatchMatMulWithBcastKernel<T, c_M, c_P, BS_X, BS_Y, TILE_K, TILE_AB> task (
373376 A, B, C, bs, M, N, P, Asub, Bsub, adj_A, adj_B,
374377 static_cast <int64_t *>(GetTensorBuffer<int64_t >(&A_offset_tensor)),
0 commit comments