Skip to content

Commit 629d3a3

Browse files
[NFC][SYCL] Pass context_impl by raw ptr/ref in scheduler.hpp (#18980)
Part of the ongoing refactoring to prefer raw ptr/ref for SYCL RT objects by default with explicit `shared_from_this` when lifetimes need to be extended.
1 parent a4c5ace commit 629d3a3

File tree

6 files changed

+18
-16
lines changed

6 files changed

+18
-16
lines changed

sycl/source/detail/queue_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ event queue_impl::submitMemOpHelper(const std::vector<event> &DepEvents,
449449
// If we have a command graph set we need to capture the op through the
450450
// handler rather than by-passing the scheduler.
451451
if (MGraph.expired() && Scheduler::areEventsSafeForSchedulerBypass(
452-
ExpandedDepEvents, MContext)) {
452+
ExpandedDepEvents, *MContext)) {
453453
auto isNoEventsMode = trySwitchingToNoEventsMode();
454454
if (!CallerNeedsEvent && isNoEventsMode) {
455455
NestedCallsTracker tracker;

sycl/source/detail/queue_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
725725
return false;
726726

727727
if (MDefaultGraphDeps.LastEventPtr != nullptr &&
728-
!Scheduler::CheckEventReadiness(MContext,
728+
!Scheduler::CheckEventReadiness(*MContext,
729729
MDefaultGraphDeps.LastEventPtr))
730730
return false;
731731

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,16 +225,16 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(const QueueImplPtr &Queue,
225225
Dev, InteropCtxPtr, async_handler{}, property_list{});
226226

227227
MemObject->MRecord.reset(
228-
new MemObjRecord{InteropCtxPtr, LeafLimit, AllocateDependency});
228+
new MemObjRecord{InteropCtxPtr.get(), LeafLimit, AllocateDependency});
229229
std::vector<Command *> ToEnqueue;
230230
getOrCreateAllocaForReq(MemObject->MRecord.get(), Req, InteropQueuePtr,
231231
ToEnqueue);
232232
assert(ToEnqueue.empty() && "Creation of the first alloca for a record "
233233
"shouldn't lead to any enqueuing (no linked "
234234
"alloca or exceeding the leaf limit).");
235235
} else
236-
MemObject->MRecord.reset(new MemObjRecord{queue_impl::getContext(Queue),
237-
LeafLimit, AllocateDependency});
236+
MemObject->MRecord.reset(new MemObjRecord{
237+
queue_impl::getContext(Queue).get(), LeafLimit, AllocateDependency});
238238

239239
MMemObjs.push_back(MemObject);
240240
return MemObject->MRecord.get();

sycl/source/detail/scheduler/scheduler.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ EventImplPtr Scheduler::addCommandGraphUpdate(
678678
return NewCmdEvent;
679679
}
680680

681-
bool Scheduler::CheckEventReadiness(const ContextImplPtr &Context,
681+
bool Scheduler::CheckEventReadiness(context_impl &Context,
682682
const EventImplPtr &SyclEventImplPtr) {
683683
// Events that don't have an initialized context are throwaway events that
684684
// don't represent actual dependencies. Calling getContextImpl() would set
@@ -691,7 +691,7 @@ bool Scheduler::CheckEventReadiness(const ContextImplPtr &Context,
691691
return SyclEventImplPtr->isCompleted();
692692
}
693693
// Cross-context dependencies can't be passed to the backend directly.
694-
if (SyclEventImplPtr->getContextImpl() != Context)
694+
if (SyclEventImplPtr->getContextImpl().get() != &Context)
695695
return false;
696696

697697
// A nullptr here means that the commmand does not produce a UR event or it
@@ -700,7 +700,7 @@ bool Scheduler::CheckEventReadiness(const ContextImplPtr &Context,
700700
}
701701

702702
bool Scheduler::areEventsSafeForSchedulerBypass(
703-
const std::vector<sycl::event> &DepEvents, const ContextImplPtr &Context) {
703+
const std::vector<sycl::event> &DepEvents, context_impl &Context) {
704704

705705
return std::all_of(
706706
DepEvents.begin(), DepEvents.end(), [&Context](const sycl::event &Event) {
@@ -710,7 +710,7 @@ bool Scheduler::areEventsSafeForSchedulerBypass(
710710
}
711711

712712
bool Scheduler::areEventsSafeForSchedulerBypass(
713-
const std::vector<EventImplPtr> &DepEvents, const ContextImplPtr &Context) {
713+
const std::vector<EventImplPtr> &DepEvents, context_impl &Context) {
714714

715715
return std::all_of(DepEvents.begin(), DepEvents.end(),
716716
[&Context](const EventImplPtr &SyclEventImplPtr) {

sycl/source/detail/scheduler/scheduler.hpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <detail/cg.hpp>
12+
#include <detail/context_impl.hpp>
1213
#include <detail/scheduler/commands.hpp>
1314
#include <detail/scheduler/leaves_collection.hpp>
1415
#include <detail/sycl_mem_obj_i.hpp>
@@ -198,10 +199,11 @@ using CommandPtr = std::unique_ptr<Command>;
198199
///
199200
/// \ingroup sycl_graph
200201
struct MemObjRecord {
201-
MemObjRecord(ContextImplPtr Ctx, std::size_t LeafLimit,
202+
MemObjRecord(context_impl *Ctx, std::size_t LeafLimit,
202203
LeavesCollection::AllocateDependencyF AllocateDependency)
203204
: MReadLeaves{this, LeafLimit, AllocateDependency},
204-
MWriteLeaves{this, LeafLimit, AllocateDependency}, MCurContext{Ctx} {}
205+
MWriteLeaves{this, LeafLimit, AllocateDependency},
206+
MCurContext{Ctx ? Ctx->shared_from_this() : nullptr} {}
205207
// Contains all allocation commands for the memory object.
206208
std::vector<AllocaCommandBase *> MAllocaCommands;
207209

@@ -212,7 +214,7 @@ struct MemObjRecord {
212214
LeavesCollection MWriteLeaves;
213215

214216
// The context which has the latest state of the memory object.
215-
ContextImplPtr MCurContext;
217+
std::shared_ptr<context_impl> MCurContext;
216218

217219
// The mode this object can be accessed from the host (host_accessor).
218220
// Valid only if the current usage is on host.
@@ -477,15 +479,15 @@ class Scheduler {
477479
const QueueImplPtr &Queue, std::vector<Requirement *> Requirements,
478480
std::vector<detail::EventImplPtr> &Events);
479481

480-
static bool CheckEventReadiness(const ContextImplPtr &Context,
482+
static bool CheckEventReadiness(context_impl &Context,
481483
const EventImplPtr &SyclEventImplPtr);
482484

483485
static bool
484486
areEventsSafeForSchedulerBypass(const std::vector<sycl::event> &DepEvents,
485-
const ContextImplPtr &Context);
487+
context_impl &Context);
486488
static bool
487489
areEventsSafeForSchedulerBypass(const std::vector<EventImplPtr> &DepEvents,
488-
const ContextImplPtr &Context);
490+
context_impl &Context);
489491

490492
protected:
491493
using RWLockT = std::shared_timed_mutex;

sycl/source/handler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ event handler::finalize() {
411411
(Queue && !Graph && !impl->MSubgraphNode && !Queue->hasCommandGraph() &&
412412
!impl->CGData.MRequirements.size() && !MStreamStorage.size() &&
413413
detail::Scheduler::areEventsSafeForSchedulerBypass(
414-
impl->CGData.MEvents, Queue->getContextImplPtr()));
414+
impl->CGData.MEvents, Queue->getContextImpl()));
415415

416416
// Extract arguments from the kernel lambda, if required.
417417
// Skipping this is currently limited to simple kernels on the fast path.

0 commit comments

Comments
 (0)