diff --git a/sycl/include/sycl/ext/oneapi/experimental/graph.hpp b/sycl/include/sycl/ext/oneapi/experimental/graph.hpp index 521e063a1bc5c..45618c2793543 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/graph.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/graph.hpp @@ -263,7 +263,6 @@ class __SYCL_EXPORT executable_command_graph { /// Creates a backend representation of the graph in \p impl member variable. void finalizeImpl(); - int MTag; std::shared_ptr impl; }; } // namespace detail diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index 30d66ffc12e02..6d59b47fda71d 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -37,12 +37,12 @@ void connectToExitNodes( std::shared_ptr CurrentNode, const std::vector> &NewInputs) { if (CurrentNode->MSuccessors.size() > 0) { - for (auto Successor : CurrentNode->MSuccessors) { + for (auto &Successor : CurrentNode->MSuccessors) { connectToExitNodes(Successor, NewInputs); } } else { - for (auto Input : NewInputs) { + for (auto &Input : NewInputs) { CurrentNode->registerSuccessor(Input, CurrentNode); } } @@ -75,7 +75,7 @@ bool checkForRequirement(sycl::detail::AccessorImplHost *Req, void exec_graph_impl::schedule() { if (MSchedule.empty()) { - for (auto Node : MGraphImpl->MRoots) { + for (auto &Node : MGraphImpl->MRoots) { Node->sortTopological(Node, MSchedule); } } @@ -97,7 +97,7 @@ std::shared_ptr graph_impl::addSubgraphNodes( // Recursively walk the graph to find exit nodes and connect up the inputs // TODO: Consider caching exit nodes so we don't have to do this - for (auto NodeImpl : MRoots) { + for (auto &NodeImpl : MRoots) { connectToExitNodes(NodeImpl, Inputs); } @@ -118,7 +118,7 @@ graph_impl::add(const std::vector> &Dep) { // TODO: Encapsulate in separate function to avoid duplication if (!Dep.empty()) { - for (auto N : Dep) { + for (auto &N : Dep) { N->registerSuccessor(NodeImpl, N); // register successor this->removeRoot(NodeImpl); // remove receiver from root node // list @@ -180,13 +180,13 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType, const auto &Requirements = CommandGroup->getRequirements(); for (auto &Req : Requirements) { // Look through the graph for nodes which share this requirement - for (auto NodePtr : MRoots) { + for (auto &NodePtr : MRoots) { checkForRequirement(Req, NodePtr, UniqueDeps); } } // Add any nodes specified by event dependencies into the dependency list - for (auto Dep : CommandGroup->getEvents()) { + for (auto &Dep : CommandGroup->getEvents()) { if (auto NodeImpl = MEventsMap.find(Dep); NodeImpl != MEventsMap.end()) { if (UniqueDeps.find(NodeImpl->second) == UniqueDeps.end()) { UniqueDeps.insert(NodeImpl->second); @@ -204,7 +204,7 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType, const std::shared_ptr &NodeImpl = std::make_shared(CGType, std::move(CommandGroup)); if (!Deps.empty()) { - for (auto N : Deps) { + for (auto &N : Deps) { N->registerSuccessor(NodeImpl, N); // register successor this->removeRoot(NodeImpl); // remove receiver from root node // list @@ -218,8 +218,10 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType, bool graph_impl::clearQueues() { bool AnyQueuesCleared = false; for (auto &Queue : MRecordingQueues) { - Queue->setCommandGraph(nullptr); - AnyQueuesCleared = true; + if (Queue) { + Queue->setCommandGraph(nullptr); + AnyQueuesCleared = true; + } } MRecordingQueues.clear(); @@ -516,6 +518,17 @@ modifiable_command_graph::finalize(const sycl::property_list &) const { bool modifiable_command_graph::begin_recording(queue &RecordingQueue) { auto QueueImpl = sycl::detail::getSyclObjImpl(RecordingQueue); + assert(QueueImpl); + if (QueueImpl->get_context() != impl->getContext()) { + throw sycl::exception(sycl::make_error_code(errc::invalid), + "begin_recording called for a queue whose context " + "differs from the graph context."); + } + if (QueueImpl->get_device() != impl->getDevice()) { + throw sycl::exception(sycl::make_error_code(errc::invalid), + "begin_recording called for a queue whose device " + "differs from the graph device."); + } if (QueueImpl->is_in_fusion_mode()) { throw sycl::exception(sycl::make_error_code(errc::invalid), @@ -561,7 +574,7 @@ bool modifiable_command_graph::end_recording() { return impl->clearQueues(); } bool modifiable_command_graph::end_recording(queue &RecordingQueue) { auto QueueImpl = sycl::detail::getSyclObjImpl(RecordingQueue); - if (QueueImpl->getCommandGraph() == impl) { + if (QueueImpl && QueueImpl->getCommandGraph() == impl) { QueueImpl->setCommandGraph(nullptr); impl->removeQueue(QueueImpl); return true; @@ -587,8 +600,7 @@ bool modifiable_command_graph::end_recording( executable_command_graph::executable_command_graph( const std::shared_ptr &Graph, const sycl::context &Ctx) - : MTag(rand()), - impl(std::make_shared(Ctx, Graph)) { + : impl(std::make_shared(Ctx, Graph)) { finalizeImpl(); // Create backend representation for executable graph } diff --git a/sycl/source/detail/graph_impl.hpp b/sycl/source/detail/graph_impl.hpp index 5526aeaccef44..6140af005af5f 100644 --- a/sycl/source/detail/graph_impl.hpp +++ b/sycl/source/detail/graph_impl.hpp @@ -77,7 +77,7 @@ class node_impl { /// @param Schedule Execution ordering to add node to. void sortTopological(std::shared_ptr NodeImpl, std::list> &Schedule) { - for (auto Next : MSuccessors) { + for (auto &Next : MSuccessors) { // Check if we've already scheduled this node if (std::find(Schedule.begin(), Schedule.end(), Next) == Schedule.end()) Next->sortTopological(Next, Schedule); diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index 0c6d422fa83fe..769de1351cc41 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -223,7 +223,7 @@ event handler::finalize() { } } - if (MQueue && !MQueue->getCommandGraph() && !MGraph && !MSubgraphNode && + if (MQueue && !MGraph && !MSubgraphNode && !MQueue->getCommandGraph() && !MQueue->is_in_fusion_mode() && CGData.MRequirements.size() + CGData.MEvents.size() + MStreamStorage.size() == @@ -424,7 +424,7 @@ event handler::finalize() { // Empty nodes are handled by Graph like standard nodes // For Standard mode (non-graph), // empty nodes are not sent to the scheduler to save time - if (MGraph || MQueue->getCommandGraph()) { + if (MGraph || (MQueue && MQueue->getCommandGraph())) { CommandGroup.reset( new detail::CG(detail::CG::None, std::move(CGData), MCodeLoc)); } else {