Skip to content

Commit

Permalink
[clang-tidy] Support begin/end free functions in modernize-loop-convert
Browse files Browse the repository at this point in the history
The modernize-loop-convert check will now match iterator based loops
that call the free functions 'begin'/'end', as well as matching the
free function 'size' on containers.

Test plan: Added unit test cases matching free function calls on
containers, and a single negative test case for length() which is not
supported.

Reviewed By: PiotrZSL

Differential Revision: https://reviews.llvm.org/D140760
  • Loading branch information
ccotter authored and PiotrZSL committed Aug 5, 2023
1 parent e1a9da3 commit 6a1f8ef
Show file tree
Hide file tree
Showing 7 changed files with 376 additions and 45 deletions.
193 changes: 149 additions & 44 deletions clang-tools-extra/clang-tidy/modernize/LoopConvertCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cstring>
#include <optional>
#include <tuple>
#include <utility>

using namespace clang::ast_matchers;
Expand Down Expand Up @@ -66,6 +68,15 @@ static const char EndCallName[] = "endCall";
static const char EndVarName[] = "endVar";
static const char DerefByValueResultName[] = "derefByValueResult";
static const char DerefByRefResultName[] = "derefByRefResult";
static const llvm::StringSet<> MemberNames{"begin", "cbegin", "rbegin",
"crbegin", "end", "cend",
"rend", "crend", "size"};
static const llvm::StringSet<> ADLNames{"begin", "cbegin", "rbegin",
"crbegin", "end", "cend",
"rend", "crend", "size"};
static const llvm::StringSet<> StdNames{
"std::begin", "std::cbegin", "std::rbegin", "std::crbegin", "std::end",
"std::cend", "std::rend", "std::crend", "std::size"};

static const StatementMatcher integerComparisonMatcher() {
return expr(ignoringParenImpCasts(
Expand Down Expand Up @@ -129,6 +140,10 @@ StatementMatcher makeArrayLoopMatcher() {
/// e = createIterator(); it != e; ++it) { ... }
/// for (containerType::iterator it = container.begin();
/// it != anotherContainer.end(); ++it) { ... }
/// for (containerType::iterator it = begin(container),
/// e = end(container); it != e; ++it) { ... }
/// for (containerType::iterator it = std::begin(container),
/// e = std::end(container); it != e; ++it) { ... }
/// \endcode
/// The following string identifiers are bound to the parts of the AST:
/// InitVarName: 'it' (as a VarDecl)
Expand All @@ -137,20 +152,31 @@ StatementMatcher makeArrayLoopMatcher() {
/// EndVarName: 'e' (as a VarDecl)
/// In the second example only:
/// EndCallName: 'container.end()' (as a CXXMemberCallExpr)
/// In the third/fourth examples:
/// 'end(container)' or 'std::end(container)' (as a CallExpr)
///
/// Client code will need to make sure that:
/// - The two containers on which 'begin' and 'end' are called are the same.
StatementMatcher makeIteratorLoopMatcher(bool IsReverse) {

auto BeginNameMatcher = IsReverse ? hasAnyName("rbegin", "crbegin")
: hasAnyName("begin", "cbegin");
auto BeginNameMatcherStd = IsReverse
? hasAnyName("::std::rbegin", "::std::crbegin")
: hasAnyName("::std::begin", "::std::cbegin");

auto EndNameMatcher =
IsReverse ? hasAnyName("rend", "crend") : hasAnyName("end", "cend");
auto EndNameMatcherStd = IsReverse ? hasAnyName("::std::rend", "::std::crend")
: hasAnyName("::std::end", "::std::cend");

StatementMatcher BeginCallMatcher =
cxxMemberCallExpr(argumentCountIs(0),
callee(cxxMethodDecl(BeginNameMatcher)))
expr(anyOf(cxxMemberCallExpr(argumentCountIs(0),
callee(cxxMethodDecl(BeginNameMatcher))),
callExpr(argumentCountIs(1),
callee(functionDecl(BeginNameMatcher)), usesADL()),
callExpr(argumentCountIs(1),
callee(functionDecl(BeginNameMatcherStd)))))
.bind(BeginCallName);

DeclarationMatcher InitDeclMatcher =
Expand All @@ -163,8 +189,12 @@ StatementMatcher makeIteratorLoopMatcher(bool IsReverse) {
DeclarationMatcher EndDeclMatcher =
varDecl(hasInitializer(anything())).bind(EndVarName);

StatementMatcher EndCallMatcher = cxxMemberCallExpr(
argumentCountIs(0), callee(cxxMethodDecl(EndNameMatcher)));
StatementMatcher EndCallMatcher = expr(anyOf(
cxxMemberCallExpr(argumentCountIs(0),
callee(cxxMethodDecl(EndNameMatcher))),
callExpr(argumentCountIs(1), callee(functionDecl(EndNameMatcher)),
usesADL()),
callExpr(argumentCountIs(1), callee(functionDecl(EndNameMatcherStd)))));

StatementMatcher IteratorBoundMatcher =
expr(anyOf(ignoringParenImpCasts(
Expand Down Expand Up @@ -223,14 +253,16 @@ StatementMatcher makeIteratorLoopMatcher(bool IsReverse) {
/// \code
/// for (int i = 0, j = container.size(); i < j; ++i) { ... }
/// for (int i = 0; i < container.size(); ++i) { ... }
/// for (int i = 0; i < size(container); ++i) { ... }
/// \endcode
/// The following string identifiers are bound to the parts of the AST:
/// InitVarName: 'i' (as a VarDecl)
/// LoopName: The entire for loop (as a ForStmt)
/// In the first example only:
/// EndVarName: 'j' (as a VarDecl)
/// In the second example only:
/// EndCallName: 'container.size()' (as a CXXMemberCallExpr)
/// EndCallName: 'container.size()' (as a CXXMemberCallExpr) or
/// 'size(contaner)' (as a CallExpr)
///
/// Client code will need to make sure that:
/// - The containers on which 'size()' is called is the container indexed.
Expand Down Expand Up @@ -265,10 +297,15 @@ StatementMatcher makePseudoArrayLoopMatcher() {
hasMethod(hasName("end"))))))))) // qualType
));

StatementMatcher SizeCallMatcher = cxxMemberCallExpr(
argumentCountIs(0), callee(cxxMethodDecl(hasAnyName("size", "length"))),
on(anyOf(hasType(pointsTo(RecordWithBeginEnd)),
hasType(RecordWithBeginEnd))));
StatementMatcher SizeCallMatcher = expr(anyOf(
cxxMemberCallExpr(argumentCountIs(0),
callee(cxxMethodDecl(hasAnyName("size", "length"))),
on(anyOf(hasType(pointsTo(RecordWithBeginEnd)),
hasType(RecordWithBeginEnd)))),
callExpr(argumentCountIs(1), callee(functionDecl(hasName("size"))),
usesADL()),
callExpr(argumentCountIs(1),
callee(functionDecl(hasName("::std::size"))))));

StatementMatcher EndInitMatcher =
expr(anyOf(ignoringParenImpCasts(expr(SizeCallMatcher).bind(EndCallName)),
Expand Down Expand Up @@ -296,36 +333,97 @@ StatementMatcher makePseudoArrayLoopMatcher() {
.bind(LoopNamePseudoArray);
}

enum class IteratorCallKind {
ICK_Member,
ICK_ADL,
ICK_Std,
};

struct ContainerCall {
const Expr *Container;
StringRef Name;
bool IsArrow;
IteratorCallKind CallKind;
};

// Find the Expr likely initializing an iterator.
//
// Call is either a CXXMemberCallExpr ('c.begin()') or CallExpr of a free
// function with the first argument as a container ('begin(c)'), or nullptr.
// Returns at a 3-tuple with the container expr, function name (begin/end/etc),
// and whether the call is made through an arrow (->) for CXXMemberCallExprs.
// The returned Expr* is nullptr if any of the assumptions are not met.
// static std::tuple<const Expr *, StringRef, bool, IteratorCallKind>
static std::optional<ContainerCall> getContainerExpr(const Expr *Call) {
const Expr *Dug = digThroughConstructorsConversions(Call);

IteratorCallKind CallKind = IteratorCallKind::ICK_Member;

if (const auto *TheCall = dyn_cast_or_null<CXXMemberCallExpr>(Dug)) {
CallKind = IteratorCallKind::ICK_Member;
if (const auto *Member = dyn_cast<MemberExpr>(TheCall->getCallee())) {
if (Member->getMemberDecl() == nullptr ||
!MemberNames.contains(Member->getMemberDecl()->getName()))
return std::nullopt;
return ContainerCall{TheCall->getImplicitObjectArgument(),
Member->getMemberDecl()->getName(),
Member->isArrow(), CallKind};
} else {
if (TheCall->getDirectCallee() == nullptr ||
!MemberNames.contains(TheCall->getDirectCallee()->getName()))
return std::nullopt;
return ContainerCall{TheCall->getArg(0),
TheCall->getDirectCallee()->getName(), false,
CallKind};
}
} else if (const auto *TheCall = dyn_cast_or_null<CallExpr>(Dug)) {
if (TheCall->getNumArgs() != 1)
return std::nullopt;

if (TheCall->usesADL()) {
if (TheCall->getDirectCallee() == nullptr ||
!ADLNames.contains(TheCall->getDirectCallee()->getName()))
return std::nullopt;
CallKind = IteratorCallKind::ICK_ADL;
} else {
if (!StdNames.contains(
TheCall->getDirectCallee()->getQualifiedNameAsString()))
return std::nullopt;
CallKind = IteratorCallKind::ICK_Std;
}

if (TheCall->getDirectCallee() == nullptr)
return std::nullopt;

return ContainerCall{TheCall->getArg(0),
TheCall->getDirectCallee()->getName(), false,
CallKind};
}
return std::nullopt;
}

/// Determine whether Init appears to be an initializing an iterator.
///
/// If it is, returns the object whose begin() or end() method is called, and
/// the output parameter isArrow is set to indicate whether the initialization
/// is called via . or ->.
static const Expr *getContainerFromBeginEndCall(const Expr *Init, bool IsBegin,
bool *IsArrow, bool IsReverse) {
static std::pair<const Expr *, IteratorCallKind>
getContainerFromBeginEndCall(const Expr *Init, bool IsBegin, bool *IsArrow,
bool IsReverse) {
// FIXME: Maybe allow declaration/initialization outside of the for loop.
const auto *TheCall = dyn_cast_or_null<CXXMemberCallExpr>(
digThroughConstructorsConversions(Init));
if (!TheCall || TheCall->getNumArgs() != 0)
return nullptr;

const auto *Member = dyn_cast<MemberExpr>(TheCall->getCallee());
if (!Member)
return nullptr;
StringRef Name = Member->getMemberDecl()->getName();
if (!Name.consume_back(IsBegin ? "begin" : "end"))
return nullptr;
if (IsReverse && !Name.consume_back("r"))
return nullptr;
if (!Name.empty() && !Name.equals("c"))
return nullptr;

const Expr *SourceExpr = Member->getBase();
if (!SourceExpr)
return nullptr;

*IsArrow = Member->isArrow();
return SourceExpr;
std::optional<ContainerCall> Call = getContainerExpr(Init);
if (!Call)
return {};

*IsArrow = Call->IsArrow;
if (!Call->Name.consume_back(IsBegin ? "begin" : "end"))
return {};
if (IsReverse && !Call->Name.consume_back("r"))
return {};
if (!Call->Name.empty() && !Call->Name.equals("c"))
return {};
return std::make_pair(Call->Container, Call->CallKind);
}

/// Determines the container whose begin() and end() functions are called
Expand All @@ -341,13 +439,16 @@ static const Expr *findContainer(ASTContext *Context, const Expr *BeginExpr,
// valid.
bool BeginIsArrow = false;
bool EndIsArrow = false;
const Expr *BeginContainerExpr = getContainerFromBeginEndCall(
auto [BeginContainerExpr, BeginCallKind] = getContainerFromBeginEndCall(
BeginExpr, /*IsBegin=*/true, &BeginIsArrow, IsReverse);
if (!BeginContainerExpr)
return nullptr;

const Expr *EndContainerExpr = getContainerFromBeginEndCall(
auto [EndContainerExpr, EndCallKind] = getContainerFromBeginEndCall(
EndExpr, /*IsBegin=*/false, &EndIsArrow, IsReverse);
if (BeginCallKind != EndCallKind)
return nullptr;

// Disallow loops that try evil things like this (note the dot and arrow):
// for (IteratorType It = Obj.begin(), E = Obj->end(); It != E; ++It) { }
if (!EndContainerExpr || BeginIsArrow != EndIsArrow ||
Expand Down Expand Up @@ -832,10 +933,10 @@ bool LoopConvertCheck::isConvertible(ASTContext *Context,
QualType InitVarType = InitVar->getType();
QualType CanonicalInitVarType = InitVarType.getCanonicalType();

const auto *BeginCall = Nodes.getNodeAs<CXXMemberCallExpr>(BeginCallName);
const auto *BeginCall = Nodes.getNodeAs<CallExpr>(BeginCallName);
assert(BeginCall && "Bad Callback. No begin call expression");
QualType CanonicalBeginType =
BeginCall->getMethodDecl()->getReturnType().getCanonicalType();
BeginCall->getDirectCallee()->getReturnType().getCanonicalType();
if (CanonicalBeginType->isPointerType() &&
CanonicalInitVarType->isPointerType()) {
// If the initializer and the variable are both pointers check if the
Expand All @@ -846,10 +947,12 @@ bool LoopConvertCheck::isConvertible(ASTContext *Context,
return false;
}
} else if (FixerKind == LFK_PseudoArray) {
// This call is required to obtain the container.
const auto *EndCall = Nodes.getNodeAs<CXXMemberCallExpr>(EndCallName);
if (!EndCall || !isa<MemberExpr>(EndCall->getCallee()))
return false;
if (const auto *EndCall = Nodes.getNodeAs<CXXMemberCallExpr>(EndCallName)) {
// This call is required to obtain the container.
if (!isa<MemberExpr>(EndCall->getCallee()))
return false;
}
return Nodes.getNodeAs<CallExpr>(EndCallName) != nullptr;
}
return true;
}
Expand Down Expand Up @@ -888,7 +991,7 @@ void LoopConvertCheck::check(const MatchFinder::MatchResult &Result) {

// If the end comparison isn't a variable, we can try to work with the
// expression the loop variable is being tested against instead.
const auto *EndCall = Nodes.getNodeAs<CXXMemberCallExpr>(EndCallName);
const auto *EndCall = Nodes.getNodeAs<Expr>(EndCallName);
const auto *BoundExpr = Nodes.getNodeAs<Expr>(ConditionBoundName);

// Find container expression of iterators and pseudoarrays, and determine if
Expand All @@ -902,9 +1005,11 @@ void LoopConvertCheck::check(const MatchFinder::MatchResult &Result) {
&Descriptor.ContainerNeedsDereference,
/*IsReverse=*/FixerKind == LFK_ReverseIterator);
} else if (FixerKind == LFK_PseudoArray) {
ContainerExpr = EndCall->getImplicitObjectArgument();
Descriptor.ContainerNeedsDereference =
dyn_cast<MemberExpr>(EndCall->getCallee())->isArrow();
std::optional<ContainerCall> Call = getContainerExpr(EndCall);
if (Call) {
ContainerExpr = Call->Container;
Descriptor.ContainerNeedsDereference = Call->IsArrow;
}
}

// We must know the container or an array length bound.
Expand Down
4 changes: 4 additions & 0 deletions clang-tools-extra/docs/ReleaseNotes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ Changes in existing checks
<clang-tidy/checks/llvm/namespace-comment>` check to provide fixes for
``inline`` namespaces in the same format as :program:`clang-format`.

- Improved :doc:`modernize-loop-convert
<clang-tidy/checks/modernize/loop-convert>` to support for-loops with
iterators initialized by free functions like ``begin``, ``end``, or ``size``.

Removed checks
^^^^^^^^^^^^^^

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,22 @@ Original:
for (vector<int>::iterator it = v.begin(); it != v.end(); ++it)
cout << *it;
// reasonable conversion
for (vector<int>::iterator it = begin(v); it != end(v); ++it)
cout << *it;
// reasonable conversion
for (vector<int>::iterator it = std::begin(v); it != std::end(v); ++it)
cout << *it;
// reasonable conversion
for (int i = 0; i < v.size(); ++i)
cout << v[i];

// reasonable conversion
for (int i = 0; i < size(v); ++i)
cout << v[i];

After applying the check with minimum confidence level set to `reasonable` (default):

.. code-block:: c++
Expand Down
Loading

0 comments on commit 6a1f8ef

Please sign in to comment.