diff --git a/test/cpp/monitor/test_counters.cpp b/test/cpp/monitor/test_counters.cpp index 8b729b8b0314e..e9f8a35bd56ee 100644 --- a/test/cpp/monitor/test_counters.cpp +++ b/test/cpp/monitor/test_counters.cpp @@ -1,23 +1,25 @@ #include +#include + #include +#include using namespace torch::monitor; TEST(MonitorTest, CounterDouble) { - Stat a{ + FixedCountStat a{ "a", {MEAN, COUNT}, + 2, }; a.add(5.0); ASSERT_EQ(a.count(), 1); a.add(6.0); - ASSERT_EQ(a.count(), 2); - a.closeWindow(); - auto stats = a.get(); ASSERT_EQ(a.count(), 0); - std::vector> want = { + auto stats = a.get(); + std::unordered_map want = { {MEAN, 5.5}, {COUNT, 2.0}, }; @@ -25,96 +27,96 @@ TEST(MonitorTest, CounterDouble) { } TEST(MonitorTest, CounterInt64Sum) { - Stat a{ + FixedCountStat a{ "a", {SUM}, + 2, }; a.add(5); a.add(6); - a.closeWindow(); auto stats = a.get(); - std::vector> want = { + std::unordered_map want = { {SUM, 11}, }; ASSERT_EQ(stats, want); } TEST(MonitorTest, CounterInt64Value) { - Stat a{ + FixedCountStat a{ "a", {VALUE}, + 2, }; a.add(5); a.add(6); - a.closeWindow(); auto stats = a.get(); - std::vector> want = { + std::unordered_map want = { {VALUE, 6}, }; ASSERT_EQ(stats, want); } TEST(MonitorTest, CounterInt64Mean) { - Stat a{ + FixedCountStat a{ "a", {MEAN}, + 2, }; - a.add(0); - a.add(10); - { - a.closeWindow(); + // zero samples case auto stats = a.get(); - std::vector> want = { - {MEAN, 5}, + std::unordered_map want = { + {MEAN, 0}, }; ASSERT_EQ(stats, want); } + a.add(0); + a.add(10); + { - // zero samples case - a.closeWindow(); auto stats = a.get(); - std::vector> want = { - {MEAN, 0}, + std::unordered_map want = { + {MEAN, 5}, }; ASSERT_EQ(stats, want); } } TEST(MonitorTest, CounterInt64Count) { - Stat a{ + FixedCountStat a{ "a", {COUNT}, + 2, }; ASSERT_EQ(a.count(), 0); a.add(0); ASSERT_EQ(a.count(), 1); a.add(10); - ASSERT_EQ(a.count(), 2); - a.closeWindow(); - auto stats = a.get(); ASSERT_EQ(a.count(), 0); - std::vector> want = { + + auto stats = a.get(); + std::unordered_map want = { {COUNT, 2}, }; ASSERT_EQ(stats, want); } TEST(MonitorTest, CounterInt64MinMax) { - Stat a{ + FixedCountStat a{ "a", {MIN, MAX}, + 6, }; { - a.closeWindow(); auto stats = a.get(); - std::vector> want = { + std::unordered_map want = { {MAX, 0}, {MIN, 0}, }; ASSERT_EQ(stats, want); } + a.add(0); a.add(5); a.add(-5); @@ -122,9 +124,8 @@ TEST(MonitorTest, CounterInt64MinMax) { a.add(9); a.add(2); { - a.closeWindow(); auto stats = a.get(); - std::vector> want = { + std::unordered_map want = { {MAX, 9}, {MIN, -6}, }; @@ -133,7 +134,7 @@ TEST(MonitorTest, CounterInt64MinMax) { } TEST(MonitorTest, CounterInt64WindowSize) { - Stat a{ + FixedCountStat a{ "a", {COUNT, SUM}, /*windowSize=*/3, @@ -144,54 +145,187 @@ TEST(MonitorTest, CounterInt64WindowSize) { a.add(3); ASSERT_EQ(a.count(), 0); - a.closeWindow(); + a.add(4); + ASSERT_EQ(a.count(), 1); + auto stats = a.get(); - std::vector> want = { + std::unordered_map want = { {COUNT, 3}, {SUM, 6}, }; ASSERT_EQ(stats, want); - a.closeWindow(); - ASSERT_EQ(stats, a.get()); } -TEST(MonitorTest, CloseAndGetStats) { - Stat a{ +template +struct TestIntervalStat : public IntervalStat { + uint64_t mockWindowId{0}; + + TestIntervalStat( + std::string name, + std::initializer_list aggregations, + std::chrono::milliseconds windowSize) + : IntervalStat(name, aggregations, windowSize) {} + + uint64_t currentWindowId() const override { + return mockWindowId; + } +}; + +struct AggregatingEventHandler : public EventHandler { + std::vector events; + + void handle(const Event& e) override { + events.emplace_back(e); + } +}; + +template +struct HandlerGuard { + std::shared_ptr handler; + + HandlerGuard() : handler(std::make_shared()) { + registerEventHandler(handler); + } + + ~HandlerGuard() { + unregisterEventHandler(handler); + } +}; + +TEST(MonitorTest, IntervalStat) { + HandlerGuard guard; + + IntervalStat a{ "a", {COUNT, SUM}, - /*windowSize=*/3, + std::chrono::milliseconds(1), }; - Stat b{ - "b", - {MIN, MAX}, - 2, + ASSERT_EQ(guard.handler->events.size(), 0); + + a.add(1); + ASSERT_LE(a.count(), 1); + + std::this_thread::sleep_for(std::chrono::milliseconds(2)); + a.add(2); + ASSERT_LE(a.count(), 1); + + ASSERT_GE(guard.handler->events.size(), 1); + ASSERT_LE(guard.handler->events.size(), 2); +} + +TEST(MonitorTest, IntervalStatEvent) { + HandlerGuard guard; + + TestIntervalStat a{ + "a", + {COUNT, SUM}, + std::chrono::milliseconds(1), }; + ASSERT_EQ(guard.handler->events.size(), 0); a.add(1); - b.add(1); + ASSERT_EQ(a.count(), 1); + a.add(2); + ASSERT_EQ(a.count(), 2); + ASSERT_EQ(guard.handler->events.size(), 0); + + a.mockWindowId = 100; + + a.add(3); + ASSERT_LE(a.count(), 1); + + ASSERT_EQ(guard.handler->events.size(), 1); + Event e = guard.handler->events.at(0); + ASSERT_EQ(e.type, "torch.monitor.Stat"); + ASSERT_EQ(e.message, "a"); + ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{}); + std::unordered_map metadata{ + {"a.sum", 3L}, + {"a.count", 2L}, + }; + ASSERT_EQ(e.metadata, metadata); +} + +TEST(MonitorTest, IntervalStatEventDestruction) { + HandlerGuard guard; { - auto out = closeAndGetStats(); - std::pair< - std::unordered_map, - std::unordered_map> - want = { - {{"a.count", 1}, {"a.sum", 1}}, - {{"b.min", 0}, {"b.max", 0}}, - }; + TestIntervalStat a{ + "a", + {COUNT, SUM}, + std::chrono::hours(10), + }; + a.add(1); + ASSERT_EQ(a.count(), 1); + ASSERT_EQ(guard.handler->events.size(), 0); } + ASSERT_EQ(guard.handler->events.size(), 1); + Event e = guard.handler->events.at(0); + ASSERT_EQ(e.type, "torch.monitor.Stat"); + ASSERT_EQ(e.message, "a"); + ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{}); + std::unordered_map metadata{ + {"a.sum", 1L}, + {"a.count", 1L}, + }; + ASSERT_EQ(e.metadata, metadata); +} + +TEST(MonitorTest, FixedCountStatEvent) { + HandlerGuard guard; + + FixedCountStat a{ + "a", + {COUNT, SUM}, + 3, + }; + ASSERT_EQ(guard.handler->events.size(), 0); + + a.add(1); + ASSERT_EQ(a.count(), 1); a.add(2); - b.add(2); + ASSERT_EQ(a.count(), 2); + ASSERT_EQ(guard.handler->events.size(), 0); + + a.add(1); + ASSERT_EQ(a.count(), 0); + ASSERT_EQ(guard.handler->events.size(), 1); + + Event e = guard.handler->events.at(0); + ASSERT_EQ(e.type, "torch.monitor.Stat"); + ASSERT_EQ(e.message, "a"); + ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{}); + std::unordered_map metadata{ + {"a.sum", 4L}, + {"a.count", 3L}, + }; + ASSERT_EQ(e.metadata, metadata); +} + +TEST(MonitorTest, FixedCountStatEventDestruction) { + HandlerGuard guard; { - auto out = closeAndGetStats(); - std::pair< - std::unordered_map, - std::unordered_map> - want = { - {{"a.count", 1}, {"a.sum", 2}}, - {{"b.min", 1}, {"b.max", 2}}, - }; + FixedCountStat a{ + "a", + {COUNT, SUM}, + 3, + }; + ASSERT_EQ(guard.handler->events.size(), 0); + a.add(1); + ASSERT_EQ(a.count(), 1); + ASSERT_EQ(guard.handler->events.size(), 0); } + ASSERT_EQ(guard.handler->events.size(), 1); + + Event e = guard.handler->events.at(0); + ASSERT_EQ(e.type, "torch.monitor.Stat"); + ASSERT_EQ(e.message, "a"); + ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{}); + std::unordered_map metadata{ + {"a.sum", 1L}, + {"a.count", 1L}, + }; + ASSERT_EQ(e.metadata, metadata); } diff --git a/test/cpp/monitor/test_events.cpp b/test/cpp/monitor/test_events.cpp new file mode 100644 index 0000000000000..30216e8756b13 --- /dev/null +++ b/test/cpp/monitor/test_events.cpp @@ -0,0 +1,39 @@ +#include + +#include + +using namespace torch::monitor; + +struct AggregatingEventHandler : public EventHandler { + std::vector events; + + void handle(const Event& e) override { + events.emplace_back(e); + } +}; + +TEST(EventsTest, EventHandler) { + Event e; + e.type = "test"; + e.message = "test message"; + e.timestamp = std::chrono::system_clock::now(); + e.metadata["string"] = "asdf"; + e.metadata["double"] = 1234.5678; + e.metadata["int"] = 1234L; + e.metadata["bool"] = true; + + // log to nothing + logEvent(e); + + auto handler = std::make_shared(); + registerEventHandler(handler); + + logEvent(e); + ASSERT_EQ(handler->events.size(), 1); + ASSERT_EQ(e, handler->events.at(0)); + + unregisterEventHandler(handler); + logEvent(e); + // handler unregister, didn't log + ASSERT_EQ(handler->events.size(), 1); +} diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 0b09ce849da1a..04d4eeedf1154 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -133,6 +133,7 @@ libtorch_profiler_sources = [ "torch/csrc/autograd/profiler_legacy.cpp", "torch/csrc/autograd/profiler_kineto.cpp", "torch/csrc/monitor/counters.cpp", + "torch/csrc/monitor/events.cpp", ] libtorch_edge_profiler_sources = libtorch_profiler_sources + [ diff --git a/torch/csrc/monitor/counters.cpp b/torch/csrc/monitor/counters.cpp index 5ac621a41d425..5bd7489182872 100644 --- a/torch/csrc/monitor/counters.cpp +++ b/torch/csrc/monitor/counters.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -12,6 +13,8 @@ const char* aggregationName(Aggregation agg) { return "none"; case VALUE: return "value"; + case MEAN: + return "mean"; case COUNT: return "count"; case SUM: @@ -62,39 +65,5 @@ void unregisterStat(Stat* stat) { } } // namespace detail -template -void closeAndGetStat(Stat* s, std::unordered_map& m) { - s->closeWindow(); - auto out = s->get(); - for (auto& kv : out) { - std::stringstream key; - key << s->name(); - key << "."; - key << aggregationName(kv.first); - m[key.str()] = kv.second; - } -} - -std::pair< - std::unordered_map, - std::unordered_map> -closeAndGetStats() noexcept { - std::pair< - std::unordered_map, - std::unordered_map> - out; - - std::lock_guard guard(stats().mu); - - for (auto* s : stats().doubles) { - closeAndGetStat(s, out.first); - } - for (auto* s : stats().int64s) { - closeAndGetStat(s, out.second); - } - - return out; -} - } // namespace monitor } // namespace torch diff --git a/torch/csrc/monitor/counters.h b/torch/csrc/monitor/counters.h index f0e002ca119e7..5753031727657 100644 --- a/torch/csrc/monitor/counters.h +++ b/torch/csrc/monitor/counters.h @@ -2,9 +2,12 @@ #include #include +#include #include #include +#include + namespace torch { namespace monitor { @@ -32,6 +35,8 @@ enum Aggregation { MIN = 6, }; +// aggregationName returns the human readable name corresponding to the +// aggregation. const char* aggregationName(Aggregation agg); template @@ -55,6 +60,16 @@ void unregisterStat(Stat* stat); void unregisterStat(Stat* stat); } // namespace detail +// Stat is a base class for stats. These stats are used to compute summary +// statistics in a performant way over repeating intervals. When the window +// closes the stats are logged via the event handlers as a `torch.monitor.Stat` +// event. +// +// Stats support double and int64_t data types depending on what needs to be +// logged and needs to be templatized with one of them. +// +// When the Stat is destructed it will log any remaining data even if the window +// hasn't elapsed. template class Stat { private: @@ -67,23 +82,24 @@ class Stat { }; public: - Stat( - std::string name, - std::initializer_list aggregations, - int64_t windowSize = -1) - : name_(std::move(name)), - aggregations_(merge(aggregations)), - windowSize_(windowSize) { + Stat(std::string name, std::initializer_list aggregations) + : name_(std::move(name)), aggregations_(merge(aggregations)) { detail::registerStat(this); } - ~Stat() { + virtual ~Stat() { + { + // on destruction log if there's unlogged data + std::lock_guard guard(mu_); + logLocked(); + } detail::unregisterStat(this); } // add adds the value v to the current window. - void add(T v) noexcept { + void add(T v) { std::lock_guard guard(mu_); + maybeLogLocked(); if (aggregations_.test(VALUE)) { current_.value = v; @@ -104,19 +120,13 @@ class Stat { } current_.count += 1; - if (windowSize_ > 0 && current_.count >= windowSize_) { - saveCurrentLocked(); - } + maybeLogLocked(); } const std::string& name() const noexcept { return name_; } - int64_t windowSize() const noexcept { - return windowSize_; - } - // count returns the number of items in the current open window. int64_t count() noexcept { std::lock_guard guard(mu_); @@ -124,68 +134,131 @@ class Stat { return current_.count; } - // closeWindow finalizes the collected stats window so they can be accessed - // via get(). - // If the Stat has a windowSize specified this doesn't do anything since the - // window is automatically closed when enough samples have been logged. - void closeWindow() noexcept { - if (windowSize_ <= 0) { - std::lock_guard guard(mu_); + std::unordered_map get() noexcept { + std::lock_guard guard(mu_); + return getLocked(); + } + + protected: + virtual void maybeLogLocked() = 0; + + void logLocked() { + prev_ = current_; + current_ = Values(); - saveCurrentLocked(); + // don't log event if there's no data + if (prev_.count == 0) { + return; } + + Event e; + e.type = "torch.monitor.Stat"; + e.message = name_; + e.timestamp = std::chrono::system_clock::now(); + + auto stats = getLocked(); + e.metadata.reserve(stats.size()); + for (auto& kv : stats) { + std::stringstream key; + key << name_; + key << "."; + key << aggregationName(kv.first); + e.metadata[key.str()] = kv.second; + } + + logEvent(e); } - std::vector> get() noexcept { - std::vector> out; + std::unordered_map getLocked() const noexcept { + std::unordered_map out; out.reserve(aggregations_.count()); - std::lock_guard guard(mu_); - if (aggregations_.test(VALUE)) { - out.emplace_back(VALUE, prev_.value); + out.emplace(VALUE, prev_.value); } if (aggregations_.test(MEAN)) { if (prev_.count == 0) { - out.emplace_back(MEAN, 0); + out.emplace(MEAN, 0); } else { - out.emplace_back(MEAN, prev_.sum / prev_.count); + out.emplace(MEAN, prev_.sum / prev_.count); } } if (aggregations_.test(COUNT)) { - out.emplace_back(COUNT, prev_.count); + out.emplace(COUNT, prev_.count); } if (aggregations_.test(SUM)) { - out.emplace_back(SUM, prev_.sum); + out.emplace(SUM, prev_.sum); } if (aggregations_.test(MAX)) { - out.emplace_back(MAX, prev_.max); + out.emplace(MAX, prev_.max); } if (aggregations_.test(MIN)) { - out.emplace_back(MIN, prev_.min); + out.emplace(MIN, prev_.min); } return out; } - private: - void saveCurrentLocked() { - prev_ = current_; - current_ = Values(); - } - const std::string name_; const std::bitset aggregations_; - const int64_t windowSize_; std::mutex mu_; Values current_; Values prev_; }; -std::pair< - std::unordered_map, - std::unordered_map> -closeAndGetStats() noexcept; +// IntervalStat is a Stat that logs the stat once every `windowSize` duration. +// This should be set to something relatively high to avoid a huge number of +// events being logged. Ex: 60s. +template +class IntervalStat : public Stat { + public: + IntervalStat( + std::string name, + std::initializer_list aggregations, + std::chrono::milliseconds windowSize) + : Stat(std::move(name), aggregations), windowSize_(windowSize) {} + + protected: + virtual uint64_t currentWindowId() const { + auto now = std::chrono::steady_clock::now().time_since_epoch(); + return now / windowSize_; + } + + private: + void maybeLogLocked() override { + auto windowId = currentWindowId(); + if (windowId_ != windowId) { + Stat::logLocked(); + windowId_ = windowId; + } + } + + uint64_t windowId_{0}; + const std::chrono::milliseconds windowSize_; +}; + +// FixedCountStat is a Stat that logs the stat every `windowSize` number of add +// calls. For high performance stats this window size should be fairly large to +// ensure that the event logging frequency is in the range of 1s to 60s under +// normal usage. Core stats should error on the side of less frequent. +template +class FixedCountStat : public Stat { + public: + FixedCountStat( + std::string name, + std::initializer_list aggregations, + int64_t windowSize) + : Stat(std::move(name), aggregations), windowSize_(windowSize) {} + + private: + void maybeLogLocked() override { + if (Stat::current_.count >= windowSize_) { + Stat::logLocked(); + } + } + + const int64_t windowSize_; +}; } // namespace monitor } // namespace torch diff --git a/torch/csrc/monitor/events.cpp b/torch/csrc/monitor/events.cpp new file mode 100644 index 0000000000000..f8f6355864325 --- /dev/null +++ b/torch/csrc/monitor/events.cpp @@ -0,0 +1,61 @@ +#include + +#include +#include +#include +#include +#include + +namespace torch { +namespace monitor { + +namespace { +class EventHandlers { + public: + void registerEventHandler(std::shared_ptr handler) noexcept { + std::unique_lock lock(mu_); + + handlers_.emplace_back(std::move(handler)); + } + + void unregisterEventHandler( + const std::shared_ptr& handler) noexcept { + std::unique_lock lock(mu_); + + auto it = std::find(handlers_.begin(), handlers_.end(), handler); + handlers_.erase(it); + } + + void logEvent(const Event& e) { + std::unique_lock lock(mu_); + + for (auto& handler : handlers_) { + handler->handle(e); + } + } + + static EventHandlers& get() noexcept { + static EventHandlers ehs; + return ehs; + } + + private: + std::mutex mu_{}; + std::vector> handlers_{}; +}; +} // namespace + +void logEvent(const Event& e) { + EventHandlers::get().logEvent(e); +} + +void registerEventHandler(std::shared_ptr p) { + EventHandlers::get().registerEventHandler(std::move(p)); +} + +void unregisterEventHandler(const std::shared_ptr& p) { + EventHandlers::get().unregisterEventHandler(p); +} + +} // namespace monitor +} // namespace torch diff --git a/torch/csrc/monitor/events.h b/torch/csrc/monitor/events.h new file mode 100644 index 0000000000000..db6ddfcd4bc82 --- /dev/null +++ b/torch/csrc/monitor/events.h @@ -0,0 +1,75 @@ +#pragma once + +#include +#include +#include + +#include + +namespace torch { +namespace monitor { + +// metadata_value_t is the type for Event metadata values. +using metadata_value_t = c10::variant; + +// Event represents a single event that can be logged out to an external +// tracker. This does acquire a lock on logging so should be used relatively +// infrequently to avoid performance issues. +struct Event { + // type is the type of the event. This is a static string that's used to + // differentiate between event types for programmatic access. The type should + // be in the format of a fully qualified Python-style class name. + // Ex: torch.monitor.MonitorEvent + std::string type; + + // message is a human readable name. This is optional for machine intended + // stats. + std::string message; + + // timestamp is a timestamp relative to the Unix epoch time. + std::chrono::system_clock::time_point timestamp; + + // metadata contains rich information about the event. The contents are event + // specific so you should check the type to ensure it's what you expect before + // accessing the metadata. + // + // NOTE: these events are not versioned and it's up to the consumer of the + // events to check the fields to ensure backwards compatibility. + std::unordered_map metadata; +}; + +inline bool operator==(const Event& lhs, const Event& rhs) { + return lhs.type == rhs.type && lhs.message == rhs.message && + lhs.timestamp == rhs.timestamp && lhs.metadata == rhs.metadata; +} + +// EventHandler represents an abstract event handler that can be registered to +// capture events. Every time an event is logged every handler will be called +// with the events contents. +// +// NOTE: The handlers should avoid any IO, blocking calls or heavy computation +// as this may block the main thread and cause performance issues. +class EventHandler { + public: + virtual ~EventHandler() = default; + + // handle needs to be implemented to handle the events. This may be called + // from multiple threads so needs to be thread safe. + virtual void handle(const Event& e) = 0; +}; + +// logEvent calls each registered event handler with the event. This method can +// be called from concurrently from multiple threads. +void logEvent(const Event& e); + +// registerEventHandler registers an EventHandler so it receives any logged +// events. Typically an EventHandler will be registered during program +// setup and unregistered at the end. +void registerEventHandler(std::shared_ptr p); + +// unregisterEventHandler unregisters the event handler pointed to by the +// shared_ptr. +void unregisterEventHandler(const std::shared_ptr& p); + +} // namespace monitor +} // namespace torch