Skip to content

Commit

Permalink
[SYCL] Fixed issues reported by a static verifier (intel#10835)
Browse files Browse the repository at this point in the history
Signed-off-by: Byoungro So <[email protected]>
  • Loading branch information
bso-intel authored Aug 23, 2023
1 parent 7290092 commit 9357162
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 17 deletions.
1 change: 0 additions & 1 deletion sycl/include/sycl/ext/oneapi/experimental/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ class __SYCL_EXPORT executable_command_graph {
/// Creates a backend representation of the graph in \p impl member variable.
void finalizeImpl();

int MTag;
std::shared_ptr<detail::exec_graph_impl> impl;
};
} // namespace detail
Expand Down
38 changes: 25 additions & 13 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ void connectToExitNodes(
std::shared_ptr<node_impl> CurrentNode,
const std::vector<std::shared_ptr<node_impl>> &NewInputs) {
if (CurrentNode->MSuccessors.size() > 0) {
for (auto Successor : CurrentNode->MSuccessors) {
for (auto &Successor : CurrentNode->MSuccessors) {
connectToExitNodes(Successor, NewInputs);
}

} else {
for (auto Input : NewInputs) {
for (auto &Input : NewInputs) {
CurrentNode->registerSuccessor(Input, CurrentNode);
}
}
Expand Down Expand Up @@ -75,7 +75,7 @@ bool checkForRequirement(sycl::detail::AccessorImplHost *Req,

void exec_graph_impl::schedule() {
if (MSchedule.empty()) {
for (auto Node : MGraphImpl->MRoots) {
for (auto &Node : MGraphImpl->MRoots) {
Node->sortTopological(Node, MSchedule);
}
}
Expand All @@ -97,7 +97,7 @@ std::shared_ptr<node_impl> graph_impl::addSubgraphNodes(

// Recursively walk the graph to find exit nodes and connect up the inputs
// TODO: Consider caching exit nodes so we don't have to do this
for (auto NodeImpl : MRoots) {
for (auto &NodeImpl : MRoots) {
connectToExitNodes(NodeImpl, Inputs);
}

Expand All @@ -118,7 +118,7 @@ graph_impl::add(const std::vector<std::shared_ptr<node_impl>> &Dep) {

// TODO: Encapsulate in separate function to avoid duplication
if (!Dep.empty()) {
for (auto N : Dep) {
for (auto &N : Dep) {
N->registerSuccessor(NodeImpl, N); // register successor
this->removeRoot(NodeImpl); // remove receiver from root node
// list
Expand Down Expand Up @@ -180,13 +180,13 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType,
const auto &Requirements = CommandGroup->getRequirements();
for (auto &Req : Requirements) {
// Look through the graph for nodes which share this requirement
for (auto NodePtr : MRoots) {
for (auto &NodePtr : MRoots) {
checkForRequirement(Req, NodePtr, UniqueDeps);
}
}

// Add any nodes specified by event dependencies into the dependency list
for (auto Dep : CommandGroup->getEvents()) {
for (auto &Dep : CommandGroup->getEvents()) {
if (auto NodeImpl = MEventsMap.find(Dep); NodeImpl != MEventsMap.end()) {
if (UniqueDeps.find(NodeImpl->second) == UniqueDeps.end()) {
UniqueDeps.insert(NodeImpl->second);
Expand All @@ -204,7 +204,7 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType,
const std::shared_ptr<node_impl> &NodeImpl =
std::make_shared<node_impl>(CGType, std::move(CommandGroup));
if (!Deps.empty()) {
for (auto N : Deps) {
for (auto &N : Deps) {
N->registerSuccessor(NodeImpl, N); // register successor
this->removeRoot(NodeImpl); // remove receiver from root node
// list
Expand All @@ -218,8 +218,10 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType,
bool graph_impl::clearQueues() {
bool AnyQueuesCleared = false;
for (auto &Queue : MRecordingQueues) {
Queue->setCommandGraph(nullptr);
AnyQueuesCleared = true;
if (Queue) {
Queue->setCommandGraph(nullptr);
AnyQueuesCleared = true;
}
}
MRecordingQueues.clear();

Expand Down Expand Up @@ -516,6 +518,17 @@ modifiable_command_graph::finalize(const sycl::property_list &) const {

bool modifiable_command_graph::begin_recording(queue &RecordingQueue) {
auto QueueImpl = sycl::detail::getSyclObjImpl(RecordingQueue);
assert(QueueImpl);
if (QueueImpl->get_context() != impl->getContext()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"begin_recording called for a queue whose context "
"differs from the graph context.");
}
if (QueueImpl->get_device() != impl->getDevice()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"begin_recording called for a queue whose device "
"differs from the graph device.");
}

if (QueueImpl->is_in_fusion_mode()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
Expand Down Expand Up @@ -561,7 +574,7 @@ bool modifiable_command_graph::end_recording() { return impl->clearQueues(); }

bool modifiable_command_graph::end_recording(queue &RecordingQueue) {
auto QueueImpl = sycl::detail::getSyclObjImpl(RecordingQueue);
if (QueueImpl->getCommandGraph() == impl) {
if (QueueImpl && QueueImpl->getCommandGraph() == impl) {
QueueImpl->setCommandGraph(nullptr);
impl->removeQueue(QueueImpl);
return true;
Expand All @@ -587,8 +600,7 @@ bool modifiable_command_graph::end_recording(

executable_command_graph::executable_command_graph(
const std::shared_ptr<detail::graph_impl> &Graph, const sycl::context &Ctx)
: MTag(rand()),
impl(std::make_shared<detail::exec_graph_impl>(Ctx, Graph)) {
: impl(std::make_shared<detail::exec_graph_impl>(Ctx, Graph)) {
finalizeImpl(); // Create backend representation for executable graph
}

Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class node_impl {
/// @param Schedule Execution ordering to add node to.
void sortTopological(std::shared_ptr<node_impl> NodeImpl,
std::list<std::shared_ptr<node_impl>> &Schedule) {
for (auto Next : MSuccessors) {
for (auto &Next : MSuccessors) {
// Check if we've already scheduled this node
if (std::find(Schedule.begin(), Schedule.end(), Next) == Schedule.end())
Next->sortTopological(Next, Schedule);
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ event handler::finalize() {
}
}

if (MQueue && !MQueue->getCommandGraph() && !MGraph && !MSubgraphNode &&
if (MQueue && !MGraph && !MSubgraphNode && !MQueue->getCommandGraph() &&
!MQueue->is_in_fusion_mode() &&
CGData.MRequirements.size() + CGData.MEvents.size() +
MStreamStorage.size() ==
Expand Down Expand Up @@ -424,7 +424,7 @@ event handler::finalize() {
// Empty nodes are handled by Graph like standard nodes
// For Standard mode (non-graph),
// empty nodes are not sent to the scheduler to save time
if (MGraph || MQueue->getCommandGraph()) {
if (MGraph || (MQueue && MQueue->getCommandGraph())) {
CommandGroup.reset(
new detail::CG(detail::CG::None, std::move(CGData), MCodeLoc));
} else {
Expand Down

0 comments on commit 9357162

Please sign in to comment.