Skip to content

Commit 74782fe

Browse files
[NFCI][SYCL] Cleanup after enable_shared_from_this for queue_impl
1 parent 00a716a commit 74782fe

File tree

5 files changed

+134
-123
lines changed

5 files changed

+134
-123
lines changed

sycl/source/detail/async_alloc.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ std::vector<std::shared_ptr<detail::node_impl>> getDepGraphNodes(
4646
// If this is being recorded from an in-order queue we need to get the last
4747
// in-order node if any, since this will later become a dependency of the
4848
// node being processed here.
49-
if (const auto &LastInOrderNode = Graph->getLastInorderNode(Queue);
49+
if (const auto &LastInOrderNode = Graph->getLastInorderNode(Queue.get());
5050
LastInOrderNode) {
5151
DepNodes.push_back(LastInOrderNode);
5252
}

sycl/source/detail/graph_impl.cpp

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,14 @@ graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
659659
return NodeImpl;
660660
}
661661

662+
void graph_impl::addQueue(sycl::detail::queue_impl &RecordingQueue) {
663+
MRecordingQueues.insert(RecordingQueue.weak_from_this());
664+
}
665+
666+
void graph_impl::removeQueue(sycl::detail::queue_impl &RecordingQueue) {
667+
MRecordingQueues.erase(RecordingQueue.weak_from_this());
668+
}
669+
662670
bool graph_impl::clearQueues() {
663671
bool AnyQueuesCleared = false;
664672
for (auto &Queue : MRecordingQueues) {
@@ -689,6 +697,24 @@ bool graph_impl::checkForCycles() {
689697
return CycleFound;
690698
}
691699

700+
std::shared_ptr<node_impl>
701+
graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
702+
if (!Queue) {
703+
assert(0 ==
704+
MInorderQueueMap.count(std::weak_ptr<sycl::detail::queue_impl>{}));
705+
return {};
706+
}
707+
if (0 == MInorderQueueMap.count(Queue->weak_from_this())) {
708+
return {};
709+
}
710+
return MInorderQueueMap[Queue->weak_from_this()];
711+
}
712+
713+
void graph_impl::setLastInorderNode(sycl::detail::queue_impl &Queue,
714+
std::shared_ptr<node_impl> Node) {
715+
MInorderQueueMap[Queue.weak_from_this()] = Node;
716+
}
717+
692718
void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
693719
std::shared_ptr<node_impl> Dest) {
694720
throwIfGraphRecordingQueue("make_edge()");
@@ -769,11 +795,10 @@ std::vector<sycl::detail::EventImplPtr> graph_impl::getExitNodesEvents(
769795
return Events;
770796
}
771797

772-
void graph_impl::beginRecording(
773-
const std::shared_ptr<sycl::detail::queue_impl> &Queue) {
798+
void graph_impl::beginRecording(sycl::detail::queue_impl &Queue) {
774799
graph_impl::WriteLock Lock(MMutex);
775-
if (!Queue->hasCommandGraph()) {
776-
Queue->setCommandGraph(shared_from_this());
800+
if (!Queue.hasCommandGraph()) {
801+
Queue.setCommandGraph(shared_from_this());
777802
addQueue(Queue);
778803
}
779804
}
@@ -1003,7 +1028,7 @@ exec_graph_impl::~exec_graph_impl() {
10031028
}
10041029

10051030
sycl::event
1006-
exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
1031+
exec_graph_impl::enqueue(sycl::detail::queue_impl &Queue,
10071032
sycl::detail::CG::StorageInitHelper CGData) {
10081033
WriteLock Lock(MMutex);
10091034

@@ -1012,8 +1037,9 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
10121037
PartitionsExecutionEvents;
10131038

10141039
auto CreateNewEvent([&]() {
1015-
auto NewEvent = std::make_shared<sycl::detail::event_impl>(Queue);
1016-
NewEvent->setContextImpl(Queue->getContextImplPtr());
1040+
auto NewEvent =
1041+
std::make_shared<sycl::detail::event_impl>(Queue.shared_from_this());
1042+
NewEvent->setContextImpl(Queue.getContextImplPtr());
10171043
NewEvent->setStateIncomplete();
10181044
return NewEvent;
10191045
});
@@ -1035,7 +1061,7 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
10351061
CGData.MEvents.push_back(PartitionsExecutionEvents[DepPartition]);
10361062
}
10371063

1038-
auto CommandBuffer = CurrentPartition->MCommandBuffers[Queue->get_device()];
1064+
auto CommandBuffer = CurrentPartition->MCommandBuffers[Queue.get_device()];
10391065

10401066
if (CommandBuffer) {
10411067
for (std::vector<sycl::detail::EventImplPtr>::iterator It =
@@ -1073,10 +1099,10 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
10731099
if (CGData.MRequirements.empty() && CGData.MEvents.empty()) {
10741100
NewEvent->setSubmissionTime();
10751101
ur_result_t Res =
1076-
Queue->getAdapter()
1102+
Queue.getAdapter()
10771103
->call_nocheck<
10781104
sycl::detail::UrApiKind::urEnqueueCommandBufferExp>(
1079-
Queue->getHandleRef(), CommandBuffer, 0, nullptr, &UREvent);
1105+
Queue.getHandleRef(), CommandBuffer, 0, nullptr, &UREvent);
10801106
NewEvent->setHandle(UREvent);
10811107
if (Res == UR_RESULT_ERROR_INVALID_QUEUE_PROPERTIES) {
10821108
throw sycl::exception(
@@ -1096,7 +1122,8 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
10961122
CommandBuffer, nullptr, std::move(CGData));
10971123

10981124
NewEvent = sycl::detail::Scheduler::getInstance().addCG(
1099-
std::move(CommandGroup), Queue, /*EventNeeded=*/true);
1125+
std::move(CommandGroup), Queue.shared_from_this(),
1126+
/*EventNeeded=*/true);
11001127
}
11011128
NewEvent->setEventFromSubmittedExecCommandBuffer(true);
11021129
} else if ((CurrentPartition->MSchedule.size() > 0) &&
@@ -1112,10 +1139,11 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
11121139
// In case of graph, this queue may differ from the actual execution
11131140
// queue. We therefore overload this Queue before submitting the task.
11141141
static_cast<sycl::detail::CGHostTask &>(*NodeImpl->MCommandGroup.get())
1115-
.MQueue = Queue;
1142+
.MQueue = Queue.shared_from_this();
11161143

11171144
NewEvent = sycl::detail::Scheduler::getInstance().addCG(
1118-
NodeImpl->getCGCopy(), Queue, /*EventNeeded=*/true);
1145+
NodeImpl->getCGCopy(), Queue.shared_from_this(),
1146+
/*EventNeeded=*/true);
11191147
}
11201148
PartitionsExecutionEvents[CurrentPartition] = NewEvent;
11211149
}
@@ -1844,21 +1872,20 @@ void modifiable_command_graph::begin_recording(
18441872
// related to graph at all.
18451873
checkGraphPropertiesAndThrow(PropList);
18461874

1847-
auto QueueImpl = sycl::detail::getSyclObjImpl(RecordingQueue);
1848-
assert(QueueImpl);
1875+
queue_impl &QueueImpl = *sycl::detail::getSyclObjImpl(RecordingQueue);
18491876

1850-
if (QueueImpl->hasCommandGraph()) {
1877+
if (QueueImpl.hasCommandGraph()) {
18511878
throw sycl::exception(sycl::make_error_code(errc::invalid),
18521879
"begin_recording cannot be called for a queue which "
18531880
"is already in the recording state.");
18541881
}
18551882

1856-
if (QueueImpl->get_context() != impl->getContext()) {
1883+
if (QueueImpl.get_context() != impl->getContext()) {
18571884
throw sycl::exception(sycl::make_error_code(errc::invalid),
18581885
"begin_recording called for a queue whose context "
18591886
"differs from the graph context.");
18601887
}
1861-
if (QueueImpl->get_device() != impl->getDevice()) {
1888+
if (QueueImpl.get_device() != impl->getDevice()) {
18621889
throw sycl::exception(sycl::make_error_code(errc::invalid),
18631890
"begin_recording called for a queue whose device "
18641891
"differs from the graph device.");
@@ -1881,15 +1908,13 @@ void modifiable_command_graph::end_recording() {
18811908
}
18821909

18831910
void modifiable_command_graph::end_recording(queue &RecordingQueue) {
1884-
auto QueueImpl = sycl::detail::getSyclObjImpl(RecordingQueue);
1885-
if (!QueueImpl)
1886-
return;
1887-
if (QueueImpl->getCommandGraph() == impl) {
1888-
QueueImpl->setCommandGraph(nullptr);
1911+
queue_impl &QueueImpl = *sycl::detail::getSyclObjImpl(RecordingQueue);
1912+
if (QueueImpl.getCommandGraph() == impl) {
1913+
QueueImpl.setCommandGraph(nullptr);
18891914
graph_impl::WriteLock Lock(impl->MMutex);
18901915
impl->removeQueue(QueueImpl);
18911916
}
1892-
if (QueueImpl->hasCommandGraph())
1917+
if (QueueImpl.hasCommandGraph())
18931918
throw sycl::exception(sycl::make_error_code(errc::invalid),
18941919
"end_recording called for a queue which is recording "
18951920
"to a different graph.");

sycl/source/detail/graph_impl.hpp

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -878,18 +878,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
878878
/// Add a queue to the set of queues which are currently recording to this
879879
/// graph.
880880
/// @param RecordingQueue Queue to add to set.
881-
void
882-
addQueue(const std::shared_ptr<sycl::detail::queue_impl> &RecordingQueue) {
883-
MRecordingQueues.insert(RecordingQueue);
884-
}
881+
void addQueue(sycl::detail::queue_impl &RecordingQueue);
885882

886883
/// Remove a queue from the set of queues which are currently recording to
887884
/// this graph.
888885
/// @param RecordingQueue Queue to remove from set.
889-
void
890-
removeQueue(const std::shared_ptr<sycl::detail::queue_impl> &RecordingQueue) {
891-
MRecordingQueues.erase(RecordingQueue);
892-
}
886+
void removeQueue(sycl::detail::queue_impl &RecordingQueue);
893887

894888
/// Remove all queues which are recording to this graph, also sets all queues
895889
/// cleared back to the executing state.
@@ -1001,22 +995,13 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
1001995
/// @return Last node in this graph added from \p Queue recording, or empty
1002996
/// shared pointer if none.
1003997
std::shared_ptr<node_impl>
1004-
getLastInorderNode(std::shared_ptr<sycl::detail::queue_impl> Queue) {
1005-
std::weak_ptr<sycl::detail::queue_impl> QueueWeakPtr(Queue);
1006-
if (0 == MInorderQueueMap.count(QueueWeakPtr)) {
1007-
return {};
1008-
}
1009-
return MInorderQueueMap[QueueWeakPtr];
1010-
}
998+
getLastInorderNode(sycl::detail::queue_impl *Queue);
1011999

10121000
/// Track the last node added to this graph from an in-order queue.
10131001
/// @param Queue In-order queue to register \p Node for.
10141002
/// @param Node Last node that was added to this graph from \p Queue.
1015-
void setLastInorderNode(std::shared_ptr<sycl::detail::queue_impl> Queue,
1016-
std::shared_ptr<node_impl> Node) {
1017-
std::weak_ptr<sycl::detail::queue_impl> QueueWeakPtr(Queue);
1018-
MInorderQueueMap[QueueWeakPtr] = Node;
1019-
}
1003+
void setLastInorderNode(sycl::detail::queue_impl &Queue,
1004+
std::shared_ptr<node_impl> Node);
10201005

10211006
/// Prints the contents of the graph to a text file in DOT format.
10221007
/// @param FilePath Path to the output file.
@@ -1176,7 +1161,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
11761161
/// Sets the Queue state to queue_state::recording. Adds the queue to the list
11771162
/// of recording queues associated with this graph.
11781163
/// @param[in] Queue The queue to be recorded from.
1179-
void beginRecording(const std::shared_ptr<sycl::detail::queue_impl> &Queue);
1164+
void beginRecording(sycl::detail::queue_impl &Queue);
11801165

11811166
/// Store the last barrier node that was submitted to the queue.
11821167
/// @param[in] Queue The queue the barrier was recorded from.
@@ -1346,7 +1331,7 @@ class exec_graph_impl {
13461331
/// @param Queue Command-queue to schedule execution on.
13471332
/// @param CGData Command-group data provided by the sycl::handler
13481333
/// @return Event associated with the execution of the graph.
1349-
sycl::event enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
1334+
sycl::event enqueue(sycl::detail::queue_impl &Queue,
13501335
sycl::detail::CG::StorageInitHelper CGData);
13511336

13521337
/// Turns the internal graph representation into UR command-buffers for a

sycl/source/handler.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ event handler::finalize() {
783783

784784
} else {
785785
event GraphCompletionEvent =
786-
impl->MExecGraph->enqueue(MQueue, std::move(impl->CGData));
786+
impl->MExecGraph->enqueue(impl->get_queue(), std::move(impl->CGData));
787787

788788
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
789789
MLastEvent = getSyclObjImpl(GraphCompletionEvent);
@@ -870,15 +870,16 @@ event handler::finalize() {
870870
// node can set it as a predecessor.
871871
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
872872
Deps;
873-
if (auto DependentNode = GraphImpl->getLastInorderNode(MQueue)) {
873+
if (auto DependentNode =
874+
GraphImpl->getLastInorderNode(impl->get_queue_or_null())) {
874875
Deps.push_back(std::move(DependentNode));
875876
}
876877
NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps);
877878

878879
// If we are recording an in-order queue remember the new node, so it
879880
// can be used as a dependency for any more nodes recorded from this
880881
// queue.
881-
GraphImpl->setLastInorderNode(MQueue, NodeImpl);
882+
GraphImpl->setLastInorderNode(*MQueue, NodeImpl);
882883
} else {
883884
auto LastBarrierRecordedFromQueue = GraphImpl->getBarrierDep(MQueue);
884885
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
@@ -1988,7 +1989,7 @@ void handler::depends_on(const detail::EventImplPtr &EventImpl) {
19881989
// we need to set it to recording (implements the transitive queue recording
19891990
// feature).
19901991
if (!QueueGraph) {
1991-
EventGraph->beginRecording(MQueue);
1992+
EventGraph->beginRecording(impl->get_queue());
19921993
}
19931994
}
19941995

0 commit comments

Comments
 (0)