Skip to content

Commit

Permalink
[OPENMP] Initial support for 'task_reduction' clause.
Browse files Browse the repository at this point in the history
Parsing/sema analysis of the 'task_reduction' clause.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@308352 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
alexey-bataev committed Jul 18, 2017
1 parent 04cbc70 commit c874b4b
Show file tree
Hide file tree
Showing 21 changed files with 929 additions and 60 deletions.
211 changes: 211 additions & 0 deletions include/clang/AST/OpenMPClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -1890,6 +1890,217 @@ class OMPReductionClause final
}
};

/// This represents clause 'task_reduction' in the '#pragma omp taskgroup'
/// directives.
///
/// \code
/// #pragma omp taskgroup task_reduction(+:a,b)
/// \endcode
/// In this example directive '#pragma omp taskgroup' has clause
/// 'task_reduction' with operator '+' and the variables 'a' and 'b'.
///
class OMPTaskReductionClause final
: public OMPVarListClause<OMPTaskReductionClause>,
public OMPClauseWithPostUpdate,
private llvm::TrailingObjects<OMPTaskReductionClause, Expr *> {
friend TrailingObjects;
friend OMPVarListClause;
friend class OMPClauseReader;
/// Location of ':'.
SourceLocation ColonLoc;
/// Nested name specifier for C++.
NestedNameSpecifierLoc QualifierLoc;
/// Name of custom operator.
DeclarationNameInfo NameInfo;

/// Build clause with number of variables \a N.
///
/// \param StartLoc Starting location of the clause.
/// \param LParenLoc Location of '('.
/// \param EndLoc Ending location of the clause.
/// \param ColonLoc Location of ':'.
/// \param N Number of the variables in the clause.
/// \param QualifierLoc The nested-name qualifier with location information
/// \param NameInfo The full name info for reduction identifier.
///
OMPTaskReductionClause(SourceLocation StartLoc, SourceLocation LParenLoc,
SourceLocation ColonLoc, SourceLocation EndLoc,
unsigned N, NestedNameSpecifierLoc QualifierLoc,
const DeclarationNameInfo &NameInfo)
: OMPVarListClause<OMPTaskReductionClause>(OMPC_task_reduction, StartLoc,
LParenLoc, EndLoc, N),
OMPClauseWithPostUpdate(this), ColonLoc(ColonLoc),
QualifierLoc(QualifierLoc), NameInfo(NameInfo) {}

/// Build an empty clause.
///
/// \param N Number of variables.
///
explicit OMPTaskReductionClause(unsigned N)
: OMPVarListClause<OMPTaskReductionClause>(
OMPC_task_reduction, SourceLocation(), SourceLocation(),
SourceLocation(), N),
OMPClauseWithPostUpdate(this), ColonLoc(), QualifierLoc(), NameInfo() {}

/// Sets location of ':' symbol in clause.
void setColonLoc(SourceLocation CL) { ColonLoc = CL; }
/// Sets the name info for specified reduction identifier.
void setNameInfo(DeclarationNameInfo DNI) { NameInfo = DNI; }
/// Sets the nested name specifier.
void setQualifierLoc(NestedNameSpecifierLoc NSL) { QualifierLoc = NSL; }

/// Set list of helper expressions, required for proper codegen of the clause.
/// These expressions represent private copy of the reduction variable.
void setPrivates(ArrayRef<Expr *> Privates);

/// Get the list of helper privates.
MutableArrayRef<Expr *> getPrivates() {
return MutableArrayRef<Expr *>(varlist_end(), varlist_size());
}
ArrayRef<const Expr *> getPrivates() const {
return llvm::makeArrayRef(varlist_end(), varlist_size());
}

/// Set list of helper expressions, required for proper codegen of the clause.
/// These expressions represent LHS expression in the final reduction
/// expression performed by the reduction clause.
void setLHSExprs(ArrayRef<Expr *> LHSExprs);

/// Get the list of helper LHS expressions.
MutableArrayRef<Expr *> getLHSExprs() {
return MutableArrayRef<Expr *>(getPrivates().end(), varlist_size());
}
ArrayRef<const Expr *> getLHSExprs() const {
return llvm::makeArrayRef(getPrivates().end(), varlist_size());
}

/// Set list of helper expressions, required for proper codegen of the clause.
/// These expressions represent RHS expression in the final reduction
/// expression performed by the reduction clause. Also, variables in these
/// expressions are used for proper initialization of reduction copies.
void setRHSExprs(ArrayRef<Expr *> RHSExprs);

/// Get the list of helper destination expressions.
MutableArrayRef<Expr *> getRHSExprs() {
return MutableArrayRef<Expr *>(getLHSExprs().end(), varlist_size());
}
ArrayRef<const Expr *> getRHSExprs() const {
return llvm::makeArrayRef(getLHSExprs().end(), varlist_size());
}

/// Set list of helper reduction expressions, required for proper
/// codegen of the clause. These expressions are binary expressions or
/// operator/custom reduction call that calculates new value from source
/// helper expressions to destination helper expressions.
void setReductionOps(ArrayRef<Expr *> ReductionOps);

/// Get the list of helper reduction expressions.
MutableArrayRef<Expr *> getReductionOps() {
return MutableArrayRef<Expr *>(getRHSExprs().end(), varlist_size());
}
ArrayRef<const Expr *> getReductionOps() const {
return llvm::makeArrayRef(getRHSExprs().end(), varlist_size());
}

public:
/// Creates clause with a list of variables \a VL.
///
/// \param StartLoc Starting location of the clause.
/// \param LParenLoc Location of '('.
/// \param ColonLoc Location of ':'.
/// \param EndLoc Ending location of the clause.
/// \param VL The variables in the clause.
/// \param QualifierLoc The nested-name qualifier with location information
/// \param NameInfo The full name info for reduction identifier.
/// \param Privates List of helper expressions for proper generation of
/// private copies.
/// \param LHSExprs List of helper expressions for proper generation of
/// assignment operation required for copyprivate clause. This list represents
/// LHSs of the reduction expressions.
/// \param RHSExprs List of helper expressions for proper generation of
/// assignment operation required for copyprivate clause. This list represents
/// RHSs of the reduction expressions.
/// Also, variables in these expressions are used for proper initialization of
/// reduction copies.
/// \param ReductionOps List of helper expressions that represents reduction
/// expressions:
/// \code
/// LHSExprs binop RHSExprs;
/// operator binop(LHSExpr, RHSExpr);
/// <CutomReduction>(LHSExpr, RHSExpr);
/// \endcode
/// Required for proper codegen of final reduction operation performed by the
/// reduction clause.
/// \param PreInit Statement that must be executed before entering the OpenMP
/// region with this clause.
/// \param PostUpdate Expression that must be executed after exit from the
/// OpenMP region with this clause.
///
static OMPTaskReductionClause *
Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation LParenLoc,
SourceLocation ColonLoc, SourceLocation EndLoc, ArrayRef<Expr *> VL,
NestedNameSpecifierLoc QualifierLoc,
const DeclarationNameInfo &NameInfo, ArrayRef<Expr *> Privates,
ArrayRef<Expr *> LHSExprs, ArrayRef<Expr *> RHSExprs,
ArrayRef<Expr *> ReductionOps, Stmt *PreInit, Expr *PostUpdate);

/// Creates an empty clause with the place for \a N variables.
///
/// \param C AST context.
/// \param N The number of variables.
///
static OMPTaskReductionClause *CreateEmpty(const ASTContext &C, unsigned N);

/// Gets location of ':' symbol in clause.
SourceLocation getColonLoc() const { return ColonLoc; }
/// Gets the name info for specified reduction identifier.
const DeclarationNameInfo &getNameInfo() const { return NameInfo; }
/// Gets the nested name specifier.
NestedNameSpecifierLoc getQualifierLoc() const { return QualifierLoc; }

typedef MutableArrayRef<Expr *>::iterator helper_expr_iterator;
typedef ArrayRef<const Expr *>::iterator helper_expr_const_iterator;
typedef llvm::iterator_range<helper_expr_iterator> helper_expr_range;
typedef llvm::iterator_range<helper_expr_const_iterator>
helper_expr_const_range;

helper_expr_const_range privates() const {
return helper_expr_const_range(getPrivates().begin(), getPrivates().end());
}
helper_expr_range privates() {
return helper_expr_range(getPrivates().begin(), getPrivates().end());
}
helper_expr_const_range lhs_exprs() const {
return helper_expr_const_range(getLHSExprs().begin(), getLHSExprs().end());
}
helper_expr_range lhs_exprs() {
return helper_expr_range(getLHSExprs().begin(), getLHSExprs().end());
}
helper_expr_const_range rhs_exprs() const {
return helper_expr_const_range(getRHSExprs().begin(), getRHSExprs().end());
}
helper_expr_range rhs_exprs() {
return helper_expr_range(getRHSExprs().begin(), getRHSExprs().end());
}
helper_expr_const_range reduction_ops() const {
return helper_expr_const_range(getReductionOps().begin(),
getReductionOps().end());
}
helper_expr_range reduction_ops() {
return helper_expr_range(getReductionOps().begin(),
getReductionOps().end());
}

child_range children() {
return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
reinterpret_cast<Stmt **>(varlist_end()));
}

static bool classof(const OMPClause *T) {
return T->getClauseKind() == OMPC_task_reduction;
}
};

/// \brief This represents clause 'linear' in the '#pragma omp ...'
/// directives.
///
Expand Down
22 changes: 22 additions & 0 deletions include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3016,6 +3016,28 @@ RecursiveASTVisitor<Derived>::VisitOMPReductionClause(OMPReductionClause *C) {
return true;
}

template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPTaskReductionClause(
OMPTaskReductionClause *C) {
TRY_TO(TraverseNestedNameSpecifierLoc(C->getQualifierLoc()));
TRY_TO(TraverseDeclarationNameInfo(C->getNameInfo()));
TRY_TO(VisitOMPClauseList(C));
TRY_TO(VisitOMPClauseWithPostUpdate(C));
for (auto *E : C->privates()) {
TRY_TO(TraverseStmt(E));
}
for (auto *E : C->lhs_exprs()) {
TRY_TO(TraverseStmt(E));
}
for (auto *E : C->rhs_exprs()) {
TRY_TO(TraverseStmt(E));
}
for (auto *E : C->reduction_ops()) {
TRY_TO(TraverseStmt(E));
}
return true;
}

template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPFlushClause(OMPFlushClause *C) {
TRY_TO(VisitOMPClauseList(C));
Expand Down
34 changes: 20 additions & 14 deletions include/clang/AST/StmtOpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -1895,47 +1895,53 @@ class OMPTaskwaitDirective : public OMPExecutableDirective {
}
};

/// \brief This represents '#pragma omp taskgroup' directive.
/// This represents '#pragma omp taskgroup' directive.
///
/// \code
/// #pragma omp taskgroup
/// \endcode
///
class OMPTaskgroupDirective : public OMPExecutableDirective {
friend class ASTStmtReader;
/// \brief Build directive with the given start and end location.
/// Build directive with the given start and end location.
///
/// \param StartLoc Starting location of the directive kind.
/// \param EndLoc Ending location of the directive.
/// \param NumClauses Number of clauses.
///
OMPTaskgroupDirective(SourceLocation StartLoc, SourceLocation EndLoc)
OMPTaskgroupDirective(SourceLocation StartLoc, SourceLocation EndLoc,
unsigned NumClauses)
: OMPExecutableDirective(this, OMPTaskgroupDirectiveClass, OMPD_taskgroup,
StartLoc, EndLoc, 0, 1) {}
StartLoc, EndLoc, NumClauses, 1) {}

/// \brief Build an empty directive.
/// Build an empty directive.
/// \param NumClauses Number of clauses.
///
explicit OMPTaskgroupDirective()
explicit OMPTaskgroupDirective(unsigned NumClauses)
: OMPExecutableDirective(this, OMPTaskgroupDirectiveClass, OMPD_taskgroup,
SourceLocation(), SourceLocation(), 0, 1) {}
SourceLocation(), SourceLocation(), NumClauses,
1) {}

public:
/// \brief Creates directive.
/// Creates directive.
///
/// \param C AST context.
/// \param StartLoc Starting location of the directive kind.
/// \param EndLoc Ending Location of the directive.
/// \param Clauses List of clauses.
/// \param AssociatedStmt Statement, associated with the directive.
///
static OMPTaskgroupDirective *Create(const ASTContext &C,
SourceLocation StartLoc,
SourceLocation EndLoc,
Stmt *AssociatedStmt);
static OMPTaskgroupDirective *
Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt);

/// \brief Creates an empty directive.
/// Creates an empty directive.
///
/// \param C AST context.
/// \param NumClauses Number of clauses.
///
static OMPTaskgroupDirective *CreateEmpty(const ASTContext &C, EmptyShell);
static OMPTaskgroupDirective *CreateEmpty(const ASTContext &C,
unsigned NumClauses, EmptyShell);

static bool classof(const Stmt *T) {
return T->getStmtClass() == OMPTaskgroupDirectiveClass;
Expand Down
6 changes: 3 additions & 3 deletions include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -8655,11 +8655,11 @@ def err_omp_unknown_reduction_identifier : Error<
def err_omp_not_resolved_reduction_identifier : Error<
"unable to resolve declare reduction construct for type %0">;
def err_omp_reduction_ref_type_arg : Error<
"argument of OpenMP clause 'reduction' must reference the same object in all threads">;
"argument of OpenMP clause '%0' must reference the same object in all threads">;
def err_omp_clause_not_arithmetic_type_arg : Error<
"arguments of OpenMP clause 'reduction' for 'min' or 'max' must be of %select{scalar|arithmetic}0 type">;
"arguments of OpenMP clause '%0' for 'min' or 'max' must be of %select{scalar|arithmetic}1 type">;
def err_omp_clause_floating_type_arg : Error<
"arguments of OpenMP clause 'reduction' with bitwise operators cannot be of floating type">;
"arguments of OpenMP clause '%0' with bitwise operators cannot be of floating type">;
def err_omp_once_referenced : Error<
"variable can appear only once in OpenMP '%0' clause">;
def err_omp_once_referenced_in_target_update : Error<
Expand Down
8 changes: 8 additions & 0 deletions include/clang/Basic/OpenMPKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@
#ifndef OPENMP_TARGET_TEAMS_DISTRIBUTE_SIMD_CLAUSE
#define OPENMP_TARGET_TEAMS_DISTRIBUTE_SIMD_CLAUSE(Name)
#endif
#ifndef OPENMP_TASKGROUP_CLAUSE
#define OPENMP_TASKGROUP_CLAUSE(Name)
#endif

// OpenMP directives.
OPENMP_DIRECTIVE(threadprivate)
Expand Down Expand Up @@ -270,6 +273,7 @@ OPENMP_CLAUSE(to, OMPToClause)
OPENMP_CLAUSE(from, OMPFromClause)
OPENMP_CLAUSE(use_device_ptr, OMPUseDevicePtrClause)
OPENMP_CLAUSE(is_device_ptr, OMPIsDevicePtrClause)
OPENMP_CLAUSE(task_reduction, OMPTaskReductionClause)

// Clauses allowed for OpenMP directive 'parallel'.
OPENMP_PARALLEL_CLAUSE(if)
Expand Down Expand Up @@ -848,6 +852,10 @@ OPENMP_TARGET_TEAMS_DISTRIBUTE_SIMD_CLAUSE(aligned)
OPENMP_TARGET_TEAMS_DISTRIBUTE_SIMD_CLAUSE(safelen)
OPENMP_TARGET_TEAMS_DISTRIBUTE_SIMD_CLAUSE(simdlen)

// Clauses allowed for OpenMP directive 'taskgroup'.
OPENMP_TASKGROUP_CLAUSE(task_reduction)

#undef OPENMP_TASKGROUP_CLAUSE
#undef OPENMP_TASKLOOP_SIMD_CLAUSE
#undef OPENMP_TASKLOOP_CLAUSE
#undef OPENMP_LINEAR_KIND
Expand Down
10 changes: 9 additions & 1 deletion include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -8682,7 +8682,8 @@ class Sema {
StmtResult ActOnOpenMPTaskwaitDirective(SourceLocation StartLoc,
SourceLocation EndLoc);
/// \brief Called on well-formed '\#pragma omp taskgroup'.
StmtResult ActOnOpenMPTaskgroupDirective(Stmt *AStmt, SourceLocation StartLoc,
StmtResult ActOnOpenMPTaskgroupDirective(ArrayRef<OMPClause *> Clauses,
Stmt *AStmt, SourceLocation StartLoc,
SourceLocation EndLoc);
/// \brief Called on well-formed '\#pragma omp flush'.
StmtResult ActOnOpenMPFlushDirective(ArrayRef<OMPClause *> Clauses,
Expand Down Expand Up @@ -9024,6 +9025,13 @@ class Sema {
CXXScopeSpec &ReductionIdScopeSpec,
const DeclarationNameInfo &ReductionId,
ArrayRef<Expr *> UnresolvedReductions = llvm::None);
/// Called on well-formed 'task_reduction' clause.
OMPClause *ActOnOpenMPTaskReductionClause(
ArrayRef<Expr *> VarList, SourceLocation StartLoc,
SourceLocation LParenLoc, SourceLocation ColonLoc, SourceLocation EndLoc,
CXXScopeSpec &ReductionIdScopeSpec,
const DeclarationNameInfo &ReductionId,
ArrayRef<Expr *> UnresolvedReductions = llvm::None);
/// \brief Called on well-formed 'linear' clause.
OMPClause *
ActOnOpenMPLinearClause(ArrayRef<Expr *> VarList, Expr *Step,
Expand Down
Loading

0 comments on commit c874b4b

Please sign in to comment.