@@ -659,6 +659,14 @@ graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
659
659
return NodeImpl;
660
660
}
661
661
662
+ void graph_impl::addQueue (sycl::detail::queue_impl &RecordingQueue) {
663
+ MRecordingQueues.insert (RecordingQueue.weak_from_this ());
664
+ }
665
+
666
+ void graph_impl::removeQueue (sycl::detail::queue_impl &RecordingQueue) {
667
+ MRecordingQueues.erase (RecordingQueue.weak_from_this ());
668
+ }
669
+
662
670
bool graph_impl::clearQueues () {
663
671
bool AnyQueuesCleared = false ;
664
672
for (auto &Queue : MRecordingQueues) {
@@ -689,6 +697,24 @@ bool graph_impl::checkForCycles() {
689
697
return CycleFound;
690
698
}
691
699
700
+ std::shared_ptr<node_impl>
701
+ graph_impl::getLastInorderNode (sycl::detail::queue_impl *Queue) {
702
+ if (!Queue) {
703
+ assert (0 ==
704
+ MInorderQueueMap.count (std::weak_ptr<sycl::detail::queue_impl>{}));
705
+ return {};
706
+ }
707
+ if (0 == MInorderQueueMap.count (Queue->weak_from_this ())) {
708
+ return {};
709
+ }
710
+ return MInorderQueueMap[Queue->weak_from_this ()];
711
+ }
712
+
713
+ void graph_impl::setLastInorderNode (sycl::detail::queue_impl &Queue,
714
+ std::shared_ptr<node_impl> Node) {
715
+ MInorderQueueMap[Queue.weak_from_this ()] = Node;
716
+ }
717
+
692
718
void graph_impl::makeEdge (std::shared_ptr<node_impl> Src,
693
719
std::shared_ptr<node_impl> Dest) {
694
720
throwIfGraphRecordingQueue (" make_edge()" );
@@ -769,11 +795,10 @@ std::vector<sycl::detail::EventImplPtr> graph_impl::getExitNodesEvents(
769
795
return Events;
770
796
}
771
797
772
- void graph_impl::beginRecording (
773
- const std::shared_ptr<sycl::detail::queue_impl> &Queue) {
798
+ void graph_impl::beginRecording (sycl::detail::queue_impl &Queue) {
774
799
graph_impl::WriteLock Lock (MMutex);
775
- if (!Queue-> hasCommandGraph ()) {
776
- Queue-> setCommandGraph (shared_from_this ());
800
+ if (!Queue. hasCommandGraph ()) {
801
+ Queue. setCommandGraph (shared_from_this ());
777
802
addQueue (Queue);
778
803
}
779
804
}
@@ -1003,7 +1028,7 @@ exec_graph_impl::~exec_graph_impl() {
1003
1028
}
1004
1029
1005
1030
sycl::event
1006
- exec_graph_impl::enqueue (const std::shared_ptr< sycl::detail::queue_impl> &Queue,
1031
+ exec_graph_impl::enqueue (sycl::detail::queue_impl &Queue,
1007
1032
sycl::detail::CG::StorageInitHelper CGData) {
1008
1033
WriteLock Lock (MMutex);
1009
1034
@@ -1012,8 +1037,9 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
1012
1037
PartitionsExecutionEvents;
1013
1038
1014
1039
auto CreateNewEvent ([&]() {
1015
- auto NewEvent = std::make_shared<sycl::detail::event_impl>(Queue);
1016
- NewEvent->setContextImpl (Queue->getContextImplPtr ());
1040
+ auto NewEvent =
1041
+ std::make_shared<sycl::detail::event_impl>(Queue.shared_from_this ());
1042
+ NewEvent->setContextImpl (Queue.getContextImplPtr ());
1017
1043
NewEvent->setStateIncomplete ();
1018
1044
return NewEvent;
1019
1045
});
@@ -1035,7 +1061,7 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
1035
1061
CGData.MEvents .push_back (PartitionsExecutionEvents[DepPartition]);
1036
1062
}
1037
1063
1038
- auto CommandBuffer = CurrentPartition->MCommandBuffers [Queue-> get_device ()];
1064
+ auto CommandBuffer = CurrentPartition->MCommandBuffers [Queue. get_device ()];
1039
1065
1040
1066
if (CommandBuffer) {
1041
1067
for (std::vector<sycl::detail::EventImplPtr>::iterator It =
@@ -1073,10 +1099,10 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
1073
1099
if (CGData.MRequirements .empty () && CGData.MEvents .empty ()) {
1074
1100
NewEvent->setSubmissionTime ();
1075
1101
ur_result_t Res =
1076
- Queue-> getAdapter ()
1102
+ Queue. getAdapter ()
1077
1103
->call_nocheck <
1078
1104
sycl::detail::UrApiKind::urEnqueueCommandBufferExp>(
1079
- Queue-> getHandleRef (), CommandBuffer, 0 , nullptr , &UREvent);
1105
+ Queue. getHandleRef (), CommandBuffer, 0 , nullptr , &UREvent);
1080
1106
NewEvent->setHandle (UREvent);
1081
1107
if (Res == UR_RESULT_ERROR_INVALID_QUEUE_PROPERTIES) {
1082
1108
throw sycl::exception (
@@ -1096,7 +1122,8 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
1096
1122
CommandBuffer, nullptr , std::move (CGData));
1097
1123
1098
1124
NewEvent = sycl::detail::Scheduler::getInstance ().addCG (
1099
- std::move (CommandGroup), Queue, /* EventNeeded=*/ true );
1125
+ std::move (CommandGroup), Queue.shared_from_this (),
1126
+ /* EventNeeded=*/ true );
1100
1127
}
1101
1128
NewEvent->setEventFromSubmittedExecCommandBuffer (true );
1102
1129
} else if ((CurrentPartition->MSchedule .size () > 0 ) &&
@@ -1112,10 +1139,11 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
1112
1139
// In case of graph, this queue may differ from the actual execution
1113
1140
// queue. We therefore overload this Queue before submitting the task.
1114
1141
static_cast <sycl::detail::CGHostTask &>(*NodeImpl->MCommandGroup .get ())
1115
- .MQueue = Queue;
1142
+ .MQueue = Queue. shared_from_this () ;
1116
1143
1117
1144
NewEvent = sycl::detail::Scheduler::getInstance ().addCG (
1118
- NodeImpl->getCGCopy (), Queue, /* EventNeeded=*/ true );
1145
+ NodeImpl->getCGCopy (), Queue.shared_from_this (),
1146
+ /* EventNeeded=*/ true );
1119
1147
}
1120
1148
PartitionsExecutionEvents[CurrentPartition] = NewEvent;
1121
1149
}
@@ -1844,21 +1872,20 @@ void modifiable_command_graph::begin_recording(
1844
1872
// related to graph at all.
1845
1873
checkGraphPropertiesAndThrow (PropList);
1846
1874
1847
- auto QueueImpl = sycl::detail::getSyclObjImpl (RecordingQueue);
1848
- assert (QueueImpl);
1875
+ queue_impl &QueueImpl = *sycl::detail::getSyclObjImpl (RecordingQueue);
1849
1876
1850
- if (QueueImpl-> hasCommandGraph ()) {
1877
+ if (QueueImpl. hasCommandGraph ()) {
1851
1878
throw sycl::exception (sycl::make_error_code (errc::invalid),
1852
1879
" begin_recording cannot be called for a queue which "
1853
1880
" is already in the recording state." );
1854
1881
}
1855
1882
1856
- if (QueueImpl-> get_context () != impl->getContext ()) {
1883
+ if (QueueImpl. get_context () != impl->getContext ()) {
1857
1884
throw sycl::exception (sycl::make_error_code (errc::invalid),
1858
1885
" begin_recording called for a queue whose context "
1859
1886
" differs from the graph context." );
1860
1887
}
1861
- if (QueueImpl-> get_device () != impl->getDevice ()) {
1888
+ if (QueueImpl. get_device () != impl->getDevice ()) {
1862
1889
throw sycl::exception (sycl::make_error_code (errc::invalid),
1863
1890
" begin_recording called for a queue whose device "
1864
1891
" differs from the graph device." );
@@ -1881,15 +1908,13 @@ void modifiable_command_graph::end_recording() {
1881
1908
}
1882
1909
1883
1910
void modifiable_command_graph::end_recording (queue &RecordingQueue) {
1884
- auto QueueImpl = sycl::detail::getSyclObjImpl (RecordingQueue);
1885
- if (!QueueImpl)
1886
- return ;
1887
- if (QueueImpl->getCommandGraph () == impl) {
1888
- QueueImpl->setCommandGraph (nullptr );
1911
+ queue_impl &QueueImpl = *sycl::detail::getSyclObjImpl (RecordingQueue);
1912
+ if (QueueImpl.getCommandGraph () == impl) {
1913
+ QueueImpl.setCommandGraph (nullptr );
1889
1914
graph_impl::WriteLock Lock (impl->MMutex );
1890
1915
impl->removeQueue (QueueImpl);
1891
1916
}
1892
- if (QueueImpl-> hasCommandGraph ())
1917
+ if (QueueImpl. hasCommandGraph ())
1893
1918
throw sycl::exception (sycl::make_error_code (errc::invalid),
1894
1919
" end_recording called for a queue which is recording "
1895
1920
" to a different graph." );
0 commit comments