Skip to content

Commit

Permalink
Bug 1877672 - Make FunctionRef a MOZ_TEMPORARY_CLASS. r=nika
Browse files Browse the repository at this point in the history
This prevents it from being used in the foot-gunny way described in
comment 0.

This in turn allows us to add a constructor for temporary callables.

Turns out we only had test usages of non-temporary FunctionRefs, so this
is much simpler than the initial approach I considered.

Fix the tests to keep compiling, and add a test for the new constructor.

Differential Revision: https://phabricator.services.mozilla.com/D200157
  • Loading branch information
emilio committed Feb 2, 2024
1 parent 0f8563b commit f67e589
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 56 deletions.
4 changes: 2 additions & 2 deletions mfbt/Attributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,8 @@
* need not be provided in such cases.
* MOZ_TEMPORARY_CLASS: Applies to all classes. Any class with this annotation
* is expected to only live in a temporary. If another class inherits from
* this class, then it is considered to be a non-temporary class as well,
* although this attribute need not be provided in such cases.
* this class, then it is considered to be a temporary class as well, although
* this attribute need not be provided in such cases.
* MOZ_RAII: Applies to all classes. Any class with this annotation is assumed
* to be a RAII guard, which is expected to live on the stack in an automatic
* allocation. It is prohibited from being allocated in a temporary, static
Expand Down
31 changes: 12 additions & 19 deletions mfbt/FunctionRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ using EnableFunctionTag = std::enable_if_t<
* FunctionRef.
*/
template <typename Fn>
class FunctionRef;
class MOZ_TEMPORARY_CLASS FunctionRef;

template <typename Ret, typename... Params>
class FunctionRef<Ret(Params...)> {
class MOZ_TEMPORARY_CLASS FunctionRef<Ret(Params...)> {
union Payload;

// |FunctionRef| stores an adaptor function pointer, determined by the
Expand Down Expand Up @@ -166,25 +166,18 @@ class FunctionRef<Ret(Params...)> {
* state). For example:
*
* int x = 5;
* auto doSideEffect = [&x]{ x++; }; // state is captured reference to |x|
* FunctionRef<void()> f(doSideEffect);
* DoSomething([&x] { x++; });
*/
template <
typename Callable,
typename = detail::EnableFunctionTag<detail::MatchingFunctorTag, Callable,
Ret, Params...>,
typename std::enable_if_t<!std::is_same_v<
typename std::remove_reference_t<typename std::remove_cv_t<Callable>>,
FunctionRef>>* = nullptr>
MOZ_IMPLICIT FunctionRef(Callable& aCallable) noexcept
template <typename Callable,
typename = detail::EnableFunctionTag<detail::MatchingFunctorTag,
Callable, Ret, Params...>,
typename std::enable_if_t<!std::is_same_v<
std::remove_cv_t<std::remove_reference_t<Callable>>,
FunctionRef>>* = nullptr>
MOZ_IMPLICIT FunctionRef(Callable&& aCallable) noexcept
: mAdaptor([](const Payload& aPayload, Params... aParams) {
auto& func = *static_cast<Callable*>(aPayload.mObject);
// Unable to use std::forward here due to llvm windows bug
// https://bugs.llvm.org/show_bug.cgi?id=28299
//
// This prevents use of move-only arguments for functors and lambdas.
// Move only arguments can be used when using function pointers
return static_cast<Ret>(func(static_cast<Params>(aParams)...));
auto& func = *static_cast<std::remove_reference_t<Callable>*>(aPayload.mObject);
return static_cast<Ret>(func(std::forward<Params>(aParams)...));
}) {
::new (KnownNotNull, &mPayload.mObject) void*(&aCallable);
}
Expand Down
72 changes: 37 additions & 35 deletions mfbt/tests/TestFunctionRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "mozilla/FunctionRef.h"
#include "mozilla/UniquePtr.h"

using mozilla::FunctionRef;

#define CHECK(c) \
do { \
bool cond = !!(c); \
Expand All @@ -34,99 +36,99 @@ struct Incrementor {
int operator()(int arg) { return arg + 1; }
};

template <typename Fn>
struct Caller;

template <typename Fn, typename... Params>
std::invoke_result_t<Fn, Params...> CallFunctionRef(FunctionRef<Fn> aRef,
Params... aParams) {
return aRef(std::forward<Params>(aParams)...);
}

static void TestNonmemberFunction() {
mozilla::FunctionRef<int(int)> f(&increment);
CHECK(f(42) == 43);
CHECK(CallFunctionRef<int(int)>(increment, 42) == 43);
}

static void TestStaticMemberFunction() {
mozilla::FunctionRef<int(int)> f(&S::increment);
CHECK(f(42) == 43);
CHECK(CallFunctionRef<int(int)>(&S::increment, 42) == 43);
}

static void TestFunctionObject() {
auto incrementor = Incrementor();
mozilla::FunctionRef<int(int)> f(incrementor);
CHECK(f(42) == 43);
CHECK(CallFunctionRef<int(int)>(incrementor, 42) == 43);
}

static void TestFunctionObjectTemporary() {
CHECK(CallFunctionRef<int(int)>(Incrementor(), 42) == 43);
}

static void TestLambda() {
// Test non-capturing lambda
auto lambda1 = [](int arg) { return arg + 1; };
mozilla::FunctionRef<int(int)> f(lambda1);
CHECK(f(42) == 43);
CHECK(CallFunctionRef<int(int)>(lambda1, 42) == 43);

// Test capturing lambda
int one = 1;
auto lambda2 = [one](int arg) { return arg + one; };
mozilla::FunctionRef<int(int)> g(lambda2);
CHECK(g(42) == 43);
CHECK(CallFunctionRef<int(int)>(lambda2, 42) == 43);

mozilla::FunctionRef<int(int)> h([](int arg) { return arg + 1; });
CHECK(h(42) == 43);
CHECK(CallFunctionRef<int(int)>([](int arg) { return arg + 1; }, 42) == 43);
}

static void TestOperatorBool() {
mozilla::FunctionRef<int(int)> f1;
CHECK(!static_cast<bool>(f1));

mozilla::FunctionRef<int(int)> f2 = increment;
CHECK(static_cast<bool>(f2));

mozilla::FunctionRef<int(int)> f3 = nullptr;
CHECK(!static_cast<bool>(f3));
auto ToBool = [](FunctionRef<int(int)> aRef) {
return static_cast<bool>(aRef);
};
CHECK(!ToBool({}));
CHECK(ToBool(increment));
CHECK(!ToBool(nullptr));
}

static void TestReferenceParameters() {
mozilla::FunctionRef<int(const int&, const int&)> f = &addConstRefs;
int x = 1;
int y = 2;
CHECK(f(x, y) == 3);
CHECK(CallFunctionRef<int(const int&, const int&)>(addConstRefs, x, y) == 3);
}

static void TestVoidNoParameters() {
mozilla::FunctionRef<void()> f = &helloWorld;
CHECK(!helloWorldCalled);
f();
CallFunctionRef<void()>(helloWorld);
CHECK(helloWorldCalled);
}

static void TestPointerParameters() {
mozilla::FunctionRef<void(int*)> f = &incrementPointer;
int x = 1;
f(&x);
CallFunctionRef<void(int*)>(incrementPointer, &x);
CHECK(x == 2);
}

static void TestImplicitFunctorTypeConversion() {
auto incrementor = Incrementor();
mozilla::FunctionRef<long(short)> f = incrementor;
short x = 1;
CHECK(f(x) == 2);
CHECK(CallFunctionRef<long(short)>(incrementor, x) == 2);
}

static void TestImplicitLambdaTypeConversion() {
mozilla::FunctionRef<long(short)> f = [](short arg) { return arg + 1; };
short x = 1;
CHECK(f(x) == 2);
CHECK(CallFunctionRef<long(short)>([](short arg) { return arg + 1; }, x) ==
2);
}

static void TestImplicitFunctionPointerTypeConversion() {
mozilla::FunctionRef<long(short)> f = &increment;
short x = 1;
CHECK(f(x) == 2);
CHECK(CallFunctionRef<long(short)>(&increment, x) == 2);
}

static void TestMoveOnlyArguments() {
mozilla::FunctionRef<int(mozilla::UniquePtr<int>)> f(&incrementUnique);

CHECK(f(mozilla::MakeUnique<int>(5)) == 6);
CHECK(CallFunctionRef<int(mozilla::UniquePtr<int>)>(
&incrementUnique, mozilla::MakeUnique<int>(5)) == 6);
}

int main() {
TestNonmemberFunction();
TestStaticMemberFunction();
TestFunctionObject();
TestFunctionObjectTemporary();
TestLambda();
TestOperatorBool();
TestReferenceParameters();
Expand Down

0 comments on commit f67e589

Please sign in to comment.