@@ -610,8 +610,8 @@ static bool compatibleWithDevice(RTDeviceBinaryImage *BinImage,
610
610
return (0 == SuitableImageID);
611
611
}
612
612
613
- // Quick check to see whether BinImage is a compiler-generated device image.
614
- bool ProgramManager::isSpecialDeviceImage (RTDeviceBinaryImage *BinImage) {
613
+ // Check if the device image is a BF16 devicelib image.
614
+ bool ProgramManager::isBfloat16DeviceImage (RTDeviceBinaryImage *BinImage) {
615
615
// SYCL devicelib image.
616
616
if ((m_Bfloat16DeviceLibImages[0 ].get () == BinImage) ||
617
617
m_Bfloat16DeviceLibImages[1 ].get () == BinImage)
@@ -620,7 +620,9 @@ bool ProgramManager::isSpecialDeviceImage(RTDeviceBinaryImage *BinImage) {
620
620
return false ;
621
621
}
622
622
623
- bool ProgramManager::isSpecialDeviceImageShouldBeUsed (
623
+ // Check if device natively support BF16 conversion and accordingly
624
+ // decide whether to use fallback or native BF16 devicelib image.
625
+ bool ProgramManager::shouldBF16DeviceImageBeUsed (
624
626
RTDeviceBinaryImage *BinImage, const device_impl &DeviceImpl) {
625
627
// Decide whether a devicelib image should be used.
626
628
int Bfloat16DeviceLibVersion = -1 ;
@@ -672,7 +674,7 @@ static bool checkLinkingSupport(const device_impl &DeviceImpl,
672
674
673
675
std::set<RTDeviceBinaryImage *>
674
676
ProgramManager::collectDeviceImageDeps (const RTDeviceBinaryImage &Img,
675
- const device &Dev,
677
+ const device_impl &Dev,
676
678
bool ErrorOnUnresolvableImport) {
677
679
// TODO collecting dependencies for virtual functions and imported symbols
678
680
// should be combined since one can lead to new unresolved dependencies for
@@ -698,7 +700,7 @@ CheckAndDecompressImage([[maybe_unused]] RTDeviceBinaryImage *Img) {
698
700
699
701
std::set<RTDeviceBinaryImage *>
700
702
ProgramManager::collectDeviceImageDepsForImportedSymbols (
701
- const RTDeviceBinaryImage &MainImg, const device &Dev,
703
+ const RTDeviceBinaryImage &MainImg, const device_impl &Dev,
702
704
bool ErrorOnUnresolvableImport) {
703
705
std::set<RTDeviceBinaryImage *> DeviceImagesToLink;
704
706
std::set<std::string> HandledSymbols;
@@ -709,8 +711,7 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
709
711
HandledSymbols.insert (ISProp->Name );
710
712
}
711
713
ur::DeviceBinaryType Format = MainImg.getFormat ();
712
- if (!WorkList.empty () &&
713
- !checkLinkingSupport (*getSyclObjImpl (Dev).get (), MainImg))
714
+ if (!WorkList.empty () && !checkLinkingSupport (Dev, MainImg))
714
715
throw exception (make_error_code (errc::feature_not_supported),
715
716
" Cannot resolve external symbols, linking is unsupported "
716
717
" for the backend" );
@@ -724,13 +725,12 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
724
725
RTDeviceBinaryImage *Img = It->second ;
725
726
726
727
if (!doesDevSupportDeviceRequirements (Dev, *Img) ||
727
- !compatibleWithDevice (Img, * getSyclObjImpl ( Dev). get () ))
728
+ !compatibleWithDevice (Img, Dev))
728
729
continue ;
729
730
730
- // If the image is a special device image, we need to check if it
731
+ // If the image is a BF16 device image, we need to check if it
731
732
// should be used for this device.
732
- if (isSpecialDeviceImage (Img) &&
733
- !isSpecialDeviceImageShouldBeUsed (Img, *getSyclObjImpl (Dev).get ()))
733
+ if (isBfloat16DeviceImage (Img) && !shouldBF16DeviceImageBeUsed (Img, Dev))
734
734
continue ;
735
735
736
736
// If any of the images is compressed, we need to decompress it
@@ -766,7 +766,7 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
766
766
767
767
std::set<RTDeviceBinaryImage *>
768
768
ProgramManager::collectDependentDeviceImagesForVirtualFunctions (
769
- const RTDeviceBinaryImage &Img, const device &Dev) {
769
+ const RTDeviceBinaryImage &Img, const device_impl &Dev) {
770
770
// If virtual functions are used in a program, then we need to link several
771
771
// device images together to make sure that vtable pointers stored in
772
772
// objects are valid between different kernels (which could be in different
@@ -890,17 +890,19 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
890
890
sizeof (ur_bool_t ), &MustBuildOnSubdevice, nullptr );
891
891
}
892
892
893
- device Device = createSyclObjFromImpl<device>(
894
- MustBuildOnSubdevice == true ? DeviceImpl : *RootDevImpl);
893
+ device_impl &RootOrSubDevImpl =
894
+ MustBuildOnSubdevice == true ? DeviceImpl : *RootDevImpl;
895
+
895
896
const RTDeviceBinaryImage &Img =
896
- getDeviceImage (KernelName, ContextImpl, getSyclObjImpl (Device). get () );
897
+ getDeviceImage (KernelName, ContextImpl, RootOrSubDevImpl );
897
898
898
899
// Check that device supports all aspects used by the kernel
899
- if (auto exception = checkDevSupportDeviceRequirements (Device, Img, NDRDesc))
900
+ if (auto exception =
901
+ checkDevSupportDeviceRequirements (RootOrSubDevImpl, Img, NDRDesc))
900
902
throw *exception;
901
903
902
904
std::set<RTDeviceBinaryImage *> DeviceImagesToLink =
903
- collectDeviceImageDeps (Img, {Device });
905
+ collectDeviceImageDeps (Img, {RootOrSubDevImpl });
904
906
905
907
// Decompress all DeviceImagesToLink
906
908
for (RTDeviceBinaryImage *BinImg : DeviceImagesToLink)
@@ -913,7 +915,7 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
913
915
std::back_inserter (AllImages));
914
916
915
917
return getBuiltURProgram (std::move (AllImages), ContextImpl,
916
- {std::move (Device )});
918
+ {createSyclObjFromImpl<device>(RootOrSubDevImpl )});
917
919
}
918
920
919
921
ur_program_handle_t ProgramManager::getBuiltURProgram (
@@ -1483,9 +1485,9 @@ ProgramManager::ProgramManager()
1483
1485
}
1484
1486
}
1485
1487
1486
- const char *getArchName (const device_impl * DeviceImpl) {
1488
+ const char *getArchName (const device_impl & DeviceImpl) {
1487
1489
namespace syclex = sycl::ext::oneapi::experimental;
1488
- auto Arch = DeviceImpl-> get_info <syclex::info::device::architecture>();
1490
+ auto Arch = DeviceImpl. get_info <syclex::info::device::architecture>();
1489
1491
switch (Arch) {
1490
1492
#define __SYCL_ARCHITECTURE (ARCH, VAL ) \
1491
1493
case syclex::architecture::ARCH: \
@@ -1507,7 +1509,7 @@ template <typename StorageKey>
1507
1509
RTDeviceBinaryImage *getBinImageFromMultiMap (
1508
1510
const std::unordered_multimap<StorageKey, RTDeviceBinaryImage *> &ImagesSet,
1509
1511
const StorageKey &Key, context_impl &ContextImpl,
1510
- const device_impl * DeviceImpl) {
1512
+ const device_impl & DeviceImpl) {
1511
1513
auto [ItBegin, ItEnd] = ImagesSet.equal_range (Key);
1512
1514
if (ItBegin == ItEnd)
1513
1515
return nullptr ;
@@ -1538,18 +1540,17 @@ RTDeviceBinaryImage *getBinImageFromMultiMap(
1538
1540
// Ask the native runtime under the given context to choose the device image
1539
1541
// it prefers.
1540
1542
ContextImpl.getAdapter ()->call <UrApiKind::urDeviceSelectBinary>(
1541
- DeviceImpl->getHandleRef (), UrBinaries.data (), UrBinaries.size (),
1542
- &ImgInd);
1543
+ DeviceImpl.getHandleRef (), UrBinaries.data (), UrBinaries.size (), &ImgInd);
1543
1544
return DeviceFilteredImgs[ImgInd];
1544
1545
}
1545
1546
1546
1547
RTDeviceBinaryImage &
1547
1548
ProgramManager::getDeviceImage (KernelNameStrRefT KernelName,
1548
1549
context_impl &ContextImpl,
1549
- const device_impl * DeviceImpl) {
1550
+ const device_impl & DeviceImpl) {
1550
1551
if constexpr (DbgProgMgr > 0 ) {
1551
1552
std::cerr << " >>> ProgramManager::getDeviceImage(\" " << KernelName << " \" , "
1552
- << ContextImpl.get () << " , " << DeviceImpl << " )\n " ;
1553
+ << ContextImpl.get () << " , " << & DeviceImpl << " )\n " ;
1553
1554
1554
1555
std::cerr << " available device images:\n " ;
1555
1556
debugPrintBinaryImages ();
@@ -1592,12 +1593,12 @@ ProgramManager::getDeviceImage(KernelNameStrRefT KernelName,
1592
1593
1593
1594
RTDeviceBinaryImage &ProgramManager::getDeviceImage (
1594
1595
const std::unordered_set<RTDeviceBinaryImage *> &ImageSet,
1595
- context_impl &ContextImpl, const device_impl * DeviceImpl) {
1596
+ context_impl &ContextImpl, const device_impl & DeviceImpl) {
1596
1597
assert (ImageSet.size () > 0 );
1597
1598
1598
1599
if constexpr (DbgProgMgr > 0 ) {
1599
1600
std::cerr << " >>> ProgramManager::getDeviceImage(Custom SPV file "
1600
- << ContextImpl.get () << " , " << DeviceImpl << " )\n " ;
1601
+ << ContextImpl.get () << " , " << & DeviceImpl << " )\n " ;
1601
1602
1602
1603
std::cerr << " available device images:\n " ;
1603
1604
debugPrintBinaryImages ();
@@ -1620,8 +1621,7 @@ RTDeviceBinaryImage &ProgramManager::getDeviceImage(
1620
1621
}
1621
1622
1622
1623
ContextImpl.getAdapter ()->call <UrApiKind::urDeviceSelectBinary>(
1623
- DeviceImpl->getHandleRef (), UrBinaries.data (), UrBinaries.size (),
1624
- &ImgInd);
1624
+ DeviceImpl.getHandleRef (), UrBinaries.data (), UrBinaries.size (), &ImgInd);
1625
1625
1626
1626
ImageIterator = ImageSet.begin ();
1627
1627
std::advance (ImageIterator, ImgInd);
@@ -2646,6 +2646,8 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
2646
2646
std::unordered_map<RTDeviceBinaryImage *, DeviceBinaryImageInfo> ImageInfoMap;
2647
2647
2648
2648
for (const sycl::device &Dev : Devs) {
2649
+
2650
+ device_impl &DevImpl = *getSyclObjImpl (Dev);
2649
2651
// Track the highest image state for each requested kernel.
2650
2652
using StateImagesPairT =
2651
2653
std::pair<bundle_state, std::vector<RTDeviceBinaryImage *>>;
@@ -2657,8 +2659,8 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
2657
2659
KernelImageMap.insert ({KernelID, {}});
2658
2660
2659
2661
for (RTDeviceBinaryImage *BinImage : BinImages) {
2660
- if (!compatibleWithDevice (BinImage, * getSyclObjImpl (Dev). get () ) ||
2661
- !doesDevSupportDeviceRequirements (Dev , *BinImage))
2662
+ if (!compatibleWithDevice (BinImage, DevImpl ) ||
2663
+ !doesDevSupportDeviceRequirements (DevImpl , *BinImage))
2662
2664
continue ;
2663
2665
2664
2666
auto InsertRes = ImageInfoMap.insert ({BinImage, {}});
@@ -2670,7 +2672,7 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
2670
2672
std::lock_guard<std::mutex> KernelIDsGuard (m_KernelIDsMutex);
2671
2673
ImgInfo.KernelIDs = m_BinImg2KernelIDs[BinImage];
2672
2674
}
2673
- ImgInfo.Deps = collectDeviceImageDeps (*BinImage, {Dev });
2675
+ ImgInfo.Deps = collectDeviceImageDeps (*BinImage, {DevImpl });
2674
2676
}
2675
2677
const bundle_state ImgState = ImgInfo.State ;
2676
2678
const std::shared_ptr<std::vector<sycl::kernel_id>> &ImageKernelIDs =
@@ -3366,7 +3368,7 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
3366
3368
return UrKernel;
3367
3369
}
3368
3370
3369
- bool doesDevSupportDeviceRequirements (const device &Dev,
3371
+ bool doesDevSupportDeviceRequirements (const device_impl &Dev,
3370
3372
const RTDeviceBinaryImage &Img) {
3371
3373
return !checkDevSupportDeviceRequirements (Dev, Img).has_value ();
3372
3374
}
@@ -3641,7 +3643,7 @@ std::optional<sycl::exception> checkDevSupportJointMatrixMad(
3641
3643
}
3642
3644
3643
3645
std::optional<sycl::exception>
3644
- checkDevSupportDeviceRequirements (const device &Dev,
3646
+ checkDevSupportDeviceRequirements (const device_impl &Dev,
3645
3647
const RTDeviceBinaryImage &Img,
3646
3648
const NDRDescT &NDRDesc) {
3647
3649
auto getPropIt = [&Img](const std::string &PropName) {
@@ -3854,29 +3856,29 @@ checkDevSupportDeviceRequirements(const device &Dev,
3854
3856
}
3855
3857
3856
3858
bool doesImageTargetMatchDevice (const RTDeviceBinaryImage &Img,
3857
- const device_impl * DevImpl) {
3859
+ const device_impl & DevImpl) {
3858
3860
auto PropRange = Img.getDeviceRequirements ();
3859
3861
auto PropIt =
3860
3862
std::find_if (PropRange.begin (), PropRange.end (), [&](const auto &Prop) {
3861
3863
return Prop->Name == std::string_view (" compile_target" );
3862
3864
});
3863
3865
// Device image has no compile_target property, check target.
3864
3866
if (PropIt == PropRange.end ()) {
3865
- sycl::backend BE = DevImpl-> getBackend ();
3867
+ sycl::backend BE = DevImpl. getBackend ();
3866
3868
const char *Target = Img.getRawData ().DeviceTargetSpec ;
3867
3869
if (strcmp (Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64) == 0 ) {
3868
3870
return (BE == sycl::backend::opencl ||
3869
3871
BE == sycl::backend::ext_oneapi_level_zero);
3870
3872
}
3871
3873
if (strcmp (Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_X86_64) == 0 ) {
3872
- return DevImpl-> is_cpu ();
3874
+ return DevImpl. is_cpu ();
3873
3875
}
3874
3876
if (strcmp (Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_GEN) == 0 ) {
3875
- return DevImpl-> is_gpu () && (BE == sycl::backend::opencl ||
3876
- BE == sycl::backend::ext_oneapi_level_zero);
3877
+ return DevImpl. is_gpu () && (BE == sycl::backend::opencl ||
3878
+ BE == sycl::backend::ext_oneapi_level_zero);
3877
3879
}
3878
3880
if (strcmp (Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_FPGA) == 0 ) {
3879
- return DevImpl-> is_accelerator ();
3881
+ return DevImpl. is_accelerator ();
3880
3882
}
3881
3883
if (strcmp (Target, __SYCL_DEVICE_BINARY_TARGET_NVPTX64) == 0 ||
3882
3884
strcmp (Target, __SYCL_DEVICE_BINARY_TARGET_LLVM_NVPTX64) == 0 ) {
0 commit comments