From cb8ffd254bc9412587bcda136dcebe64ec361168 Mon Sep 17 00:00:00 2001 From: "Duncan P. N. Exon Smith" Date: Tue, 7 Mar 2017 21:56:32 +0000 Subject: [PATCH] ADT: Fix SmallPtrSet iterators in reverse mode Fix SmallPtrSet::iterator behaviour and creation ReverseIterate is true. - Any function that creates an iterator now uses SmallPtrSet::makeIterator, which creates an iterator that dereferences to the given pointer. - In reverse-iterate mode, initialze iterator::End with "CurArray" instead of EndPointer. - In reverse-iterate mode, the current node is iterator::Buffer[-1]. iterator::operator* and SmallPtrSet::makeIterator are the only ones that need to know. - Fix the assertions for reverse-iterate mode. This fixes the tests Danny B added in r297182, and adds a couple of others to confirm that dereferencing does the right thing, regardless of how the iterator was found, and that iteration works correctly from each return from find. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@297234 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/ADT/SmallPtrSet.h | 45 ++++++++++++++++--------------- unittests/ADT/SmallPtrSetTest.cpp | 32 +++++++++++++++++----- 2 files changed, 48 insertions(+), 29 deletions(-) diff --git a/include/llvm/ADT/SmallPtrSet.h b/include/llvm/ADT/SmallPtrSet.h index 7234f0fbded91..b98cf6c376b5d 100644 --- a/include/llvm/ADT/SmallPtrSet.h +++ b/include/llvm/ADT/SmallPtrSet.h @@ -260,11 +260,10 @@ class SmallPtrSetIteratorImpl { } #if LLVM_ENABLE_ABI_BREAKING_CHECKS void RetreatIfNotValid() { - --Bucket; - assert(Bucket <= End); + assert(Bucket >= End); while (Bucket != End && - (*Bucket == SmallPtrSetImplBase::getEmptyMarker() || - *Bucket == SmallPtrSetImplBase::getTombstoneMarker())) { + (Bucket[-1] == SmallPtrSetImplBase::getEmptyMarker() || + Bucket[-1] == SmallPtrSetImplBase::getTombstoneMarker())) { --Bucket; } } @@ -289,6 +288,12 @@ class SmallPtrSetIterator : public SmallPtrSetIteratorImpl { // Most methods provided by baseclass. const PtrTy operator*() const { +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + if (ReverseIterate::value) { + assert(Bucket > End); + return PtrTraits::getFromVoidPointer(const_cast(Bucket[-1])); + } +#endif assert(Bucket < End); return PtrTraits::getFromVoidPointer(const_cast(*Bucket)); } @@ -296,6 +301,7 @@ class SmallPtrSetIterator : public SmallPtrSetIteratorImpl { inline SmallPtrSetIterator& operator++() { // Preincrement #if LLVM_ENABLE_ABI_BREAKING_CHECKS if (ReverseIterate::value) { + --Bucket; RetreatIfNotValid(); return *this; } @@ -370,7 +376,7 @@ class SmallPtrSetImpl : public SmallPtrSetImplBase { /// the element equal to Ptr. std::pair insert(PtrType Ptr) { auto p = insert_imp(PtrTraits::getAsVoidPointer(Ptr)); - return std::make_pair(iterator(p.first, EndPointer()), p.second); + return std::make_pair(makeIterator(p.first), p.second); } /// erase - If the set contains the specified pointer, remove it and return @@ -379,12 +385,9 @@ class SmallPtrSetImpl : public SmallPtrSetImplBase { return erase_imp(PtrTraits::getAsVoidPointer(Ptr)); } /// count - Return 1 if the specified pointer is in the set, 0 otherwise. - size_type count(ConstPtrType Ptr) const { - return find(Ptr) != endPtr() ? 1 : 0; - } + size_type count(ConstPtrType Ptr) const { return find(Ptr) != end() ? 1 : 0; } iterator find(ConstPtrType Ptr) const { - auto *P = find_imp(ConstPtrTraits::getAsVoidPointer(Ptr)); - return iterator(P, EndPointer()); + return makeIterator(find_imp(ConstPtrTraits::getAsVoidPointer(Ptr))); } template @@ -397,25 +400,23 @@ class SmallPtrSetImpl : public SmallPtrSetImplBase { insert(IL.begin(), IL.end()); } - inline iterator begin() const { + iterator begin() const { #if LLVM_ENABLE_ABI_BREAKING_CHECKS if (ReverseIterate::value) - return endPtr(); + return makeIterator(EndPointer() - 1); #endif - return iterator(CurArray, EndPointer()); + return makeIterator(CurArray); } - inline iterator end() const { + iterator end() const { return makeIterator(EndPointer()); } + +private: + /// Create an iterator that dereferences to same place as the given pointer. + iterator makeIterator(const void *const *P) const { #if LLVM_ENABLE_ABI_BREAKING_CHECKS if (ReverseIterate::value) - return iterator(CurArray, CurArray); + return iterator(P == EndPointer() ? CurArray : P + 1, CurArray); #endif - return endPtr(); - } - -private: - inline iterator endPtr() const { - const void *const *End = EndPointer(); - return iterator(End, End); + return iterator(P, EndPointer()); } }; diff --git a/unittests/ADT/SmallPtrSetTest.cpp b/unittests/ADT/SmallPtrSetTest.cpp index bb9ee67b7eba3..fc14c684d67f3 100644 --- a/unittests/ADT/SmallPtrSetTest.cpp +++ b/unittests/ADT/SmallPtrSetTest.cpp @@ -282,6 +282,28 @@ TEST(SmallPtrSetTest, EraseTest) { checkEraseAndIterators(A); } +// Verify that dereferencing and iteration work. +TEST(SmallPtrSetTest, dereferenceAndIterate) { + int Ints[] = {0, 1, 2, 3, 4, 5, 6, 7}; + SmallPtrSet S; + for (int &I : Ints) { + EXPECT_EQ(&I, *S.insert(&I).first); + EXPECT_EQ(&I, *S.find(&I)); + } + + // Iterate from each and count how many times each element is found. + int Found[sizeof(Ints)/sizeof(int)] = {0}; + for (int &I : Ints) + for (auto F = S.find(&I), E = S.end(); F != E; ++F) + ++Found[*F - Ints]; + + // Sort. We should hit the first element just once and the final element N + // times. + std::sort(std::begin(Found), std::end(Found)); + for (auto F = std::begin(Found), E = std::end(Found); F != E; ++F) + EXPECT_EQ(F - Found + 1, *F); +} + // Verify that const pointers work for count and find even when the underlying // SmallPtrSet is not for a const pointer type. TEST(SmallPtrSetTest, ConstTest) { @@ -292,10 +314,8 @@ TEST(SmallPtrSetTest, ConstTest) { IntSet.insert(B); EXPECT_EQ(IntSet.count(B), 1u); EXPECT_EQ(IntSet.count(C), 1u); - // FIXME: We can't unit test find right now because ABI_BREAKING_CHECKS breaks - // find(). - // EXPECT_NE(IntSet.find(B), IntSet.end()); - // EXPECT_NE(IntSet.find(C), IntSet.end()); + EXPECT_NE(IntSet.find(B), IntSet.end()); + EXPECT_NE(IntSet.find(C), IntSet.end()); } // Verify that we automatically get the const version of PointerLikeTypeTraits @@ -308,7 +328,5 @@ TEST(SmallPtrSetTest, ConstNonPtrTest) { TestPair Pair(&A[0], 1); IntSet.insert(Pair); EXPECT_EQ(IntSet.count(Pair), 1u); - // FIXME: We can't unit test find right now because ABI_BREAKING_CHECKS breaks - // find(). - // EXPECT_NE(IntSet.find(Pair), IntSet.end()); + EXPECT_NE(IntSet.find(Pair), IntSet.end()); }