Skip to content

Commit acbabef

Browse files
[SYCL][NFC] Pass around device_impl instead of sycl::device object (#18928)
This PR makes two NFC changes: 1. Renamed `isSpecialDeviceImage` -> `isBfloat16DeviceImage` and `isSpecialDeviceImageShouldBeUsed` -> `shouldBF16DeviceImageBeUsed`. 2. Reduce passing around `sycl::device` object and use `device_impl` instead, at least for functions in `detail::`
1 parent 629d3a3 commit acbabef

File tree

6 files changed

+66
-62
lines changed

6 files changed

+66
-62
lines changed

sycl/source/detail/device_image_impl.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,11 +1050,12 @@ class device_image_impl {
10501050

10511051
// Filter the devices that support the image requirements.
10521052
std::vector<sycl::device> SupportingDevs = Devices;
1053-
auto NewSupportingDevsEnd = std::remove_if(
1054-
SupportingDevs.begin(), SupportingDevs.end(),
1055-
[&NewImageRef](const sycl::device &SDev) {
1056-
return !doesDevSupportDeviceRequirements(SDev, NewImageRef);
1057-
});
1053+
auto NewSupportingDevsEnd =
1054+
std::remove_if(SupportingDevs.begin(), SupportingDevs.end(),
1055+
[&NewImageRef](const sycl::device &SDev) {
1056+
return !doesDevSupportDeviceRequirements(
1057+
*detail::getSyclObjImpl(SDev), NewImageRef);
1058+
});
10581059

10591060
// If there are no devices that support the image, we skip it.
10601061
if (NewSupportingDevsEnd == SupportingDevs.begin())
@@ -1151,7 +1152,7 @@ class device_image_impl {
11511152
std::set<RTDeviceBinaryImage *> ImgDeps;
11521153
for (const device &Device : DevImgImpl->get_devices()) {
11531154
std::set<RTDeviceBinaryImage *> DevImgDeps = PM.collectDeviceImageDeps(
1154-
*NewImage, Device,
1155+
*NewImage, *getSyclObjImpl(Device),
11551156
/*ErrorOnUnresolvableImport=*/State == bundle_state::executable);
11561157
ImgDeps.insert(DevImgDeps.begin(), DevImgDeps.end());
11571158
}

sycl/source/detail/helpers.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ retrieveKernelBinary(queue_impl &Queue, KernelNameStrRefT KernelName,
8282
} else {
8383
auto ContextImpl = Queue.getContextImplPtr();
8484
DeviceImage = &detail::ProgramManager::getInstance().getDeviceImage(
85-
KernelName, *ContextImpl, &Dev);
85+
KernelName, *ContextImpl, Dev);
8686
Program = detail::ProgramManager::getInstance().createURProgram(
8787
*DeviceImage, *ContextImpl, {createSyclObjFromImpl<device>(Dev)});
8888
}

sycl/source/detail/memory_manager.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1146,7 +1146,8 @@ getOrBuildProgramForDeviceGlobal(queue_impl &Queue,
11461146
auto Context = createSyclObjFromImpl<context>(ContextImpl);
11471147
ProgramManager &PM = ProgramManager::getInstance();
11481148
RTDeviceBinaryImage &Img = PM.getDeviceImage(
1149-
DeviceGlobalEntry->MImages, ContextImpl, getSyclObjImpl(Device).get());
1149+
DeviceGlobalEntry->MImages, ContextImpl, *getSyclObjImpl(Device));
1150+
11501151
device_image_plain DeviceImage =
11511152
PM.getDeviceImageFromBinaryImage(&Img, Context, Device);
11521153
device_image_plain BuiltImage =

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -610,8 +610,8 @@ static bool compatibleWithDevice(RTDeviceBinaryImage *BinImage,
610610
return (0 == SuitableImageID);
611611
}
612612

613-
// Quick check to see whether BinImage is a compiler-generated device image.
614-
bool ProgramManager::isSpecialDeviceImage(RTDeviceBinaryImage *BinImage) {
613+
// Check if the device image is a BF16 devicelib image.
614+
bool ProgramManager::isBfloat16DeviceImage(RTDeviceBinaryImage *BinImage) {
615615
// SYCL devicelib image.
616616
if ((m_Bfloat16DeviceLibImages[0].get() == BinImage) ||
617617
m_Bfloat16DeviceLibImages[1].get() == BinImage)
@@ -620,7 +620,9 @@ bool ProgramManager::isSpecialDeviceImage(RTDeviceBinaryImage *BinImage) {
620620
return false;
621621
}
622622

623-
bool ProgramManager::isSpecialDeviceImageShouldBeUsed(
623+
// Check if device natively support BF16 conversion and accordingly
624+
// decide whether to use fallback or native BF16 devicelib image.
625+
bool ProgramManager::shouldBF16DeviceImageBeUsed(
624626
RTDeviceBinaryImage *BinImage, const device_impl &DeviceImpl) {
625627
// Decide whether a devicelib image should be used.
626628
int Bfloat16DeviceLibVersion = -1;
@@ -672,7 +674,7 @@ static bool checkLinkingSupport(const device_impl &DeviceImpl,
672674

673675
std::set<RTDeviceBinaryImage *>
674676
ProgramManager::collectDeviceImageDeps(const RTDeviceBinaryImage &Img,
675-
const device &Dev,
677+
const device_impl &Dev,
676678
bool ErrorOnUnresolvableImport) {
677679
// TODO collecting dependencies for virtual functions and imported symbols
678680
// should be combined since one can lead to new unresolved dependencies for
@@ -698,7 +700,7 @@ CheckAndDecompressImage([[maybe_unused]] RTDeviceBinaryImage *Img) {
698700

699701
std::set<RTDeviceBinaryImage *>
700702
ProgramManager::collectDeviceImageDepsForImportedSymbols(
701-
const RTDeviceBinaryImage &MainImg, const device &Dev,
703+
const RTDeviceBinaryImage &MainImg, const device_impl &Dev,
702704
bool ErrorOnUnresolvableImport) {
703705
std::set<RTDeviceBinaryImage *> DeviceImagesToLink;
704706
std::set<std::string> HandledSymbols;
@@ -709,8 +711,7 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
709711
HandledSymbols.insert(ISProp->Name);
710712
}
711713
ur::DeviceBinaryType Format = MainImg.getFormat();
712-
if (!WorkList.empty() &&
713-
!checkLinkingSupport(*getSyclObjImpl(Dev).get(), MainImg))
714+
if (!WorkList.empty() && !checkLinkingSupport(Dev, MainImg))
714715
throw exception(make_error_code(errc::feature_not_supported),
715716
"Cannot resolve external symbols, linking is unsupported "
716717
"for the backend");
@@ -724,13 +725,12 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
724725
RTDeviceBinaryImage *Img = It->second;
725726

726727
if (!doesDevSupportDeviceRequirements(Dev, *Img) ||
727-
!compatibleWithDevice(Img, *getSyclObjImpl(Dev).get()))
728+
!compatibleWithDevice(Img, Dev))
728729
continue;
729730

730-
// If the image is a special device image, we need to check if it
731+
// If the image is a BF16 device image, we need to check if it
731732
// should be used for this device.
732-
if (isSpecialDeviceImage(Img) &&
733-
!isSpecialDeviceImageShouldBeUsed(Img, *getSyclObjImpl(Dev).get()))
733+
if (isBfloat16DeviceImage(Img) && !shouldBF16DeviceImageBeUsed(Img, Dev))
734734
continue;
735735

736736
// If any of the images is compressed, we need to decompress it
@@ -766,7 +766,7 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
766766

767767
std::set<RTDeviceBinaryImage *>
768768
ProgramManager::collectDependentDeviceImagesForVirtualFunctions(
769-
const RTDeviceBinaryImage &Img, const device &Dev) {
769+
const RTDeviceBinaryImage &Img, const device_impl &Dev) {
770770
// If virtual functions are used in a program, then we need to link several
771771
// device images together to make sure that vtable pointers stored in
772772
// objects are valid between different kernels (which could be in different
@@ -890,17 +890,19 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
890890
sizeof(ur_bool_t), &MustBuildOnSubdevice, nullptr);
891891
}
892892

893-
device Device = createSyclObjFromImpl<device>(
894-
MustBuildOnSubdevice == true ? DeviceImpl : *RootDevImpl);
893+
device_impl &RootOrSubDevImpl =
894+
MustBuildOnSubdevice == true ? DeviceImpl : *RootDevImpl;
895+
895896
const RTDeviceBinaryImage &Img =
896-
getDeviceImage(KernelName, ContextImpl, getSyclObjImpl(Device).get());
897+
getDeviceImage(KernelName, ContextImpl, RootOrSubDevImpl);
897898

898899
// Check that device supports all aspects used by the kernel
899-
if (auto exception = checkDevSupportDeviceRequirements(Device, Img, NDRDesc))
900+
if (auto exception =
901+
checkDevSupportDeviceRequirements(RootOrSubDevImpl, Img, NDRDesc))
900902
throw *exception;
901903

902904
std::set<RTDeviceBinaryImage *> DeviceImagesToLink =
903-
collectDeviceImageDeps(Img, {Device});
905+
collectDeviceImageDeps(Img, {RootOrSubDevImpl});
904906

905907
// Decompress all DeviceImagesToLink
906908
for (RTDeviceBinaryImage *BinImg : DeviceImagesToLink)
@@ -913,7 +915,7 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
913915
std::back_inserter(AllImages));
914916

915917
return getBuiltURProgram(std::move(AllImages), ContextImpl,
916-
{std::move(Device)});
918+
{createSyclObjFromImpl<device>(RootOrSubDevImpl)});
917919
}
918920

919921
ur_program_handle_t ProgramManager::getBuiltURProgram(
@@ -1483,9 +1485,9 @@ ProgramManager::ProgramManager()
14831485
}
14841486
}
14851487

1486-
const char *getArchName(const device_impl *DeviceImpl) {
1488+
const char *getArchName(const device_impl &DeviceImpl) {
14871489
namespace syclex = sycl::ext::oneapi::experimental;
1488-
auto Arch = DeviceImpl->get_info<syclex::info::device::architecture>();
1490+
auto Arch = DeviceImpl.get_info<syclex::info::device::architecture>();
14891491
switch (Arch) {
14901492
#define __SYCL_ARCHITECTURE(ARCH, VAL) \
14911493
case syclex::architecture::ARCH: \
@@ -1507,7 +1509,7 @@ template <typename StorageKey>
15071509
RTDeviceBinaryImage *getBinImageFromMultiMap(
15081510
const std::unordered_multimap<StorageKey, RTDeviceBinaryImage *> &ImagesSet,
15091511
const StorageKey &Key, context_impl &ContextImpl,
1510-
const device_impl *DeviceImpl) {
1512+
const device_impl &DeviceImpl) {
15111513
auto [ItBegin, ItEnd] = ImagesSet.equal_range(Key);
15121514
if (ItBegin == ItEnd)
15131515
return nullptr;
@@ -1538,18 +1540,17 @@ RTDeviceBinaryImage *getBinImageFromMultiMap(
15381540
// Ask the native runtime under the given context to choose the device image
15391541
// it prefers.
15401542
ContextImpl.getAdapter()->call<UrApiKind::urDeviceSelectBinary>(
1541-
DeviceImpl->getHandleRef(), UrBinaries.data(), UrBinaries.size(),
1542-
&ImgInd);
1543+
DeviceImpl.getHandleRef(), UrBinaries.data(), UrBinaries.size(), &ImgInd);
15431544
return DeviceFilteredImgs[ImgInd];
15441545
}
15451546

15461547
RTDeviceBinaryImage &
15471548
ProgramManager::getDeviceImage(KernelNameStrRefT KernelName,
15481549
context_impl &ContextImpl,
1549-
const device_impl *DeviceImpl) {
1550+
const device_impl &DeviceImpl) {
15501551
if constexpr (DbgProgMgr > 0) {
15511552
std::cerr << ">>> ProgramManager::getDeviceImage(\"" << KernelName << "\", "
1552-
<< ContextImpl.get() << ", " << DeviceImpl << ")\n";
1553+
<< ContextImpl.get() << ", " << &DeviceImpl << ")\n";
15531554

15541555
std::cerr << "available device images:\n";
15551556
debugPrintBinaryImages();
@@ -1592,12 +1593,12 @@ ProgramManager::getDeviceImage(KernelNameStrRefT KernelName,
15921593

15931594
RTDeviceBinaryImage &ProgramManager::getDeviceImage(
15941595
const std::unordered_set<RTDeviceBinaryImage *> &ImageSet,
1595-
context_impl &ContextImpl, const device_impl *DeviceImpl) {
1596+
context_impl &ContextImpl, const device_impl &DeviceImpl) {
15961597
assert(ImageSet.size() > 0);
15971598

15981599
if constexpr (DbgProgMgr > 0) {
15991600
std::cerr << ">>> ProgramManager::getDeviceImage(Custom SPV file "
1600-
<< ContextImpl.get() << ", " << DeviceImpl << ")\n";
1601+
<< ContextImpl.get() << ", " << &DeviceImpl << ")\n";
16011602

16021603
std::cerr << "available device images:\n";
16031604
debugPrintBinaryImages();
@@ -1620,8 +1621,7 @@ RTDeviceBinaryImage &ProgramManager::getDeviceImage(
16201621
}
16211622

16221623
ContextImpl.getAdapter()->call<UrApiKind::urDeviceSelectBinary>(
1623-
DeviceImpl->getHandleRef(), UrBinaries.data(), UrBinaries.size(),
1624-
&ImgInd);
1624+
DeviceImpl.getHandleRef(), UrBinaries.data(), UrBinaries.size(), &ImgInd);
16251625

16261626
ImageIterator = ImageSet.begin();
16271627
std::advance(ImageIterator, ImgInd);
@@ -2646,6 +2646,8 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
26462646
std::unordered_map<RTDeviceBinaryImage *, DeviceBinaryImageInfo> ImageInfoMap;
26472647

26482648
for (const sycl::device &Dev : Devs) {
2649+
2650+
device_impl &DevImpl = *getSyclObjImpl(Dev);
26492651
// Track the highest image state for each requested kernel.
26502652
using StateImagesPairT =
26512653
std::pair<bundle_state, std::vector<RTDeviceBinaryImage *>>;
@@ -2657,8 +2659,8 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
26572659
KernelImageMap.insert({KernelID, {}});
26582660

26592661
for (RTDeviceBinaryImage *BinImage : BinImages) {
2660-
if (!compatibleWithDevice(BinImage, *getSyclObjImpl(Dev).get()) ||
2661-
!doesDevSupportDeviceRequirements(Dev, *BinImage))
2662+
if (!compatibleWithDevice(BinImage, DevImpl) ||
2663+
!doesDevSupportDeviceRequirements(DevImpl, *BinImage))
26622664
continue;
26632665

26642666
auto InsertRes = ImageInfoMap.insert({BinImage, {}});
@@ -2670,7 +2672,7 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
26702672
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
26712673
ImgInfo.KernelIDs = m_BinImg2KernelIDs[BinImage];
26722674
}
2673-
ImgInfo.Deps = collectDeviceImageDeps(*BinImage, {Dev});
2675+
ImgInfo.Deps = collectDeviceImageDeps(*BinImage, {DevImpl});
26742676
}
26752677
const bundle_state ImgState = ImgInfo.State;
26762678
const std::shared_ptr<std::vector<sycl::kernel_id>> &ImageKernelIDs =
@@ -3366,7 +3368,7 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
33663368
return UrKernel;
33673369
}
33683370

3369-
bool doesDevSupportDeviceRequirements(const device &Dev,
3371+
bool doesDevSupportDeviceRequirements(const device_impl &Dev,
33703372
const RTDeviceBinaryImage &Img) {
33713373
return !checkDevSupportDeviceRequirements(Dev, Img).has_value();
33723374
}
@@ -3641,7 +3643,7 @@ std::optional<sycl::exception> checkDevSupportJointMatrixMad(
36413643
}
36423644

36433645
std::optional<sycl::exception>
3644-
checkDevSupportDeviceRequirements(const device &Dev,
3646+
checkDevSupportDeviceRequirements(const device_impl &Dev,
36453647
const RTDeviceBinaryImage &Img,
36463648
const NDRDescT &NDRDesc) {
36473649
auto getPropIt = [&Img](const std::string &PropName) {
@@ -3854,29 +3856,29 @@ checkDevSupportDeviceRequirements(const device &Dev,
38543856
}
38553857

38563858
bool doesImageTargetMatchDevice(const RTDeviceBinaryImage &Img,
3857-
const device_impl *DevImpl) {
3859+
const device_impl &DevImpl) {
38583860
auto PropRange = Img.getDeviceRequirements();
38593861
auto PropIt =
38603862
std::find_if(PropRange.begin(), PropRange.end(), [&](const auto &Prop) {
38613863
return Prop->Name == std::string_view("compile_target");
38623864
});
38633865
// Device image has no compile_target property, check target.
38643866
if (PropIt == PropRange.end()) {
3865-
sycl::backend BE = DevImpl->getBackend();
3867+
sycl::backend BE = DevImpl.getBackend();
38663868
const char *Target = Img.getRawData().DeviceTargetSpec;
38673869
if (strcmp(Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64) == 0) {
38683870
return (BE == sycl::backend::opencl ||
38693871
BE == sycl::backend::ext_oneapi_level_zero);
38703872
}
38713873
if (strcmp(Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_X86_64) == 0) {
3872-
return DevImpl->is_cpu();
3874+
return DevImpl.is_cpu();
38733875
}
38743876
if (strcmp(Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_GEN) == 0) {
3875-
return DevImpl->is_gpu() && (BE == sycl::backend::opencl ||
3876-
BE == sycl::backend::ext_oneapi_level_zero);
3877+
return DevImpl.is_gpu() && (BE == sycl::backend::opencl ||
3878+
BE == sycl::backend::ext_oneapi_level_zero);
38773879
}
38783880
if (strcmp(Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_FPGA) == 0) {
3879-
return DevImpl->is_accelerator();
3881+
return DevImpl.is_accelerator();
38803882
}
38813883
if (strcmp(Target, __SYCL_DEVICE_BINARY_TARGET_NVPTX64) == 0 ||
38823884
strcmp(Target, __SYCL_DEVICE_BINARY_TARGET_LLVM_NVPTX64) == 0) {

sycl/source/detail/program_manager/program_manager.hpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,15 @@ inline namespace _V1 {
5656
class context;
5757
namespace detail {
5858

59-
bool doesDevSupportDeviceRequirements(const device &Dev,
59+
bool doesDevSupportDeviceRequirements(const device_impl &Dev,
6060
const RTDeviceBinaryImage &BinImages);
6161
std::optional<sycl::exception>
62-
checkDevSupportDeviceRequirements(const device &Dev,
62+
checkDevSupportDeviceRequirements(const device_impl &Dev,
6363
const RTDeviceBinaryImage &BinImages,
6464
const NDRDescT &NDRDesc = {});
6565

6666
bool doesImageTargetMatchDevice(const RTDeviceBinaryImage &Img,
67-
const device_impl *DevImpl);
67+
const device_impl &DevImpl);
6868

6969
// This value must be the same as in libdevice/device_itt.h.
7070
// See sycl/doc/design/ITTAnnotations.md for more info.
@@ -136,11 +136,11 @@ class ProgramManager {
136136

137137
RTDeviceBinaryImage &getDeviceImage(KernelNameStrRefT KernelName,
138138
context_impl &ContextImpl,
139-
const device_impl *DeviceImpl);
139+
const device_impl &DeviceImpl);
140140

141141
RTDeviceBinaryImage &getDeviceImage(
142142
const std::unordered_set<RTDeviceBinaryImage *> &ImagesToVerify,
143-
context_impl &ContextImpl, const device_impl *DeviceImpl);
143+
context_impl &ContextImpl, const device_impl &DeviceImpl);
144144

145145
ur_program_handle_t createURProgram(const RTDeviceBinaryImage &Img,
146146
context_impl &ContextImpl,
@@ -380,11 +380,11 @@ class ProgramManager {
380380
getRawDeviceImages(const std::vector<kernel_id> &KernelIDs);
381381

382382
std::set<RTDeviceBinaryImage *>
383-
collectDeviceImageDeps(const RTDeviceBinaryImage &Img, const device &Dev,
383+
collectDeviceImageDeps(const RTDeviceBinaryImage &Img, const device_impl &Dev,
384384
bool ErrorOnUnresolvableImport = true);
385385
std::set<RTDeviceBinaryImage *>
386386
collectDeviceImageDepsForImportedSymbols(const RTDeviceBinaryImage &Img,
387-
const device &Dev,
387+
const device_impl &Dev,
388388
bool ErrorOnUnresolvableImport);
389389

390390
private:
@@ -412,11 +412,11 @@ class ProgramManager {
412412

413413
std::set<RTDeviceBinaryImage *>
414414
collectDependentDeviceImagesForVirtualFunctions(
415-
const RTDeviceBinaryImage &Img, const device &Dev);
415+
const RTDeviceBinaryImage &Img, const device_impl &Dev);
416416

417-
bool isSpecialDeviceImage(RTDeviceBinaryImage *BinImage);
418-
bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
419-
const device_impl &DeviceImpl);
417+
bool isBfloat16DeviceImage(RTDeviceBinaryImage *BinImage);
418+
bool shouldBF16DeviceImageBeUsed(RTDeviceBinaryImage *BinImage,
419+
const device_impl &DeviceImpl);
420420

421421
protected:
422422
/// The three maps below are used during kernel resolution. Any kernel is

sycl/source/kernel_bundle.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,15 +369,15 @@ bool is_compatible(const std::vector<kernel_id> &KernelIDs, const device &Dev) {
369369
// number of targets. This kernel is compatible with the device if there is
370370
// at least one image (containing this kernel) whose aspects are supported by
371371
// the device and whose target matches the device.
372+
detail::device_impl &DevImpl = *getSyclObjImpl(Dev);
372373
for (const auto &KernelID : KernelIDs) {
373374
std::set<detail::RTDeviceBinaryImage *> BinImages =
374375
detail::ProgramManager::getInstance().getRawDeviceImages({KernelID});
375376

376377
if (std::none_of(BinImages.begin(), BinImages.end(),
377378
[&](const detail::RTDeviceBinaryImage *Img) {
378-
return doesDevSupportDeviceRequirements(Dev, *Img) &&
379-
doesImageTargetMatchDevice(
380-
*Img, getSyclObjImpl(Dev).get());
379+
return doesDevSupportDeviceRequirements(DevImpl, *Img) &&
380+
doesImageTargetMatchDevice(*Img, DevImpl);
381381
}))
382382
return false;
383383
}

0 commit comments

Comments
 (0)