Skip to content

Commit 6eb4a92

Browse files
[NFC][SYCL] std::enable_shared_from_this for device_image_impl (#19182)
1 parent 5074ba7 commit 6eb4a92

File tree

7 files changed

+93
-90
lines changed

7 files changed

+93
-90
lines changed

sycl/source/backend.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,9 @@ make_kernel_bundle(ur_native_handle_t NativeHandle,
301301
// do the same to user images, since they may contain references to undefined
302302
// symbols (e.g. when kernel_bundle is supposed to be joined with another).
303303
auto KernelIDs = std::make_shared<std::vector<kernel_id>>();
304-
auto DevImgImpl = std::make_shared<device_image_impl>(
305-
nullptr, TargetContext, Devices, State, KernelIDs, UrProgram,
306-
ImageOriginInterop);
304+
auto DevImgImpl =
305+
device_image_impl::create(nullptr, TargetContext, Devices, State,
306+
KernelIDs, UrProgram, ImageOriginInterop);
307307
device_image_plain DevImg{DevImgImpl};
308308

309309
return kernel_bundle_impl::create(TargetContext, Devices, DevImg);

sycl/source/detail/device_image_impl.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ namespace detail {
1515

1616
std::shared_ptr<kernel_impl> device_image_impl::tryGetExtensionKernel(
1717
std::string_view Name, const context &Context,
18-
const kernel_bundle_impl &OwnerBundle,
19-
const std::shared_ptr<device_image_impl> &Self) const {
18+
const kernel_bundle_impl &OwnerBundle) {
2019
if (!(getOriginMask() & ImageOriginKernelCompiler) &&
2120
!((getOriginMask() & ImageOriginSYCLBIN) && hasKernelName(Name)))
2221
return nullptr;
@@ -34,9 +33,9 @@ std::shared_ptr<kernel_impl> device_image_impl::tryGetExtensionKernel(
3433
auto [UrKernel, CacheMutex, ArgMask] =
3534
PM.getOrCreateKernel(Context, AdjustedName,
3635
/*PropList=*/{}, UrProgram);
37-
return std::make_shared<kernel_impl>(UrKernel, *getSyclObjImpl(Context),
38-
Self, OwnerBundle.shared_from_this(),
39-
ArgMask, UrProgram, CacheMutex);
36+
return std::make_shared<kernel_impl>(
37+
UrKernel, *getSyclObjImpl(Context), shared_from_this(),
38+
OwnerBundle.shared_from_this(), ArgMask, UrProgram, CacheMutex);
4039
}
4140
return nullptr;
4241
}
@@ -49,7 +48,7 @@ std::shared_ptr<kernel_impl> device_image_impl::tryGetExtensionKernel(
4948
// Kernel created by urKernelCreate is implicitly retained.
5049

5150
return std::make_shared<kernel_impl>(
52-
UrKernel, *detail::getSyclObjImpl(Context), Self,
51+
UrKernel, *detail::getSyclObjImpl(Context), shared_from_this(),
5352
OwnerBundle.shared_from_this(),
5453
/*ArgMask=*/nullptr, UrProgram, /*CacheMutex=*/nullptr);
5554
}

sycl/source/detail/device_image_impl.hpp

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,12 @@ struct KernelCompilerBinaryInfo {
230230
// The class is impl counterpart for sycl::device_image
231231
// It can represent a program in different states, kernel_id's it has and state
232232
// of specialization constants for it
233-
class device_image_impl {
233+
class device_image_impl
234+
: public std::enable_shared_from_this<device_image_impl> {
235+
struct private_tag {
236+
explicit private_tag() = default;
237+
};
238+
234239
public:
235240
// The struct maps specialization ID to offset in the binary blob where value
236241
// for this spec const should be.
@@ -249,8 +254,7 @@ class device_image_impl {
249254
device_image_impl(const RTDeviceBinaryImage *BinImage, context Context,
250255
std::vector<device> Devices, bundle_state State,
251256
std::shared_ptr<std::vector<kernel_id>> KernelIDs,
252-
ur_program_handle_t Program,
253-
uint8_t Origins = ImageOriginSYCLOffline)
257+
ur_program_handle_t Program, uint8_t Origins, private_tag)
254258
: MBinImage(BinImage), MContext(std::move(Context)),
255259
MDevices(std::move(Devices)), MState(State), MProgram(Program),
256260
MKernelIDs(std::move(KernelIDs)),
@@ -272,7 +276,7 @@ class device_image_impl {
272276
const std::vector<unsigned char> &SpecConstsBlob, uint8_t Origins,
273277
std::optional<KernelCompilerBinaryInfo> &&RTCInfo,
274278
KernelNameSetT &&KernelNames,
275-
std::unique_ptr<DynRTDeviceBinaryImage> &&MergedImageStorage = nullptr)
279+
std::unique_ptr<DynRTDeviceBinaryImage> &&MergedImageStorage, private_tag)
276280
: MBinImage(BinImage), MContext(std::move(Context)),
277281
MDevices(std::move(Devices)), MState(State), MProgram(Program),
278282
MKernelIDs(std::move(KernelIDs)), MKernelNames{std::move(KernelNames)},
@@ -285,7 +289,7 @@ class device_image_impl {
285289
device_image_impl(const RTDeviceBinaryImage *BinImage, const context &Context,
286290
const std::vector<device> &Devices, bundle_state State,
287291
ur_program_handle_t Program, syclex::source_language Lang,
288-
KernelNameSetT &&KernelNames)
292+
KernelNameSetT &&KernelNames, private_tag)
289293
: MBinImage(BinImage), MContext(std::move(Context)),
290294
MDevices(std::move(Devices)), MState(State), MProgram(Program),
291295
MKernelIDs(std::make_shared<std::vector<kernel_id>>()),
@@ -302,7 +306,8 @@ class device_image_impl {
302306
std::shared_ptr<std::vector<kernel_id>> &&KernelIDs,
303307
syclex::source_language Lang, KernelNameSetT &&KernelNames,
304308
MangledKernelNameMapT &&MangledKernelNames, std::string &&Prefix,
305-
std::shared_ptr<ManagedDeviceGlobalsRegistry> &&DeviceGlobalRegistry)
309+
std::shared_ptr<ManagedDeviceGlobalsRegistry> &&DeviceGlobalRegistry,
310+
private_tag)
306311
: MBinImage(BinImage), MContext(std::move(Context)),
307312
MDevices(std::move(Devices)), MState(State), MProgram(nullptr),
308313
MKernelIDs(std::move(KernelIDs)), MKernelNames{std::move(KernelNames)},
@@ -317,7 +322,7 @@ class device_image_impl {
317322
device_image_impl(const std::string &Src, context Context,
318323
const std::vector<device> &Devices,
319324
syclex::source_language Lang,
320-
include_pairs_t &&IncludePairsVec)
325+
include_pairs_t &&IncludePairsVec, private_tag)
321326
: MBinImage(Src), MContext(std::move(Context)),
322327
MDevices(std::move(Devices)), MState(bundle_state::ext_oneapi_source),
323328
MProgram(nullptr),
@@ -331,7 +336,7 @@ class device_image_impl {
331336

332337
device_image_impl(const std::vector<std::byte> &Bytes, const context &Context,
333338
const std::vector<device> &Devices,
334-
syclex::source_language Lang)
339+
syclex::source_language Lang, private_tag)
335340
: MBinImage(Bytes), MContext(std::move(Context)),
336341
MDevices(std::move(Devices)), MState(bundle_state::ext_oneapi_source),
337342
MProgram(nullptr),
@@ -344,7 +349,8 @@ class device_image_impl {
344349

345350
device_image_impl(const context &Context, const std::vector<device> &Devices,
346351
bundle_state State, ur_program_handle_t Program,
347-
syclex::source_language Lang, KernelNameSetT &&KernelNames)
352+
syclex::source_language Lang, KernelNameSetT &&KernelNames,
353+
private_tag)
348354
: MBinImage(static_cast<const RTDeviceBinaryImage *>(nullptr)),
349355
MContext(std::move(Context)), MDevices(std::move(Devices)),
350356
MState(State), MProgram(Program),
@@ -354,6 +360,12 @@ class device_image_impl {
354360
MOrigins(ImageOriginKernelCompiler),
355361
MRTCBinInfo(KernelCompilerBinaryInfo{Lang}) {}
356362

363+
template <typename... Ts>
364+
static std::shared_ptr<device_image_impl> create(Ts &&...args) {
365+
return std::make_shared<device_image_impl>(std::forward<Ts>(args)...,
366+
private_tag{});
367+
}
368+
357369
bool has_kernel(const kernel_id &KernelIDCand) const noexcept {
358370
return std::binary_search(MKernelIDs->begin(), MKernelIDs->end(),
359371
KernelIDCand, LessByHash<kernel_id>{});
@@ -631,8 +643,7 @@ class device_image_impl {
631643

632644
std::shared_ptr<kernel_impl>
633645
tryGetExtensionKernel(std::string_view Name, const context &Context,
634-
const kernel_bundle_impl &OwnerBundle,
635-
const std::shared_ptr<device_image_impl> &Self) const;
646+
const kernel_bundle_impl &OwnerBundle);
636647

637648
bool hasDeviceGlobalName(const std::string &Name) const noexcept {
638649
if (!MRTCBinInfo.has_value())
@@ -752,9 +763,9 @@ class device_image_impl {
752763
*SourceStrPtr, UrProgram);
753764
}
754765
return std::vector<std::shared_ptr<device_image_impl>>{
755-
std::make_shared<device_image_impl>(
756-
MContext, Devices, bundle_state::executable, UrProgram,
757-
MRTCBinInfo->MLanguage, std::move(KernelNameSet))};
766+
device_image_impl::create(MContext, Devices, bundle_state::executable,
767+
UrProgram, MRTCBinInfo->MLanguage,
768+
std::move(KernelNameSet))};
758769
}
759770

760771
std::vector<std::shared_ptr<device_image_impl>> compileFromSource(
@@ -1114,7 +1125,7 @@ class device_image_impl {
11141125

11151126
// Mark the image as input so the program manager will bring it into
11161127
// the right state.
1117-
auto DevImgImpl = std::make_shared<device_image_impl>(
1128+
auto DevImgImpl = device_image_impl::create(
11181129
NewImage, MContext, std::move(SupportingDevs), bundle_state::input,
11191130
std::move(KernelIDs), MRTCBinInfo->MLanguage, std::move(KernelNames),
11201131
std::move(MangledKernelNames), std::string{Prefix},

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ class kernel_bundle_impl
528528
const std::string &Src, include_pairs_t IncludePairsVec,
529529
private_tag)
530530
: MContext(Context), MDevices(Context.get_devices()),
531-
MDeviceImages{device_image_plain{std::make_shared<device_image_impl>(
531+
MDeviceImages{device_image_plain{device_image_impl::create(
532532
Src, MContext, MDevices, Lang, std::move(IncludePairsVec))}},
533533
MUniqueDeviceImages{MDeviceImages[0].getMain()},
534534
MState(bundle_state::ext_oneapi_source) {
@@ -540,8 +540,8 @@ class kernel_bundle_impl
540540
kernel_bundle_impl(const context &Context, syclex::source_language Lang,
541541
const std::vector<std::byte> &Bytes, private_tag)
542542
: MContext(Context), MDevices(Context.get_devices()),
543-
MDeviceImages{device_image_plain{std::make_shared<device_image_impl>(
544-
Bytes, MContext, MDevices, Lang)}},
543+
MDeviceImages{device_image_plain{
544+
device_image_impl::create(Bytes, MContext, MDevices, Lang)}},
545545
MUniqueDeviceImages{MDeviceImages[0].getMain()},
546546
MState(bundle_state::ext_oneapi_source) {
547547
common_ctor_checks();
@@ -583,7 +583,7 @@ class kernel_bundle_impl
583583
SYCLBIN->getBestCompatibleImages(Devs);
584584
MDeviceImages.reserve(BestImages.size());
585585
for (const detail::RTDeviceBinaryImage *Image : BestImages)
586-
MDeviceImages.emplace_back(std::make_shared<detail::device_image_impl>(
586+
MDeviceImages.emplace_back(device_image_impl::create(
587587
Image, Context, Devs, ProgramManager::getBinImageState(Image),
588588
/*KernelIDs=*/nullptr, /*URProgram=*/nullptr, ImageOriginSYCLBIN));
589589
ProgramManager::getInstance().bringSYCLDeviceImagesToState(MDeviceImages,
@@ -669,8 +669,7 @@ class kernel_bundle_impl
669669
const std::shared_ptr<device_image_impl> &DevImgImpl =
670670
getSyclObjImpl(DevImg);
671671
if (std::shared_ptr<kernel_impl> PotentialKernelImpl =
672-
DevImgImpl->tryGetExtensionKernel(Name, MContext, *this,
673-
DevImgImpl))
672+
DevImgImpl->tryGetExtensionKernel(Name, MContext, *this))
674673
return detail::createSyclObjFromImpl<kernel>(
675674
std::move(PotentialKernelImpl));
676675
}
@@ -1008,8 +1007,7 @@ class kernel_bundle_impl
10081007
const std::shared_ptr<device_image_impl> &DevImgImpl =
10091008
getSyclObjImpl(DevImg);
10101009
if (std::shared_ptr<kernel_impl> SourceBasedKernel =
1011-
DevImgImpl->tryGetExtensionKernel(Name, MContext, *this,
1012-
DevImgImpl))
1010+
DevImgImpl->tryGetExtensionKernel(Name, MContext, *this))
10131011
return SourceBasedKernel;
10141012
}
10151013

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 42 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -838,14 +838,13 @@ ProgramManager::collectDependentDeviceImagesForVirtualFunctions(
838838
return DeviceImagesToLink;
839839
}
840840

841-
static void
842-
setSpecializationConstants(const std::shared_ptr<device_image_impl> &InputImpl,
843-
ur_program_handle_t Prog,
844-
const AdapterPtr &Adapter) {
845-
std::lock_guard<std::mutex> Lock{InputImpl->get_spec_const_data_lock()};
841+
static void setSpecializationConstants(device_image_impl &InputImpl,
842+
ur_program_handle_t Prog,
843+
const AdapterPtr &Adapter) {
844+
std::lock_guard<std::mutex> Lock{InputImpl.get_spec_const_data_lock()};
846845
const std::map<std::string, std::vector<device_image_impl::SpecConstDescT>>
847-
&SpecConstData = InputImpl->get_spec_const_data_ref();
848-
const SerializedObj &SpecConsts = InputImpl->get_spec_const_blob_ref();
846+
&SpecConstData = InputImpl.get_spec_const_data_ref();
847+
const SerializedObj &SpecConsts = InputImpl.get_spec_const_blob_ref();
849848

850849
// Set all specialization IDs from descriptors in the input device image.
851850
for (const auto &[SpecConstNames, SpecConstDescs] : SpecConstData) {
@@ -941,7 +940,7 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
941940
if (!DeviceCodeWasInCache && MainImg.supportsSpecConstants()) {
942941
enableITTAnnotationsIfNeeded(NativePrg, Adapter);
943942
if (DevImgWithDeps)
944-
setSpecializationConstants(getSyclObjImpl(DevImgWithDeps->getMain()),
943+
setSpecializationConstants(*getSyclObjImpl(DevImgWithDeps->getMain()),
945944
NativePrg, Adapter);
946945
}
947946

@@ -982,7 +981,7 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
982981
enableITTAnnotationsIfNeeded(NativePrg, Adapter);
983982
if (DevImgWithDeps)
984983
setSpecializationConstants(
985-
getSyclObjImpl(DevImgWithDeps->getAll()[I]), NativePrg,
984+
*getSyclObjImpl(DevImgWithDeps->getAll()[I]), NativePrg,
986985
Adapter);
987986
}
988987
ProgramsToLink.push_back(NativePrg);
@@ -2508,9 +2507,9 @@ device_image_plain ProgramManager::getDeviceImageFromBinaryImage(
25082507
KernelIDs = m_BinImg2KernelIDs[BinImage];
25092508
}
25102509

2511-
DeviceImageImplPtr Impl = std::make_shared<detail::device_image_impl>(
2510+
DeviceImageImplPtr Impl = device_image_impl::create(
25122511
BinImage, Ctx, std::vector<device>{Dev}, ImgState, KernelIDs,
2513-
/*PIProgram=*/nullptr);
2512+
/*PIProgram=*/nullptr, ImageOriginSYCLOffline);
25142513

25152514
return createSyclObjFromImpl<device_image_plain>(std::move(Impl));
25162515
}
@@ -2668,9 +2667,10 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
26682667
if (ImgInfoPair.second.RequirementCounter == 0)
26692668
continue;
26702669

2671-
DeviceImageImplPtr MainImpl = std::make_shared<detail::device_image_impl>(
2670+
DeviceImageImplPtr MainImpl = device_image_impl::create(
26722671
ImgInfoPair.first, Ctx, Devs, ImgInfoPair.second.State,
2673-
ImgInfoPair.second.KernelIDs, /*PIProgram=*/nullptr);
2672+
ImgInfoPair.second.KernelIDs, /*PIProgram=*/nullptr,
2673+
ImageOriginSYCLOffline);
26742674

26752675
std::vector<device_image_plain> Images;
26762676
const std::set<RTDeviceBinaryImage *> &Deps = ImgInfoPair.second.Deps;
@@ -2701,8 +2701,9 @@ device_image_plain ProgramManager::createDependencyImage(
27012701

27022702
assert(DepState == getBinImageState(DepImage) &&
27032703
"State mismatch between main image and its dependency");
2704-
DeviceImageImplPtr DepImpl = std::make_shared<detail::device_image_impl>(
2705-
DepImage, Ctx, Devs, DepState, DepKernelIDs, /*PIProgram=*/nullptr);
2704+
DeviceImageImplPtr DepImpl =
2705+
device_image_impl::create(DepImage, Ctx, Devs, DepState, DepKernelIDs,
2706+
/*PIProgram=*/nullptr, ImageOriginSYCLOffline);
27062707

27072708
return createSyclObjFromImpl<device_image_plain>(std::move(DepImpl));
27082709
}
@@ -2857,42 +2858,42 @@ ProgramManager::compile(const DevImgPlainWithDeps &ImgWithDeps,
28572858
std::vector<device_image_plain> CompiledImages;
28582859
CompiledImages.reserve(ImgWithDeps.size());
28592860
for (const device_image_plain &DeviceImage : ImgWithDeps.getAll()) {
2860-
const std::shared_ptr<device_image_impl> &InputImpl =
2861-
getSyclObjImpl(DeviceImage);
2861+
device_image_impl &InputImpl = *getSyclObjImpl(DeviceImage);
28622862

28632863
const AdapterPtr &Adapter =
2864-
getSyclObjImpl(InputImpl->get_context())->getAdapter();
2864+
getSyclObjImpl(InputImpl.get_context())->getAdapter();
28652865

28662866
ur_program_handle_t Prog =
2867-
createURProgram(*InputImpl->get_bin_image_ref(),
2868-
*getSyclObjImpl(InputImpl->get_context()), Devs);
2867+
createURProgram(*InputImpl.get_bin_image_ref(),
2868+
*getSyclObjImpl(InputImpl.get_context()), Devs);
28692869

2870-
if (InputImpl->get_bin_image_ref()->supportsSpecConstants())
2870+
if (InputImpl.get_bin_image_ref()->supportsSpecConstants())
28712871
setSpecializationConstants(InputImpl, Prog, Adapter);
28722872

2873-
KernelNameSetT KernelNames = InputImpl->getKernelNames();
2873+
KernelNameSetT KernelNames = InputImpl.getKernelNames();
28742874

28752875
std::optional<detail::KernelCompilerBinaryInfo> RTCInfo =
2876-
InputImpl->getRTCInfo();
2877-
DeviceImageImplPtr ObjectImpl = std::make_shared<detail::device_image_impl>(
2878-
InputImpl->get_bin_image_ref(), InputImpl->get_context(),
2876+
InputImpl.getRTCInfo();
2877+
DeviceImageImplPtr ObjectImpl = device_image_impl::create(
2878+
InputImpl.get_bin_image_ref(), InputImpl.get_context(),
28792879
std::vector<device>{Devs}, bundle_state::object,
2880-
InputImpl->get_kernel_ids_ptr(), Prog,
2881-
InputImpl->get_spec_const_data_ref(),
2882-
InputImpl->get_spec_const_blob_ref(), InputImpl->getOriginMask(),
2883-
std::move(RTCInfo), std::move(KernelNames));
2880+
InputImpl.get_kernel_ids_ptr(), Prog,
2881+
InputImpl.get_spec_const_data_ref(),
2882+
InputImpl.get_spec_const_blob_ref(), InputImpl.getOriginMask(),
2883+
std::move(RTCInfo), std::move(KernelNames),
2884+
/*MergedImageStorage = */ nullptr);
28842885

28852886
std::string CompileOptions;
28862887
applyCompileOptionsFromEnvironment(CompileOptions);
28872888
appendCompileOptionsFromImage(
2888-
CompileOptions, *(InputImpl->get_bin_image_ref()), Devs, Adapter);
2889+
CompileOptions, *(InputImpl.get_bin_image_ref()), Devs, Adapter);
28892890
// Should always come last!
28902891
appendCompileEnvironmentVariablesThatAppend(CompileOptions);
2891-
ur_result_t Error = doCompile(
2892-
Adapter, ObjectImpl->get_ur_program_ref(), Devs.size(),
2893-
URDevices.data(),
2894-
getSyclObjImpl(InputImpl->get_context()).get()->getHandleRef(),
2895-
CompileOptions.c_str());
2892+
ur_result_t Error =
2893+
doCompile(Adapter, ObjectImpl->get_ur_program_ref(), Devs.size(),
2894+
URDevices.data(),
2895+
getSyclObjImpl(InputImpl.get_context()).get()->getHandleRef(),
2896+
CompileOptions.c_str());
28962897
if (Error != UR_RESULT_SUCCESS)
28972898
throw sycl::exception(
28982899
make_error_code(errc::build),
@@ -3074,13 +3075,11 @@ ProgramManager::link(const std::vector<device_image_plain> &Imgs,
30743075
}
30753076
auto MergedRTCInfo = detail::KernelCompilerBinaryInfo::Merge(RTCInfoPtrs);
30763077

3077-
DeviceImageImplPtr ExecutableImpl =
3078-
std::make_shared<detail::device_image_impl>(
3079-
NewBinImg, Context, std::vector<device>{Devs},
3080-
bundle_state::executable, std::move(KernelIDs), LinkedProg,
3081-
std::move(NewSpecConstMap), std::move(NewSpecConstBlob),
3082-
CombinedOrigins, std::move(MergedRTCInfo),
3083-
std::move(MergedKernelNames), std::move(MergedImageStorage));
3078+
DeviceImageImplPtr ExecutableImpl = device_image_impl::create(
3079+
NewBinImg, Context, std::vector<device>{Devs}, bundle_state::executable,
3080+
std::move(KernelIDs), LinkedProg, std::move(NewSpecConstMap),
3081+
std::move(NewSpecConstBlob), CombinedOrigins, std::move(MergedRTCInfo),
3082+
std::move(MergedKernelNames), std::move(MergedImageStorage));
30843083

30853084
// TODO: Make multiple sets of device images organized by devices they are
30863085
// compiled for.
@@ -3156,7 +3155,7 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
31563155
}
31573156
auto MergedRTCInfo = detail::KernelCompilerBinaryInfo::Merge(RTCInfoPtrs);
31583157

3159-
DeviceImageImplPtr ExecImpl = std::make_shared<detail::device_image_impl>(
3158+
DeviceImageImplPtr ExecImpl = device_image_impl::create(
31603159
ResultBinImg, Context, std::vector<device>{Devs},
31613160
bundle_state::executable, std::move(KernelIDs), ResProgram,
31623161
std::move(SpecConstMap), std::move(SpecConstBlob), CombinedOrigins,

0 commit comments

Comments
 (0)