Skip to content

Commit 6e55f3b

Browse files
[NFCI][SYCL][Graph] Cleanup after enable_shared_from_this for queue_impl (#18748)
`foo(const std::shared_ptr<queue_impl> &)` doesn't provide any information about possible `nullptr` arguments. It could provide information about possibility of creating a copy if used sparsely, but when entire codebase passes all objects like this it's meaningless too. Instead of that, the WIP refactoring across the entire codebase is to pass by raw ptr/ref (depending on the possibility of `nullptr` value) and extending lifetimes with explicit `shared_from_this()`. This PR is limited to `graph_impl.hpp` interfaces accepting `queue_impl`, clean up after #18715.
1 parent 77773ad commit 6e55f3b

File tree

5 files changed

+134
-123
lines changed

5 files changed

+134
-123
lines changed

sycl/source/detail/async_alloc.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ std::vector<std::shared_ptr<detail::node_impl>> getDepGraphNodes(
4646
// If this is being recorded from an in-order queue we need to get the last
4747
// in-order node if any, since this will later become a dependency of the
4848
// node being processed here.
49-
if (const auto &LastInOrderNode = Graph->getLastInorderNode(Queue);
49+
if (const auto &LastInOrderNode = Graph->getLastInorderNode(Queue.get());
5050
LastInOrderNode) {
5151
DepNodes.push_back(LastInOrderNode);
5252
}

sycl/source/detail/graph_impl.cpp

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,14 @@ graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
659659
return NodeImpl;
660660
}
661661

662+
void graph_impl::addQueue(sycl::detail::queue_impl &RecordingQueue) {
663+
MRecordingQueues.insert(RecordingQueue.weak_from_this());
664+
}
665+
666+
void graph_impl::removeQueue(sycl::detail::queue_impl &RecordingQueue) {
667+
MRecordingQueues.erase(RecordingQueue.weak_from_this());
668+
}
669+
662670
bool graph_impl::clearQueues() {
663671
bool AnyQueuesCleared = false;
664672
for (auto &Queue : MRecordingQueues) {
@@ -689,6 +697,24 @@ bool graph_impl::checkForCycles() {
689697
return CycleFound;
690698
}
691699

700+
std::shared_ptr<node_impl>
701+
graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
702+
if (!Queue) {
703+
assert(0 ==
704+
MInorderQueueMap.count(std::weak_ptr<sycl::detail::queue_impl>{}));
705+
return {};
706+
}
707+
if (0 == MInorderQueueMap.count(Queue->weak_from_this())) {
708+
return {};
709+
}
710+
return MInorderQueueMap[Queue->weak_from_this()];
711+
}
712+
713+
void graph_impl::setLastInorderNode(sycl::detail::queue_impl &Queue,
714+
std::shared_ptr<node_impl> Node) {
715+
MInorderQueueMap[Queue.weak_from_this()] = Node;
716+
}
717+
692718
void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
693719
std::shared_ptr<node_impl> Dest) {
694720
throwIfGraphRecordingQueue("make_edge()");
@@ -769,11 +795,10 @@ std::vector<sycl::detail::EventImplPtr> graph_impl::getExitNodesEvents(
769795
return Events;
770796
}
771797

772-
void graph_impl::beginRecording(
773-
const std::shared_ptr<sycl::detail::queue_impl> &Queue) {
798+
void graph_impl::beginRecording(sycl::detail::queue_impl &Queue) {
774799
graph_impl::WriteLock Lock(MMutex);
775-
if (!Queue->hasCommandGraph()) {
776-
Queue->setCommandGraph(shared_from_this());
800+
if (!Queue.hasCommandGraph()) {
801+
Queue.setCommandGraph(shared_from_this());
777802
addQueue(Queue);
778803
}
779804
}
@@ -1003,7 +1028,7 @@ exec_graph_impl::~exec_graph_impl() {
10031028
}
10041029

10051030
sycl::event
1006-
exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
1031+
exec_graph_impl::enqueue(sycl::detail::queue_impl &Queue,
10071032
sycl::detail::CG::StorageInitHelper CGData) {
10081033
WriteLock Lock(MMutex);
10091034

@@ -1012,8 +1037,9 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
10121037
PartitionsExecutionEvents;
10131038

10141039
auto CreateNewEvent([&]() {
1015-
auto NewEvent = std::make_shared<sycl::detail::event_impl>(Queue);
1016-
NewEvent->setContextImpl(Queue->getContextImplPtr());
1040+
auto NewEvent =
1041+
std::make_shared<sycl::detail::event_impl>(Queue.shared_from_this());
1042+
NewEvent->setContextImpl(Queue.getContextImplPtr());
10171043
NewEvent->setStateIncomplete();
10181044
return NewEvent;
10191045
});
@@ -1035,7 +1061,7 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
10351061
CGData.MEvents.push_back(PartitionsExecutionEvents[DepPartition]);
10361062
}
10371063

1038-
auto CommandBuffer = CurrentPartition->MCommandBuffers[Queue->get_device()];
1064+
auto CommandBuffer = CurrentPartition->MCommandBuffers[Queue.get_device()];
10391065

10401066
if (CommandBuffer) {
10411067
for (std::vector<sycl::detail::EventImplPtr>::iterator It =
@@ -1073,10 +1099,10 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
10731099
if (CGData.MRequirements.empty() && CGData.MEvents.empty()) {
10741100
NewEvent->setSubmissionTime();
10751101
ur_result_t Res =
1076-
Queue->getAdapter()
1102+
Queue.getAdapter()
10771103
->call_nocheck<
10781104
sycl::detail::UrApiKind::urEnqueueCommandBufferExp>(
1079-
Queue->getHandleRef(), CommandBuffer, 0, nullptr, &UREvent);
1105+
Queue.getHandleRef(), CommandBuffer, 0, nullptr, &UREvent);
10801106
NewEvent->setHandle(UREvent);
10811107
if (Res == UR_RESULT_ERROR_INVALID_QUEUE_PROPERTIES) {
10821108
throw sycl::exception(
@@ -1096,7 +1122,8 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
10961122
CommandBuffer, nullptr, std::move(CGData));
10971123

10981124
NewEvent = sycl::detail::Scheduler::getInstance().addCG(
1099-
std::move(CommandGroup), Queue, /*EventNeeded=*/true);
1125+
std::move(CommandGroup), Queue.shared_from_this(),
1126+
/*EventNeeded=*/true);
11001127
}
11011128
NewEvent->setEventFromSubmittedExecCommandBuffer(true);
11021129
} else if ((CurrentPartition->MSchedule.size() > 0) &&
@@ -1112,10 +1139,11 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
11121139
// In case of graph, this queue may differ from the actual execution
11131140
// queue. We therefore overload this Queue before submitting the task.
11141141
static_cast<sycl::detail::CGHostTask &>(*NodeImpl->MCommandGroup.get())
1115-
.MQueue = Queue;
1142+
.MQueue = Queue.shared_from_this();
11161143

11171144
NewEvent = sycl::detail::Scheduler::getInstance().addCG(
1118-
NodeImpl->getCGCopy(), Queue, /*EventNeeded=*/true);
1145+
NodeImpl->getCGCopy(), Queue.shared_from_this(),
1146+
/*EventNeeded=*/true);
11191147
}
11201148
PartitionsExecutionEvents[CurrentPartition] = NewEvent;
11211149
}
@@ -1844,21 +1872,20 @@ void modifiable_command_graph::begin_recording(
18441872
// related to graph at all.
18451873
checkGraphPropertiesAndThrow(PropList);
18461874

1847-
auto QueueImpl = sycl::detail::getSyclObjImpl(RecordingQueue);
1848-
assert(QueueImpl);
1875+
queue_impl &QueueImpl = *sycl::detail::getSyclObjImpl(RecordingQueue);
18491876

1850-
if (QueueImpl->hasCommandGraph()) {
1877+
if (QueueImpl.hasCommandGraph()) {
18511878
throw sycl::exception(sycl::make_error_code(errc::invalid),
18521879
"begin_recording cannot be called for a queue which "
18531880
"is already in the recording state.");
18541881
}
18551882

1856-
if (QueueImpl->get_context() != impl->getContext()) {
1883+
if (QueueImpl.get_context() != impl->getContext()) {
18571884
throw sycl::exception(sycl::make_error_code(errc::invalid),
18581885
"begin_recording called for a queue whose context "
18591886
"differs from the graph context.");
18601887
}
1861-
if (QueueImpl->get_device() != impl->getDevice()) {
1888+
if (QueueImpl.get_device() != impl->getDevice()) {
18621889
throw sycl::exception(sycl::make_error_code(errc::invalid),
18631890
"begin_recording called for a queue whose device "
18641891
"differs from the graph device.");
@@ -1881,15 +1908,13 @@ void modifiable_command_graph::end_recording() {
18811908
}
18821909

18831910
void modifiable_command_graph::end_recording(queue &RecordingQueue) {
1884-
auto QueueImpl = sycl::detail::getSyclObjImpl(RecordingQueue);
1885-
if (!QueueImpl)
1886-
return;
1887-
if (QueueImpl->getCommandGraph() == impl) {
1888-
QueueImpl->setCommandGraph(nullptr);
1911+
queue_impl &QueueImpl = *sycl::detail::getSyclObjImpl(RecordingQueue);
1912+
if (QueueImpl.getCommandGraph() == impl) {
1913+
QueueImpl.setCommandGraph(nullptr);
18891914
graph_impl::WriteLock Lock(impl->MMutex);
18901915
impl->removeQueue(QueueImpl);
18911916
}
1892-
if (QueueImpl->hasCommandGraph())
1917+
if (QueueImpl.hasCommandGraph())
18931918
throw sycl::exception(sycl::make_error_code(errc::invalid),
18941919
"end_recording called for a queue which is recording "
18951920
"to a different graph.");

sycl/source/detail/graph_impl.hpp

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -878,18 +878,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
878878
/// Add a queue to the set of queues which are currently recording to this
879879
/// graph.
880880
/// @param RecordingQueue Queue to add to set.
881-
void
882-
addQueue(const std::shared_ptr<sycl::detail::queue_impl> &RecordingQueue) {
883-
MRecordingQueues.insert(RecordingQueue);
884-
}
881+
void addQueue(sycl::detail::queue_impl &RecordingQueue);
885882

886883
/// Remove a queue from the set of queues which are currently recording to
887884
/// this graph.
888885
/// @param RecordingQueue Queue to remove from set.
889-
void
890-
removeQueue(const std::shared_ptr<sycl::detail::queue_impl> &RecordingQueue) {
891-
MRecordingQueues.erase(RecordingQueue);
892-
}
886+
void removeQueue(sycl::detail::queue_impl &RecordingQueue);
893887

894888
/// Remove all queues which are recording to this graph, also sets all queues
895889
/// cleared back to the executing state.
@@ -1001,22 +995,13 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
1001995
/// @return Last node in this graph added from \p Queue recording, or empty
1002996
/// shared pointer if none.
1003997
std::shared_ptr<node_impl>
1004-
getLastInorderNode(std::shared_ptr<sycl::detail::queue_impl> Queue) {
1005-
std::weak_ptr<sycl::detail::queue_impl> QueueWeakPtr(Queue);
1006-
if (0 == MInorderQueueMap.count(QueueWeakPtr)) {
1007-
return {};
1008-
}
1009-
return MInorderQueueMap[QueueWeakPtr];
1010-
}
998+
getLastInorderNode(sycl::detail::queue_impl *Queue);
1011999

10121000
/// Track the last node added to this graph from an in-order queue.
10131001
/// @param Queue In-order queue to register \p Node for.
10141002
/// @param Node Last node that was added to this graph from \p Queue.
1015-
void setLastInorderNode(std::shared_ptr<sycl::detail::queue_impl> Queue,
1016-
std::shared_ptr<node_impl> Node) {
1017-
std::weak_ptr<sycl::detail::queue_impl> QueueWeakPtr(Queue);
1018-
MInorderQueueMap[QueueWeakPtr] = Node;
1019-
}
1003+
void setLastInorderNode(sycl::detail::queue_impl &Queue,
1004+
std::shared_ptr<node_impl> Node);
10201005

10211006
/// Prints the contents of the graph to a text file in DOT format.
10221007
/// @param FilePath Path to the output file.
@@ -1176,7 +1161,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
11761161
/// Sets the Queue state to queue_state::recording. Adds the queue to the list
11771162
/// of recording queues associated with this graph.
11781163
/// @param[in] Queue The queue to be recorded from.
1179-
void beginRecording(const std::shared_ptr<sycl::detail::queue_impl> &Queue);
1164+
void beginRecording(sycl::detail::queue_impl &Queue);
11801165

11811166
/// Store the last barrier node that was submitted to the queue.
11821167
/// @param[in] Queue The queue the barrier was recorded from.
@@ -1346,7 +1331,7 @@ class exec_graph_impl {
13461331
/// @param Queue Command-queue to schedule execution on.
13471332
/// @param CGData Command-group data provided by the sycl::handler
13481333
/// @return Event associated with the execution of the graph.
1349-
sycl::event enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
1334+
sycl::event enqueue(sycl::detail::queue_impl &Queue,
13501335
sycl::detail::CG::StorageInitHelper CGData);
13511336

13521337
/// Turns the internal graph representation into UR command-buffers for a

sycl/source/handler.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ event handler::finalize() {
783783

784784
} else {
785785
event GraphCompletionEvent =
786-
impl->MExecGraph->enqueue(MQueue, std::move(impl->CGData));
786+
impl->MExecGraph->enqueue(impl->get_queue(), std::move(impl->CGData));
787787

788788
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
789789
MLastEvent = getSyclObjImpl(GraphCompletionEvent);
@@ -870,15 +870,16 @@ event handler::finalize() {
870870
// node can set it as a predecessor.
871871
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
872872
Deps;
873-
if (auto DependentNode = GraphImpl->getLastInorderNode(MQueue)) {
873+
if (auto DependentNode =
874+
GraphImpl->getLastInorderNode(impl->get_queue_or_null())) {
874875
Deps.push_back(std::move(DependentNode));
875876
}
876877
NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps);
877878

878879
// If we are recording an in-order queue remember the new node, so it
879880
// can be used as a dependency for any more nodes recorded from this
880881
// queue.
881-
GraphImpl->setLastInorderNode(MQueue, NodeImpl);
882+
GraphImpl->setLastInorderNode(*MQueue, NodeImpl);
882883
} else {
883884
auto LastBarrierRecordedFromQueue = GraphImpl->getBarrierDep(MQueue);
884885
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
@@ -1988,7 +1989,7 @@ void handler::depends_on(const detail::EventImplPtr &EventImpl) {
19881989
// we need to set it to recording (implements the transitive queue recording
19891990
// feature).
19901991
if (!QueueGraph) {
1991-
EventGraph->beginRecording(MQueue);
1992+
EventGraph->beginRecording(impl->get_queue());
19921993
}
19931994
}
19941995

0 commit comments

Comments
 (0)