@@ -838,14 +838,13 @@ ProgramManager::collectDependentDeviceImagesForVirtualFunctions(
838
838
return DeviceImagesToLink;
839
839
}
840
840
841
- static void
842
- setSpecializationConstants (const std::shared_ptr<device_image_impl> &InputImpl,
843
- ur_program_handle_t Prog,
844
- const AdapterPtr &Adapter) {
845
- std::lock_guard<std::mutex> Lock{InputImpl->get_spec_const_data_lock ()};
841
+ static void setSpecializationConstants (device_image_impl &InputImpl,
842
+ ur_program_handle_t Prog,
843
+ const AdapterPtr &Adapter) {
844
+ std::lock_guard<std::mutex> Lock{InputImpl.get_spec_const_data_lock ()};
846
845
const std::map<std::string, std::vector<device_image_impl::SpecConstDescT>>
847
- &SpecConstData = InputImpl-> get_spec_const_data_ref ();
848
- const SerializedObj &SpecConsts = InputImpl-> get_spec_const_blob_ref ();
846
+ &SpecConstData = InputImpl. get_spec_const_data_ref ();
847
+ const SerializedObj &SpecConsts = InputImpl. get_spec_const_blob_ref ();
849
848
850
849
// Set all specialization IDs from descriptors in the input device image.
851
850
for (const auto &[SpecConstNames, SpecConstDescs] : SpecConstData) {
@@ -941,7 +940,7 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
941
940
if (!DeviceCodeWasInCache && MainImg.supportsSpecConstants ()) {
942
941
enableITTAnnotationsIfNeeded (NativePrg, Adapter);
943
942
if (DevImgWithDeps)
944
- setSpecializationConstants (getSyclObjImpl (DevImgWithDeps->getMain ()),
943
+ setSpecializationConstants (* getSyclObjImpl (DevImgWithDeps->getMain ()),
945
944
NativePrg, Adapter);
946
945
}
947
946
@@ -982,7 +981,7 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
982
981
enableITTAnnotationsIfNeeded (NativePrg, Adapter);
983
982
if (DevImgWithDeps)
984
983
setSpecializationConstants (
985
- getSyclObjImpl (DevImgWithDeps->getAll ()[I]), NativePrg,
984
+ * getSyclObjImpl (DevImgWithDeps->getAll ()[I]), NativePrg,
986
985
Adapter);
987
986
}
988
987
ProgramsToLink.push_back (NativePrg);
@@ -2508,9 +2507,9 @@ device_image_plain ProgramManager::getDeviceImageFromBinaryImage(
2508
2507
KernelIDs = m_BinImg2KernelIDs[BinImage];
2509
2508
}
2510
2509
2511
- DeviceImageImplPtr Impl = std::make_shared<detail::device_image_impl> (
2510
+ DeviceImageImplPtr Impl = device_image_impl::create (
2512
2511
BinImage, Ctx, std::vector<device>{Dev}, ImgState, KernelIDs,
2513
- /* PIProgram=*/ nullptr );
2512
+ /* PIProgram=*/ nullptr , ImageOriginSYCLOffline );
2514
2513
2515
2514
return createSyclObjFromImpl<device_image_plain>(std::move (Impl));
2516
2515
}
@@ -2668,9 +2667,10 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
2668
2667
if (ImgInfoPair.second .RequirementCounter == 0 )
2669
2668
continue ;
2670
2669
2671
- DeviceImageImplPtr MainImpl = std::make_shared<detail::device_image_impl> (
2670
+ DeviceImageImplPtr MainImpl = device_image_impl::create (
2672
2671
ImgInfoPair.first , Ctx, Devs, ImgInfoPair.second .State ,
2673
- ImgInfoPair.second .KernelIDs , /* PIProgram=*/ nullptr );
2672
+ ImgInfoPair.second .KernelIDs , /* PIProgram=*/ nullptr ,
2673
+ ImageOriginSYCLOffline);
2674
2674
2675
2675
std::vector<device_image_plain> Images;
2676
2676
const std::set<RTDeviceBinaryImage *> &Deps = ImgInfoPair.second .Deps ;
@@ -2701,8 +2701,9 @@ device_image_plain ProgramManager::createDependencyImage(
2701
2701
2702
2702
assert (DepState == getBinImageState (DepImage) &&
2703
2703
" State mismatch between main image and its dependency" );
2704
- DeviceImageImplPtr DepImpl = std::make_shared<detail::device_image_impl>(
2705
- DepImage, Ctx, Devs, DepState, DepKernelIDs, /* PIProgram=*/ nullptr );
2704
+ DeviceImageImplPtr DepImpl =
2705
+ device_image_impl::create (DepImage, Ctx, Devs, DepState, DepKernelIDs,
2706
+ /* PIProgram=*/ nullptr , ImageOriginSYCLOffline);
2706
2707
2707
2708
return createSyclObjFromImpl<device_image_plain>(std::move (DepImpl));
2708
2709
}
@@ -2857,42 +2858,42 @@ ProgramManager::compile(const DevImgPlainWithDeps &ImgWithDeps,
2857
2858
std::vector<device_image_plain> CompiledImages;
2858
2859
CompiledImages.reserve (ImgWithDeps.size ());
2859
2860
for (const device_image_plain &DeviceImage : ImgWithDeps.getAll ()) {
2860
- const std::shared_ptr<device_image_impl> &InputImpl =
2861
- getSyclObjImpl (DeviceImage);
2861
+ device_image_impl &InputImpl = *getSyclObjImpl (DeviceImage);
2862
2862
2863
2863
const AdapterPtr &Adapter =
2864
- getSyclObjImpl (InputImpl-> get_context ())->getAdapter ();
2864
+ getSyclObjImpl (InputImpl. get_context ())->getAdapter ();
2865
2865
2866
2866
ur_program_handle_t Prog =
2867
- createURProgram (*InputImpl-> get_bin_image_ref (),
2868
- *getSyclObjImpl (InputImpl-> get_context ()), Devs);
2867
+ createURProgram (*InputImpl. get_bin_image_ref (),
2868
+ *getSyclObjImpl (InputImpl. get_context ()), Devs);
2869
2869
2870
- if (InputImpl-> get_bin_image_ref ()->supportsSpecConstants ())
2870
+ if (InputImpl. get_bin_image_ref ()->supportsSpecConstants ())
2871
2871
setSpecializationConstants (InputImpl, Prog, Adapter);
2872
2872
2873
- KernelNameSetT KernelNames = InputImpl-> getKernelNames ();
2873
+ KernelNameSetT KernelNames = InputImpl. getKernelNames ();
2874
2874
2875
2875
std::optional<detail::KernelCompilerBinaryInfo> RTCInfo =
2876
- InputImpl-> getRTCInfo ();
2877
- DeviceImageImplPtr ObjectImpl = std::make_shared<detail::device_image_impl> (
2878
- InputImpl-> get_bin_image_ref (), InputImpl-> get_context (),
2876
+ InputImpl. getRTCInfo ();
2877
+ DeviceImageImplPtr ObjectImpl = device_image_impl::create (
2878
+ InputImpl. get_bin_image_ref (), InputImpl. get_context (),
2879
2879
std::vector<device>{Devs}, bundle_state::object,
2880
- InputImpl->get_kernel_ids_ptr (), Prog,
2881
- InputImpl->get_spec_const_data_ref (),
2882
- InputImpl->get_spec_const_blob_ref (), InputImpl->getOriginMask (),
2883
- std::move (RTCInfo), std::move (KernelNames));
2880
+ InputImpl.get_kernel_ids_ptr (), Prog,
2881
+ InputImpl.get_spec_const_data_ref (),
2882
+ InputImpl.get_spec_const_blob_ref (), InputImpl.getOriginMask (),
2883
+ std::move (RTCInfo), std::move (KernelNames),
2884
+ /* MergedImageStorage = */ nullptr );
2884
2885
2885
2886
std::string CompileOptions;
2886
2887
applyCompileOptionsFromEnvironment (CompileOptions);
2887
2888
appendCompileOptionsFromImage (
2888
- CompileOptions, *(InputImpl-> get_bin_image_ref ()), Devs, Adapter);
2889
+ CompileOptions, *(InputImpl. get_bin_image_ref ()), Devs, Adapter);
2889
2890
// Should always come last!
2890
2891
appendCompileEnvironmentVariablesThatAppend (CompileOptions);
2891
- ur_result_t Error = doCompile (
2892
- Adapter, ObjectImpl->get_ur_program_ref (), Devs.size (),
2893
- URDevices.data (),
2894
- getSyclObjImpl (InputImpl-> get_context ()).get ()->getHandleRef (),
2895
- CompileOptions.c_str ());
2892
+ ur_result_t Error =
2893
+ doCompile ( Adapter, ObjectImpl->get_ur_program_ref (), Devs.size (),
2894
+ URDevices.data (),
2895
+ getSyclObjImpl (InputImpl. get_context ()).get ()->getHandleRef (),
2896
+ CompileOptions.c_str ());
2896
2897
if (Error != UR_RESULT_SUCCESS)
2897
2898
throw sycl::exception (
2898
2899
make_error_code (errc::build),
@@ -3074,13 +3075,11 @@ ProgramManager::link(const std::vector<device_image_plain> &Imgs,
3074
3075
}
3075
3076
auto MergedRTCInfo = detail::KernelCompilerBinaryInfo::Merge (RTCInfoPtrs);
3076
3077
3077
- DeviceImageImplPtr ExecutableImpl =
3078
- std::make_shared<detail::device_image_impl>(
3079
- NewBinImg, Context, std::vector<device>{Devs},
3080
- bundle_state::executable, std::move (KernelIDs), LinkedProg,
3081
- std::move (NewSpecConstMap), std::move (NewSpecConstBlob),
3082
- CombinedOrigins, std::move (MergedRTCInfo),
3083
- std::move (MergedKernelNames), std::move (MergedImageStorage));
3078
+ DeviceImageImplPtr ExecutableImpl = device_image_impl::create (
3079
+ NewBinImg, Context, std::vector<device>{Devs}, bundle_state::executable,
3080
+ std::move (KernelIDs), LinkedProg, std::move (NewSpecConstMap),
3081
+ std::move (NewSpecConstBlob), CombinedOrigins, std::move (MergedRTCInfo),
3082
+ std::move (MergedKernelNames), std::move (MergedImageStorage));
3084
3083
3085
3084
// TODO: Make multiple sets of device images organized by devices they are
3086
3085
// compiled for.
@@ -3156,7 +3155,7 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
3156
3155
}
3157
3156
auto MergedRTCInfo = detail::KernelCompilerBinaryInfo::Merge (RTCInfoPtrs);
3158
3157
3159
- DeviceImageImplPtr ExecImpl = std::make_shared<detail::device_image_impl> (
3158
+ DeviceImageImplPtr ExecImpl = device_image_impl::create (
3160
3159
ResultBinImg, Context, std::vector<device>{Devs},
3161
3160
bundle_state::executable, std::move (KernelIDs), ResProgram,
3162
3161
std::move (SpecConstMap), std::move (SpecConstBlob), CombinedOrigins,
0 commit comments