diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000..1ffcb0c --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,83 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ "master" ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ "master" ] +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'cpp' ] + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + + # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs + # queries: security-extended,security-and-quality + + + - name: Install dependencies + run: | + sudo apt-get update && \ + sudo apt-get -qq install clang cmake git openssl libssl-dev libhiredis-dev \ + libspdlog-dev libfmt-dev ninja-build + + - name: Install redis-plus-plus + run: | + mkdir -p /tmp/redis-plus-plus && cd /tmp/redis-plus-plus && \ + git clone https://github.com/sewenew/redis-plus-plus.git . && \ + git checkout tags/1.3.3 && \ + mkdir compile && cd compile && cmake -GNinja -DCMAKE_BUILD_TYPE=Release .. && \ + ninja && sudo ninja install + + # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v2 + + # ℹī¸ Command-line programs to run using the OS shell. + # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun + + # If the Autobuild fails above, remove it and uncomment the following three lines. + # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. + + # - run: | + # echo "Run, Build Application using script" + # ./location_of_script_within_repo/buildscript.sh + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 + with: + category: "/language:${{matrix.language}}" diff --git a/.gitignore b/.gitignore index 2ca1c63..7a4bd44 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ build-dbg/* eventhub.code-workspace tests/integration/clienttest/clienttest tests/integration/clienttest/vendor/* +eventhub.conf \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 3d32e98..e3907e1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,11 +1,11 @@ -cmake_minimum_required (VERSION 2.9) +cmake_minimum_required(VERSION 3.5) project (eventhub) set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake") set(CMAKE_C_FLAGS "-Wall") set(CMAKE_C_COMPILER "clang") set(CMAKE_CXX_COMPILER "clang++") -set(CMAKE_CXX_FLAGS "-Wall -std=c++17 -DSPDLOG_FMT_EXTERNAL") +set(CMAKE_CXX_FLAGS "-Wall -std=c++17 -DSPDLOG_FMT_EXTERNAL=ON -Wno-deprecated-declarations") set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -62,4 +62,4 @@ add_custom_target( -style=file -i ${ALL_SOURCE_FILES} -) \ No newline at end of file +) diff --git a/Dockerfile b/Dockerfile index ebff9cf..817181a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM debian:bullseye-slim +FROM debian:bookworm-slim ENV DEBIAN_FRONTEND noninteractive RUN apt-get update && \ @@ -33,4 +33,4 @@ RUN apt-get -qq remove clang cmake git ninja-build && \ USER eventhub -ENTRYPOINT [ "/usr/bin/eventhub" ] \ No newline at end of file +ENTRYPOINT [ "/usr/bin/eventhub" ] diff --git a/Dockerfile.debug b/Dockerfile.debug index 1d8def3..2f98b02 100644 --- a/Dockerfile.debug +++ b/Dockerfile.debug @@ -2,7 +2,7 @@ FROM debian:bullseye-slim ENV DEBIAN_FRONTEND noninteractive RUN apt-get update && \ - apt-get -qq install clang cmake git openssl libssl-dev libhiredis-dev gdb bash vim psmisc procps htop curl sudo \ + apt-get -qq install gcc g++ cmake git openssl libssl-dev libhiredis-dev gdb bash vim psmisc procps htop curl sudo \ libspdlog-dev libfmt-dev ninja-build RUN mkdir -p /usr/src/redis-plus-plus && cd /usr/src/redis-plus-plus && \ @@ -17,8 +17,10 @@ WORKDIR /usr/src/eventhub COPY . . RUN mkdir -p build && cd build && \ - cmake -DSKIP_TESTS=1 -GNinja -DCMAKE_BUILD_TYPE=RelWithDebInfo .. && \ - ninja && \ + sed -i 's/clang++/g++/' ../CMakeLists.txt && \ + sed -i 's/clang/gcc/' ../CMakeLists.txt && \ + cmake -GNinja -DSKIP_TESTS=1 -DCMAKE_BUILD_TYPE=Debug .. && \ + ninja -j0 && \ cp -a eventhub /usr/bin/eventhub WORKDIR /tmp @@ -28,9 +30,6 @@ RUN addgroup --system eventhub && \ echo "eventhub ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers && \ mkdir -p /tmp/coredumps; chown -R eventhub:eventhub /tmp/coredumps -RUN apt-get -qq remove clang cmake git ninja-build && \ - apt-get -qq -f autoremove - USER eventhub ENTRYPOINT [ "/usr/bin/eventhub" ] \ No newline at end of file diff --git a/README.md b/README.md index ddc7a49..34d2bf1 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Eventhub is a pub/sub over WebSocket server written in modern C++. It implements the [publish-subscribe pattern](https://en.wikipedia.org/wiki/Publish%E2%80%93subscribe_pattern) and concept of topics. -The key focus of the project is to deliver high performance, availability, and easy integration. +The key focus of the project is to deliver high performance and ease of integration.

@@ -42,7 +42,7 @@ Eventhub use the same layout for patterns as MQTT where ```+``` matches a single ## Eventlog -Eventhub stores all published messages into a log that can be requested by clients who want to get all events in or since a given time frame. For example if a client gets disconnected it can request this log to get all new events since the last event that was received. +Eventhub stores all published messages into a log that can be requested by clients who want to get all events in or since a given time frame. For example if a client is disconnected it can request this log when it reconnects to get all new events since the last event that was received. ## Authentication @@ -63,8 +63,16 @@ This token wil allow subscription to all channels under ```topic1``` and ```topi Eventhub does not have a interface or API to generate these tokens for you yet. So you have to generate them in your backend or through a JWT token generator like [jwt.io](https://jwt.io/). +## Key/value store + +Most applications require some kind of state to be stored. +Eventhub implements a simple key/value store for this purpose. This is a feature separate from the pub/sub functionality. + +The read/write ACL's in the JWT-token defines what keys you have access to. + # Clients * [JavaScript (Browser/Node.js)](https://github.com/olesku/eventhub-jsclient) +* [Python](https://github.com/olesku/eventhub-pyclient) * [Swift (iOS 13+)](https://github.com/shtrihstr/EventHub) #### Implementing your own client @@ -95,18 +103,20 @@ All configuration options can also be set using [environment variables](https:// |handshake_timeout | Client handshake timeout | 15 |disable_auth | Disable client authentication | false |[enable_sse](docs/sse.md) | Enable Server-Sent-Events support | false -|enable_cache | Enable retained cache for topics. | true +|enable_cache | Enable retained cache for topics. | false |prometheus_metric_prefix | Prometheus prefix | eventhub |default_cache_ttl | Default message TTL | 60 -|max_cache_request_limit | Default returned cache result limit | 1000 +|max_cache_request_limit | Default returned cache result limit | 100 |log_level | Log level to use | info |enable_ssl | Enable SSL | false +|ssl_listen_port | Port to listen on for SSL requests | 8443 |ssl_certificate | Path to certificate for SSL | None |ssl_private_key | Path to private key for SSL | None |ssl_ca_certificate | Path to CA certificate | None |ssl_cert_auto_reload | Reload ssl cert when changed on disk | false |ssl_cert_check_interval | How often to check for cert changes | 300 |disable_unsecure_listener | Disable unsecure listener when ssl is enabled | false +|enable_kvstore | Enable key/value store functionality | true ## Docker The easiest way is to use our docker image. diff --git a/docs/protocol.md b/docs/protocol.md index 794bc8d..a41a042 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -1,15 +1,19 @@ # Protocol specification Eventhub uses [JSON-RPC](http://www.jsonrpc.org/) over WebSocket as transport protocol. -| RPC method | Parameters | Description | -|-------------------------------------|----------------|-------------| -| [subscribe](#subscribe) | *topic, since* | Subscribe to a topic or pattern. -| [publish](#publish) | *topic, message* | Publish to a topic. -| [unsubscribe](#unsubscribe) | *topic* | Unsubscribe from a topic or pattern. -| [unsubscribeall](#unsubscribeall) | *None* | Unsubscribe from all current subscriptions. -| [list](#list) | *None* | List all current subscriptions. -| [ping](#ping) | *None* | Ping the server. -| [disconnect](#disconnect) | *None* | Disconnect from the server. +| RPC method | Parameters | Description | +|-------------------------------------|------------------------------ |---------------------------------------------| +| [subscribe](#subscribe) | *topic, since* | Subscribe to a topic or pattern. +| [publish](#publish) | *topic, message* | Publish to a topic. +| [unsubscribe](#unsubscribe) | *topic* | Unsubscribe from a topic or pattern. +| [unsubscribeall](#unsubscribeall) | *None* | Unsubscribe from all current subscriptions. +| [list](#list) | *None* | List all current subscriptions. +| [eventlog](#eventlog) | *topic, since, sinceEventId, limit* | Request event history for a topic. +| [get](#get) | *key* | Get key from key/value store. +| [set](#set) | *key, value, ttl* | Set key in key/value store. +| [del](#del) | *key* | Delete key in key/value store. +| [ping](#ping) | *None* | Ping the server. +| [disconnect](#disconnect) | *None* | Disconnect from the server. **Important:** Each request must have a unique `id` attribute as specified by JSON-RPC. It can be a number or a string. @@ -24,14 +28,14 @@ If you are implementing your own client I can recommend using the nice [websocat "id": 1, "jsonrpc": "2.0", "method": "subscribe", - "params": { - "topic": "my/topic1", + "params": { + "topic": "my/topic1", "since": 0 } } ``` -*The `since` attribute can be set to a timestamp or a message id to get all events from the history log since that period. If unset or set to 0 no history will be requested.* +*The `since` attribute can be set to a timestamp or a message id to get all events from the eventlog since that period. If unset or set to 0 eventlog will not be requested.* **Confirmation response:** ```json @@ -154,13 +158,150 @@ All messages received on a subscribed topic or pattern will have the same `id` a } ``` -## ping +## eventlog +**Request:** +```json5 +{ + "id": 1, + "jsonrpc": "2.0", + "method": "eventlog", + "params": { + // All events from the past 60 seconds. + // 'since' can also be a literal unix timestamp in milliseconds or + // you can use 'sinceEventId' to get all events since a given + // event id. + "since": -60000, + + // Limit result to 100 latest events in given time period. + "limit": 100 + } +} +``` + +**Response:** +```json +{ + "id": 2, + "jsonrpc": "2.0", + "result": { + "action": "eventlog", + "items": [ + { + "id": "1661265352086-0", + "message": "Event 1", + "topic": "my/topic1" + }, + { + "id": "1661265374910-0", + "message": "Event 1", + "topic": "my/topic1" + }, + { + "id": "1661265379198-0", + "message": "Event 2", + "topic": "my/topic1" + }, + { + "id": "1661265383286-0", + "message": "Event 3", + "topic": "my/topic1" + } + ], + "status": "ok", + "topic": "my/topic1" + } +} +``` +## get **Request:** ```json { "id": 6, "jsonrpc": "2.0", + "method": "get", + "params": { + "key": "my/key" + } +} +``` + +**Response:** +```json +{ + "id": 6, + "jsonrpc": "2.0", + "result": { + "action": "get", + "key": "my/key", + "value": "some value" + } +} +``` + +## set +**Request:** +```json +{ + "id": 7, + "jsonrpc": "2.0", + "method": "set", + "params": { + "key": "my/key", + "value": "some value", + "ttl": 3600 + } +} +``` + +*If `ttl` attribute is omitted or set to `0` it means the key is stored without any expirity time.* + +**Response:** +```json +{ + "id": 7, + "jsonrpc": "2.0", + "result": { + "action": "set", + "key": "my/key", + "success": true + } +} +``` + +## del +**Request:** +```json +{ + "id": 8, + "jsonrpc": "2.0", + "method": "del", + "params": { + "key": "my/key" + } +} +``` + +**Response:** +```json +{ + "id": 8, + "jsonrpc": "2.0", + "result": { + "action": "del", + "key": "my/key", + "success": true + } +} +``` + +## ping + +**Request:** +```json +{ + "id": 9, + "jsonrpc": "2.0", "method": "ping", "params": [] } @@ -169,7 +310,7 @@ All messages received on a subscribed topic or pattern will have the same `id` a **Response:** ```json { - "id": 6, + "id": 9, "jsonrpc": "2.0", "result": { "pong": 1574846750424 @@ -183,7 +324,7 @@ Contents of the `pong` attribute is the server time since epoch in milliseconds. **Request:** ```json { - "id": 7, + "id": 10, "jsonrpc": "2.0", "method": "disconnect", "params": [] diff --git a/docs/rate-limiting.md b/docs/rate-limiting.md new file mode 100644 index 0000000..f375838 --- /dev/null +++ b/docs/rate-limiting.md @@ -0,0 +1,40 @@ +# Rate limiting + +Eventhub allows you to rate limit how many messages a user/token is allowed to publish within a given time period (interval). This is configured by adding ```rlimit``` configuration to the token used by the publisher. + +#### Syntax +```json5 + "sub": "user@domain.com", // Must be defined and unique for limits to work. + "write": [ "topic1/#" ], + "read": [ "topic1/#" ], + "rlimit": [ + { + "topic": "topic1/#", // Topic or pattern to limit. + "interval": 10, // Bucket interval. + "max": 10 // Max allowed publishes within this interval. + } + ] +``` + +You can have multiple limit configuration under ```rlimit```. + +#### Example +```json + "sub": "user@domain.com", + "write": [ "topic1/#", "topic2" ], + "read": [ "topic1/#", "topic2" ], + "rlimit": [ + { + "topic": "topic1/#", + "interval": 10, + "max": 10 + }, + { + "topic": "topic2", + "interval": 10, + "max": 10 + } + ] +``` + +In cases where you have multiple limits that matches a given topic, i.e patterns and distinct topic name, the closest match will be used. diff --git a/docs/sse.md b/docs/sse.md index 9e3ebbf..f1fa2f0 100644 --- a/docs/sse.md +++ b/docs/sse.md @@ -28,7 +28,7 @@ data: Foobar // Message content : // Ping event. ``` -## Requesting cache / event history +## Requesting cache / eventlog | Header | Query parameter | Description | |---------------|-----------------|---------------------------------------------------------------------| | Last-Event-ID | since | Get all events since specified lastevent id when connecting | diff --git a/eventhub.conf.example b/eventhub.conf.example index 5128cf8..4588735 100644 --- a/eventhub.conf.example +++ b/eventhub.conf.example @@ -5,6 +5,7 @@ jwt_secret = FooBarBaz log_level = info disable_auth = true prometheus_metric_prefix = eventhub +enable_kvstore = true # Redis settings. redis_host = localhost diff --git a/include/AccessController.hpp b/include/AccessController.hpp index 4fd4d5a..bceea37 100644 --- a/include/AccessController.hpp +++ b/include/AccessController.hpp @@ -1,17 +1,39 @@ -#ifndef INCLUDE_ACCESSCONTROLLER_HPP_ -#define INCLUDE_ACCESSCONTROLLER_HPP_ +#pragma once #include #include #include +#include -#include "Config.hpp" #include "Forward.hpp" #include "EventhubBase.hpp" #include "jwt/jwt.hpp" namespace eventhub { +struct rlimit_config_t { + std::string topic; + unsigned long interval; + unsigned long max; +}; + +typedef struct rlimit_config_t rlimit_config_t; + +struct NoRateLimitForTopic : public std::exception { + const char* what() const throw() { + return "Token has no rate limits defined"; + } +}; + +class RateLimitConfig final { + private: + std::vector _limitConfigs; + + public: + bool loadFromJSON(const nlohmann::json::array_t& config); + const rlimit_config_t getRateLimitForTopic(const std::string& topic); +}; + class AccessController final : public EventhubBase { private: bool _token_loaded; @@ -20,6 +42,7 @@ class AccessController final : public EventhubBase { jwt::jwt_object _token; std::vector _publish_acl; std::vector _subscribe_acl; + RateLimitConfig _rlimit; public: AccessController(Config &cfg) : @@ -31,8 +54,9 @@ class AccessController final : public EventhubBase { bool allowSubscribe(const std::string& topic); bool allowCreateToken(const std::string& path); const std::string& subject(); + RateLimitConfig& getRateLimitConfig() { return _rlimit; }; }; } // namespace eventhub -#endif // INCLUDE_ACCESSCONTROLLER_HPP_ + diff --git a/include/Common.hpp b/include/Common.hpp index a264b85..9b34158 100644 --- a/include/Common.hpp +++ b/include/Common.hpp @@ -1,5 +1,4 @@ -#ifndef INCLUDE_COMMON_HPP_ -#define INCLUDE_COMMON_HPP_ +#pragma once #include #include @@ -14,22 +13,22 @@ static constexpr unsigned int EPOLL_MAX_TIMEOUT = 100; static constexpr unsigned int MAXEVENTS = 1024; // Read buffer size. -static constexpr size_t NET_READ_BUFFER_SIZE = 128; +static constexpr std::size_t NET_READ_BUFFER_SIZE = 512; // Max write buffer size. -static constexpr size_t NET_WRITE_BUFFER_MAX = (1024 * 1000) * 8; +static constexpr std::size_t NET_WRITE_BUFFER_MAX = (1024 * 1000) * 8; // Hangup connection if data frame is larger than this. -static constexpr size_t MAX_DATA_FRAME_SIZE = (1024 * 1000) * 8; +static constexpr std::size_t MAX_DATA_FRAME_SIZE = (1024 * 1000) * 8; // String used in Sec-WebSocket-Accept header during websocket handshake. -static constexpr const char* WS_MAGIC_STRING = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; +static constexpr const char* WS_MAGIC_STRING = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11\0"; // Will split up into continuation frames above this threshold. -static constexpr size_t WS_MAX_CHUNK_SIZE = 1 << 15; +static constexpr std::size_t WS_MAX_CHUNK_SIZE = 1 << 15; // Hangup connection if control frame is larger than this. -static constexpr size_t WS_MAX_CONTROL_FRAME_SIZE = 1024; +static constexpr std::size_t WS_MAX_CONTROL_FRAME_SIZE = 1024; // Delay metric sample rate. static constexpr unsigned int METRIC_DELAY_SAMPLE_RATE_MS = 5000; @@ -40,4 +39,4 @@ static constexpr unsigned int CACHE_PURGER_INTERVAL_MS = (60 * 1000); // Maximum SSL handshake retries. static const unsigned int SSL_MAX_HANDSHAKE_RETRY = 5; -#endif // INCLUDE_COMMON_HPP_ + diff --git a/include/Config.hpp b/include/Config.hpp index 05be28a..c6614bb 100644 --- a/include/Config.hpp +++ b/include/Config.hpp @@ -1,6 +1,7 @@ #ifndef __INCLUDE_CONFIG_HPP__ #define __INCLUDE_CONFIG_HPP__ +#include #include #include #include @@ -8,6 +9,8 @@ #include #include #include +#include +#include namespace eventhub { @@ -194,4 +197,4 @@ class Config final { }; } -#endif \ No newline at end of file +#endif diff --git a/include/Connection.hpp b/include/Connection.hpp index 7a7040a..8119617 100644 --- a/include/Connection.hpp +++ b/include/Connection.hpp @@ -1,5 +1,4 @@ -#ifndef INCLUDE_CONNECTION_HPP_ -#define INCLUDE_CONNECTION_HPP_ +#pragma once #include #include @@ -9,7 +8,6 @@ #include "EpollWrapper.hpp" #endif #include - #include #include #include @@ -19,16 +17,13 @@ #include #include -#include "EventhubBase.hpp" #include "Forward.hpp" -#include "AccessController.hpp" -#include "Common.hpp" -#include "http/Parser.hpp" +#include "EventhubBase.hpp" +#include "websocket/Types.hpp" +#include "http/Types.hpp" #include "jsonrpc/jsonrpcpp.hpp" -#include "websocket/Parser.hpp" namespace eventhub { - using ConnectionPtr = std::shared_ptr; using ConnectionWeakPtr = std::weak_ptr; using ConnectionListIterator = std::list::iterator; @@ -59,7 +54,7 @@ class Connection : public EventhubBase, public std::enable_shared_from_this::iterator connectionIterator); ConnectionListIterator getConnectionListIterator(); ConnectionPtr getSharedPtr(); @@ -67,7 +62,7 @@ class Connection : public EventhubBase, public std::enable_shared_from_this listSubscriptions(); void onHTTPRequest(http::ParserCallback callback); @@ -87,8 +82,8 @@ class Connection : public EventhubBase, public std::enable_shared_from_this _http_parser; - websocket::Parser _websocket_parser; - AccessController _access_controller; + std::unique_ptr _websocket_parser; + std::unique_ptr _access_controller; ConnectionState _state; bool _is_shutdown; bool _is_shutdown_after_flush; @@ -97,10 +92,8 @@ class Connection : public EventhubBase, public std::enable_shared_from_this #include @@ -7,15 +6,14 @@ #include #include #include +#include #include "Forward.hpp" +#include "metrics/Types.hpp" #include "EventhubBase.hpp" -#include "Config.hpp" -#include "Connection.hpp" #include "EventLoop.hpp" -#include "TopicManager.hpp" #include "Worker.hpp" -#include "metrics/Types.hpp" +#include "Connection.hpp" namespace eventhub { @@ -26,7 +24,7 @@ class Worker final : public EventhubBase, public WorkerBase { Worker(Server* srv, unsigned int workerId); ~Worker(); - TopicManager& getTopicManager() { return _topic_manager; } + TopicManager* getTopicManager() { return _topic_manager.get(); } void subscribeConnection(ConnectionPtr conn, const std::string& topicFilterName); void publish(const std::string& topicName, const std::string& data); @@ -39,10 +37,10 @@ class Worker final : public EventhubBase, public WorkerBase { unsigned int _workerId; Server* _server; int _epoll_fd; - EventLoop _ev; + std::unique_ptr _ev; ConnectionList _connection_list; std::mutex _connection_list_mutex; - TopicManager _topic_manager; + std::unique_ptr _topic_manager; metrics::WorkerMetrics _metrics; int64_t _ev_delay_sample_start; @@ -56,4 +54,4 @@ class Worker final : public EventhubBase, public WorkerBase { } // namespace eventhub -#endif // INCLUDE_CONNECTIONWORKER_HPP_ + diff --git a/include/EventLoop.hpp b/include/EventLoop.hpp index 666cc83..87852e1 100644 --- a/include/EventLoop.hpp +++ b/include/EventLoop.hpp @@ -1,5 +1,4 @@ -#ifndef INCLUDE_EVENTLOOP_HPP_ -#define INCLUDE_EVENTLOOP_HPP_ +#pragma once #include #include @@ -124,5 +123,3 @@ class EventLoop final { }; } // namespace eventhub - -#endif // INCLUDE_EVENTLOOP_HPP_ diff --git a/include/EventhubBase.hpp b/include/EventhubBase.hpp index ea6e49b..2884156 100644 --- a/include/EventhubBase.hpp +++ b/include/EventhubBase.hpp @@ -1,5 +1,4 @@ -#ifndef INCLUDE_EVENTHUBBASE_HPP_ -#define INCLUDE_EVENTHUBBASE_HPP_ +#pragma once #include "Config.hpp" @@ -19,6 +18,4 @@ class EventhubBase { }; -} - -#endif \ No newline at end of file +} \ No newline at end of file diff --git a/include/Forward.hpp b/include/Forward.hpp index a91a201..bb92d53 100644 --- a/include/Forward.hpp +++ b/include/Forward.hpp @@ -1,13 +1,29 @@ -#ifndef INCLUDE_FORWARD_HPP_ -#define INCLUDE_FORWARD_HPP_ - namespace eventhub { - -class Server; -class Worker; +class Config; class Connection; +class HandlerContext; +class KVStore; +class Redis; +class Server; class Topic; +class Worker; +class TopicManager; +class AccessController; + +namespace http { +class Parser; +class Response; +enum class RequestState; +} +namespace websocket { +class Handler; +class Parser; +class Response; } -#endif \ No newline at end of file +namespace sse { +class Handler; +class Response; +} +} \ No newline at end of file diff --git a/include/HandlerContext.hpp b/include/HandlerContext.hpp index 8af6303..1c424ac 100644 --- a/include/HandlerContext.hpp +++ b/include/HandlerContext.hpp @@ -1,16 +1,17 @@ -#ifndef INCLUDE_HANDLERCONTEXT_HPP_ -#define INCLUDE_HANDLERCONTEXT_HPP_ +#pragma once #include -#include "Server.hpp" + +#include "Forward.hpp" #include "EventhubBase.hpp" +#include "Config.hpp" namespace eventhub { class HandlerContext final : public EventhubBase { public: - HandlerContext(Server* server, Worker* worker, std::shared_ptr connection) : - EventhubBase(server->config()), _server(server), _worker(worker), _connection(connection) {}; + HandlerContext(Config& cfg, Server* server, Worker* worker, std::shared_ptr connection) : + EventhubBase(cfg), _server(server), _worker(worker), _connection(connection) {}; ~HandlerContext() {} @@ -26,4 +27,4 @@ class HandlerContext final : public EventhubBase { } // namespace eventhub -#endif // INCLUDE_HANDLERCONTEXT_HPP_ + diff --git a/include/KVStore.hpp b/include/KVStore.hpp new file mode 100644 index 0000000..64b76fc --- /dev/null +++ b/include/KVStore.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include "Forward.hpp" +#include "EventhubBase.hpp" +#include "Config.hpp" + +namespace eventhub { + class KVStore final : public EventhubBase { + private: + std::string _prefix; + Redis& _redis; + + const std::string _prefix_key(const std::string& key) const; + + public: + KVStore(Config& cfg, Redis& redis) : + EventhubBase(cfg), + _redis(redis) { + _prefix = config().get("redis_prefix"); + } + + bool is_enabled(); + const std::string get(const std::string& key) const; + bool set(const std::string& key, const std::string& value, unsigned long ttl = 0) const; + long long del(const std::string& key) const; + }; +} diff --git a/include/Logger.hpp b/include/Logger.hpp index f10138b..ee7c839 100644 --- a/include/Logger.hpp +++ b/include/Logger.hpp @@ -1,5 +1,4 @@ -#ifndef INCLUDE_LOGGER_HPP_ -#define INCLUDE_LOGGER_HPP_ +#pragma once #include #include @@ -56,4 +55,4 @@ class Logger final { #define LOG Logger::getInstance().getLogger() } // namespace eventhub -#endif // INCLUDE_LOGGER_HPP_ + diff --git a/include/RPCHandler.hpp b/include/RPCHandler.hpp index cd11e02..ba08e0e 100644 --- a/include/RPCHandler.hpp +++ b/include/RPCHandler.hpp @@ -1,14 +1,15 @@ -#ifndef INCLUDE_RPCHANDLER_HPP_ -#define INCLUDE_RPCHANDLER_HPP_ +#pragma once #include #include #include #include +#include "Forward.hpp" #include "Connection.hpp" #include "HandlerContext.hpp" #include "jsonrpc/jsonrpcpp.hpp" +#include "jwt/json/json.hpp" namespace eventhub { @@ -20,22 +21,22 @@ class RPCHandler final { static RPCMethod getHandler(const std::string& methodName); private: - RPCHandler(); - ~RPCHandler(); - static void _sendSuccessResponse(HandlerContext& hCtx, jsonrpcpp::request_ptr req, const nlohmann::json& result); static void _sendInvalidParamsError(HandlerContext& hCtx, jsonrpcpp::request_ptr req, const std::string& message); + static void _sendCacheToClient(HandlerContext &hCtx, jsonrpcpp::request_ptr req, const std::string& topic); + static unsigned long long _calculateRelativeSince(long long since); static void _handleSubscribe(HandlerContext& hCtx, jsonrpcpp::request_ptr req); static void _handleUnsubscribe(HandlerContext& hCtx, jsonrpcpp::request_ptr req); static void _handleUnsubscribeAll(HandlerContext& hCtx, jsonrpcpp::request_ptr req); static void _handlePublish(HandlerContext& hCtx, jsonrpcpp::request_ptr req); static void _handleList(HandlerContext& hCtx, jsonrpcpp::request_ptr req); - static void _handleHistory(HandlerContext& hCtx, jsonrpcpp::request_ptr req); + static void _handleEventlog(HandlerContext& hCtx, jsonrpcpp::request_ptr req); + static void _handleGet(HandlerContext& hCtx, jsonrpcpp::request_ptr req); + static void _handleSet(HandlerContext& hCtx, jsonrpcpp::request_ptr req); + static void _handleDelete(HandlerContext& hCtx, jsonrpcpp::request_ptr req); static void _handlePing(HandlerContext& hCtx, jsonrpcpp::request_ptr req); static void _handleDisconnect(HandlerContext& hCtx, jsonrpcpp::request_ptr req); }; } // namespace eventhub - -#endif // INCLUDE_RPCHANDLER_HPP_ diff --git a/include/Redis.hpp b/include/Redis.hpp index 65dcb87..02e7ce8 100644 --- a/include/Redis.hpp +++ b/include/Redis.hpp @@ -1,16 +1,16 @@ -#ifndef INCLUDE_REDIS_HPP_ -#define INCLUDE_REDIS_HPP_ +#pragma once #include - +#include +#include +#include #include #include #include #include #include -#include +#include -#include "Forward.hpp" #include "EventhubBase.hpp" #include "jwt/json/json.hpp" @@ -22,7 +22,7 @@ using RedisMsgCallback = std::function 0) ? _prefix + ":" + key : key) #define REDIS_CACHE_SCORE_PATH(key) std::string(REDIS_PREFIX(key) + ":scores") #define REDIS_CACHE_DATA_PATH(key) std::string(REDIS_PREFIX(key) + ":cache") +#define REDIS_RATE_LIMIT_PATH(key, subject, topic) std::string(REDIS_PREFIX(key) + ":rlimit:" + topic + ":" + subject) public: explicit Redis(Config &cfg); @@ -50,20 +51,23 @@ class Redis final : public EventhubBase { void publishMessage(const std::string& topic, const std::string& id, const std::string& payload, const std::string& origin=""); void psubscribe(const std::string& pattern, RedisMsgCallback callback); - const std::string cacheMessage(const std::string& topic, const std::string& payload, const std::string& origin, long long timestamp = 0, unsigned int ttl = 0); - size_t getCacheSince(const std::string& topicPattern, long long since, long long limit, bool isPattern, nlohmann::json& result); - size_t getCacheSinceId(const std::string& topicPattern, const std::string& sinceId, long long limit, bool isPattern, nlohmann::json& result); - size_t purgeExpiredCacheItems(); + const std::string cacheMessage(const std::string& topic, const std::string& payload, const std::string& origin, long long timestamp = 0, unsigned long ttl = 0); + std::size_t getCacheSince(const std::string& topicPattern, long long since, long long limit, bool isPattern, nlohmann::json& result); + std::size_t getCacheSinceId(const std::string& topicPattern, const std::string& sinceId, long long limit, bool isPattern, nlohmann::json& result); + std::size_t purgeExpiredCacheItems(); void consume(); void resetSubscribers(); - sw::redis::Redis* getRedisInstance() { return _redisInstance.get(); } + std::shared_ptr connection() { return _redisInstance; } void _incrTopicPubCount(const std::string& topicName); std::vector _getTopicsSeen(const std::string& topicPattern); const std::string _getNextCacheId(long long timestamp); + bool isRateLimited(const std::string& topic, const std::string& subject, unsigned long max); + void incrementLimitCount(const std::string& topic, const std::string& subject, unsigned long interval); + private: - std::unique_ptr _redisInstance; + std::shared_ptr _redisInstance; std::unique_ptr _redisSubscriber; std::string _prefix; std::mutex _publish_mtx; @@ -71,4 +75,4 @@ class Redis final : public EventhubBase { } // namespace eventhub -#endif // INCLUDE_REDIS_HPP_ + diff --git a/include/SSLConnection.hpp b/include/SSLConnection.hpp index c73d184..7af0280 100644 --- a/include/SSLConnection.hpp +++ b/include/SSLConnection.hpp @@ -1,11 +1,12 @@ -#ifndef INCLUDE_SSL_CONNECTION_HPP_ -#define INCLUDE_SSL_CONNECTION_HPP_ +#pragma once -#include "Config.hpp" -#include "Forward.hpp" -#include "Connection.hpp" #include #include +#include +#include + +#include "Forward.hpp" +#include "Connection.hpp" namespace eventhub { @@ -27,5 +28,3 @@ class SSLConnection final : public Connection { }; } // namespace eventhub - -#endif \ No newline at end of file diff --git a/include/Server.hpp b/include/Server.hpp index 62bd042..8e39de7 100644 --- a/include/Server.hpp +++ b/include/Server.hpp @@ -1,16 +1,18 @@ -#ifndef INCLUDE_SERVER_HPP_ -#define INCLUDE_SERVER_HPP_ +#pragma once +#include +#include +#include #include #include -#include #include -#include "ConnectionWorker.hpp" -#include "EventLoop.hpp" -#include "Redis.hpp" +#include "Forward.hpp" +#include "KVStore.hpp" #include "Worker.hpp" #include "metrics/Types.hpp" +#include "EventLoop.hpp" +#include "Redis.hpp" namespace eventhub { @@ -27,6 +29,7 @@ class Server final { Worker* getWorker(); void publish(const std::string& topicName, const std::string& data); Redis& getRedis() { return _redis; } + KVStore* getKVStore() { return _kv_store.get(); } metrics::AggregatedMetrics getAggregatedMetrics(); int getSSLServerSocket() { return _server_socket_ssl; }; @@ -44,11 +47,12 @@ class Server final { bool _ssl_enabled; SSL_CTX* _ssl_ctx; std::string _ssl_cert_md5_hash; - std:: string _ssl_priv_key_md5_hash; + std::string _ssl_priv_key_md5_hash; WorkerGroup _connection_workers; WorkerGroup::iterator _cur_worker; std::mutex _connection_workers_lock; Redis _redis; + std::unique_ptr _kv_store; metrics::ServerMetrics _metrics; EventLoop _ev; @@ -62,4 +66,4 @@ class Server final { } // namespace eventhub -#endif // INCLUDE_SERVER_HPP_ + diff --git a/include/Topic.hpp b/include/Topic.hpp index 171d564..a85da73 100644 --- a/include/Topic.hpp +++ b/include/Topic.hpp @@ -1,6 +1,6 @@ -#ifndef INCLUDE_TOPIC_HPP_ -#define INCLUDE_TOPIC_HPP_ +#pragma once +#include #include #include #include @@ -23,7 +23,7 @@ class Topic final { TopicSubscriberList::iterator addSubscriber(ConnectionPtr conn, const jsonrpcpp::Id subscriptionRequestId); void deleteSubscriberByIterator(TopicSubscriberList::iterator it); void publish(const std::string& data); - size_t getSubscriberCount(); + std::size_t getSubscriberCount(); private: std::string _id; @@ -33,4 +33,4 @@ class Topic final { }; // namespace eventhub -#endif // INCLUDE_TOPIC_HPP_ + diff --git a/include/TopicManager.hpp b/include/TopicManager.hpp index ea5ecad..abfd2da 100644 --- a/include/TopicManager.hpp +++ b/include/TopicManager.hpp @@ -1,5 +1,4 @@ -#ifndef INCLUDE_TOPICMANAGER_HPP_ -#define INCLUDE_TOPICMANAGER_HPP_ +#pragma once #include #include @@ -7,9 +6,11 @@ #include #include +#include "Forward.hpp" #include "Common.hpp" #include "Connection.hpp" #include "Topic.hpp" +#include "jsonrpc/jsonrpcpp.hpp" namespace eventhub { @@ -30,6 +31,4 @@ class TopicManager final { TopicList _topic_list; std::mutex _topic_list_lock; }; -} // namespace eventhub - -#endif // INCLUDE_TOPICMANAGER_HPP_ +} // namespace eventhub \ No newline at end of file diff --git a/include/Util.hpp b/include/Util.hpp index 03f9f60..95f0b37 100644 --- a/include/Util.hpp +++ b/include/Util.hpp @@ -1,6 +1,7 @@ -#ifndef INCLUDE_UTIL_HPP_ -#define INCLUDE_UTIL_HPP_ +#pragma once +#include +#include #include #include @@ -8,7 +9,7 @@ namespace eventhub { class Util final { public: - static const std::string base64Encode(const unsigned char* buffer, size_t length); + static const std::string base64Encode(const unsigned char* buffer, std::size_t length); static const std::string uriDecode(const std::string& str); static std::string& strToLower(std::string& s); static int64_t getTimeSinceEpoch(); @@ -22,4 +23,4 @@ class Util final { } // namespace eventhub -#endif // INCLUDE_UTIL_HPP_ + diff --git a/include/Worker.hpp b/include/Worker.hpp index 3412c07..3c09476 100644 --- a/include/Worker.hpp +++ b/include/Worker.hpp @@ -1,5 +1,4 @@ -#ifndef INCLUDE_WORKER_HPP_ -#define INCLUDE_WORKER_HPP_ +#pragma once #include #include @@ -92,4 +91,4 @@ class WorkerGroup { worker_list_t _workers; }; -#endif // INCLUDE_WORKER_HPP_ + diff --git a/include/http/Handler.hpp b/include/http/Handler.hpp index beb08f1..26cca82 100644 --- a/include/http/Handler.hpp +++ b/include/http/Handler.hpp @@ -1,15 +1,9 @@ -#ifndef INCLUDE_HTTP_HANDLER_HPP_ -#define INCLUDE_HTTP_HANDLER_HPP_ +#pragma once #include #include -#include "Connection.hpp" -#include "ConnectionWorker.hpp" -#include "HandlerContext.hpp" -#include "TopicManager.hpp" -#include "http/Parser.hpp" -#include "http/Response.hpp" +#include "Forward.hpp" namespace eventhub { namespace http { @@ -31,4 +25,4 @@ class Handler final { } // namespace http } // namespace eventhub -#endif // INCLUDE_HTTP_HANDLER_HPP_ + diff --git a/include/http/Parser.hpp b/include/http/Parser.hpp index b3c75fb..373b1d8 100644 --- a/include/http/Parser.hpp +++ b/include/http/Parser.hpp @@ -1,5 +1,4 @@ -#ifndef INCLUDE_HTTP_PARSER_HPP_ -#define INCLUDE_HTTP_PARSER_HPP_ +#pragma once #include #include @@ -7,19 +6,11 @@ #include #include "http/picohttpparser.h" +#include "http/Types.hpp" namespace eventhub { namespace http { -enum class RequestState { - REQ_FAILED, - REQ_INCOMPLETE, - REQ_TO_BIG, - REQ_OK -}; - -using ParserCallback = std::function; - class Parser final { #define HTTP_BUFSIZ 8192 #define HTTP_REQUEST_MAX_HEADERS 100 @@ -27,13 +18,13 @@ class Parser final { public: Parser(); ~Parser(); - void parse(const char* data, int len); + void parse(const char* data, std::size_t len); const std::string& getPath(); const std::string& getMethod(); const std::map& getHeaders(); const std::string getHeader(std::string header); const std::string getQueryString(std::string param); - size_t numQueryString(); + std::size_t numQueryString(); const std::string& getErrorMessage(); void setCallback(ParserCallback callback); @@ -44,7 +35,7 @@ class Parser final { bool _is_complete; const char *_phr_method, *_phr_path; struct phr_header _phr_headers[HTTP_REQUEST_MAX_HEADERS]; - size_t _phr_num_headers, _phr_method_len, _phr_path_len; + std::size_t _phr_num_headers, _phr_method_len, _phr_path_len; int _phr_minor_version; std::string _path; std::string _method; @@ -53,11 +44,11 @@ class Parser final { std::map _query_parameters; std::map _qsmap; - size_t _parse_query_string(const std::string& buf); + std::size_t _parse_query_string(const std::string& buf); ParserCallback _callback; }; } // namespace http } // namespace eventhub -#endif // INCLUDE_HTTP_PARSER_HPP_ + diff --git a/include/http/Response.hpp b/include/http/Response.hpp index 841d16d..b255a89 100644 --- a/include/http/Response.hpp +++ b/include/http/Response.hpp @@ -1,5 +1,4 @@ -#ifndef INCLUDE_HTTP_RESPONSE_HPP_ -#define INCLUDE_HTTP_RESPONSE_HPP_ +#pragma once #include #include @@ -30,4 +29,3 @@ class Response final { } // namespace http } // namespace eventhub -#endif // INCLUDE_HTTP_RESPONSE_HPP_ \ No newline at end of file diff --git a/include/http/Types.hpp b/include/http/Types.hpp new file mode 100644 index 0000000..5283d33 --- /dev/null +++ b/include/http/Types.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include +#include +#include + +namespace eventhub { +namespace http { + +enum class RequestState { + REQ_FAILED, + REQ_INCOMPLETE, + REQ_TO_BIG, + REQ_OK +}; + +using ParserCallback = std::function; + +} // namespace http +} // namespace eventhub + + diff --git a/include/metrics/JsonRenderer.hpp b/include/metrics/JsonRenderer.hpp index 083cddc..8d6995f 100644 --- a/include/metrics/JsonRenderer.hpp +++ b/include/metrics/JsonRenderer.hpp @@ -1,5 +1,4 @@ -#ifndef INCLUDE_METRICS_JSON_RENDERER_HPP_ -#define INCLUDE_METRICS_JSON_RENDERER_HPP_ +#pragma once #include @@ -17,4 +16,4 @@ class JsonRenderer final { } // namespace metrics } // namespace eventhub -#endif // INCLUDE_METRICS_JSON_RENDERER_HPP_ + diff --git a/include/metrics/PrometheusRenderer.hpp b/include/metrics/PrometheusRenderer.hpp index f16b70c..163932f 100644 --- a/include/metrics/PrometheusRenderer.hpp +++ b/include/metrics/PrometheusRenderer.hpp @@ -1,12 +1,12 @@ -#ifndef INCLUDE_METRICS_PROMETHEUS_RENDERER_HPP_ -#define INCLUDE_METRICS_PROMETHEUS_RENDERER_HPP_ +#pragma once #include -#include "metrics/Types.hpp" #include "Forward.hpp" +#include "metrics/Types.hpp" namespace eventhub { + namespace metrics { class PrometheusRenderer final { @@ -17,4 +17,4 @@ class PrometheusRenderer final { } // namespace metrics } // namespace eventhub -#endif // INCLUDE_METRICS_PROMETHEUS_RENDERER_HPP_ + diff --git a/include/metrics/Types.hpp b/include/metrics/Types.hpp index 4f75d21..8e9d70a 100644 --- a/include/metrics/Types.hpp +++ b/include/metrics/Types.hpp @@ -1,5 +1,4 @@ -#ifndef INCLUDE_METRICS_TYPES_HPP_ -#define INCLUDE_METRICS_TYPES_HPP_ +#pragma once #include #include @@ -49,4 +48,4 @@ struct AggregatedMetrics { } // namespace metrics } // namespace eventhub -#endif // INCLUDE_METRICS_TYPES_HPP_ + diff --git a/include/sse/Handler.hpp b/include/sse/Handler.hpp index 448a52d..18cf8cd 100644 --- a/include/sse/Handler.hpp +++ b/include/sse/Handler.hpp @@ -1,15 +1,11 @@ -#ifndef INCLUDE_SSE_HANDLER_HPP_ -#define INCLUDE_SSE_HANDLER_HPP_ +#pragma once #include #include #include #include -#include "Connection.hpp" -#include "ConnectionWorker.hpp" -#include "HandlerContext.hpp" -#include "http/Parser.hpp" +#include "Forward.hpp" namespace eventhub { namespace sse { @@ -26,4 +22,4 @@ class Handler final { } // namespace sse } // namespace eventhub -#endif // INCLUDE_SSE_HANDLER_HPP_ + diff --git a/include/sse/Response.hpp b/include/sse/Response.hpp index 423ff8e..9c48d60 100644 --- a/include/sse/Response.hpp +++ b/include/sse/Response.hpp @@ -1,5 +1,4 @@ -#ifndef INCLUDE_SSE_RESPONSE_HPP_ -#define INCLUDE_SSE_RESPONSE_HPP_ +#pragma once #include #include @@ -20,4 +19,4 @@ class Response final { } // namespace sse } // namespace eventhub -#endif // INCLUDE_SSE_RESPONSE_HPP_ + diff --git a/include/websocket/Handler.hpp b/include/websocket/Handler.hpp index c1f50ad..c3b3abe 100644 --- a/include/websocket/Handler.hpp +++ b/include/websocket/Handler.hpp @@ -1,17 +1,16 @@ -#ifndef INCLUDE_WEBSOCKET_HANDLER_HPP_ -#define INCLUDE_WEBSOCKET_HANDLER_HPP_ +#pragma once #include #include #include #include -#include "Connection.hpp" -#include "ConnectionWorker.hpp" +#include "Forward.hpp" #include "HandlerContext.hpp" #include "websocket/Types.hpp" namespace eventhub { + namespace websocket { class Handler final { @@ -28,4 +27,4 @@ class Handler final { } // namespace websocket } // namespace eventhub -#endif // INCLUDE_WEBSOCKET_HANDLER_HPP_ + diff --git a/include/websocket/Parser.hpp b/include/websocket/Parser.hpp index e5ca9e7..fc6aeaf 100644 --- a/include/websocket/Parser.hpp +++ b/include/websocket/Parser.hpp @@ -1,6 +1,6 @@ -#ifndef INCLUDE_WEBSOCKET_PARSER_HPP_ -#define INCLUDE_WEBSOCKET_PARSER_HPP_ +#pragma once +#include #include #include @@ -15,13 +15,13 @@ class Parser final { Parser(); ~Parser() {} - void parse(char* buf, size_t len); + void parse(char* buf, std::size_t len); void clearDataPayload(); void clearControlPayload(); - void appendDataPayload(const char* data, size_t len); - void appendControlPayload(const char* data, size_t len); + void appendDataPayload(const char* data, std::size_t len); + void appendControlPayload(const char* data, std::size_t len); void setControlFrameType(FrameType frameType); void setDataFrameType(FrameType frameType); @@ -48,4 +48,4 @@ class Parser final { } // namespace websocket } // namespace eventhub -#endif // INCLUDE_WEBSOCKET_PARSER_HPP_ + diff --git a/include/websocket/Response.hpp b/include/websocket/Response.hpp index d91b280..e8c36df 100644 --- a/include/websocket/Response.hpp +++ b/include/websocket/Response.hpp @@ -1,6 +1,6 @@ -#ifndef INCLUDE_WEBSOCKET_RESPONSE_HPP_ -#define INCLUDE_WEBSOCKET_RESPONSE_HPP_ +#pragma once +#include #include #include @@ -21,4 +21,4 @@ class Response final { } // namespace websocket } // namespace eventhub -#endif // INCLUDE_WEBSOCKET_RESPONSE_HPP_ + diff --git a/include/websocket/Types.hpp b/include/websocket/Types.hpp index 9d5e5c6..eeafb8b 100644 --- a/include/websocket/Types.hpp +++ b/include/websocket/Types.hpp @@ -1,5 +1,4 @@ -#ifndef INCLUDE_WEBSOCKET_TYPES_HPP_ -#define INCLUDE_WEBSOCKET_TYPES_HPP_ +#pragma once #include #include @@ -28,4 +27,4 @@ using ParserCallback = std::function +#include #include #include #include +#include +#include +#include +#include "AccessController.hpp" #include "Config.hpp" #include "TopicManager.hpp" -#include "Server.hpp" +#include "Logger.hpp" +#include "jwt/impl/jwt.ipp" +#include "jwt/json/json.hpp" +#include "jwt/parameters.hpp" namespace eventhub { @@ -44,12 +51,19 @@ bool AccessController::authenticate(const std::string& jwtToken, const std::stri } } + if ((_subscribe_acl.size() + _publish_acl.size()) == 0) { + throw std::invalid_argument("No publish or subscribe ACL defined in JWT token."); + } + if (payload.has_claim("sub")) { _subject = payload.get_claim_value("sub"); } - if ((_subscribe_acl.size() + _publish_acl.size()) == 0) { - throw std::invalid_argument("No publish or subscribe ACL defined in JWT token."); + if (payload.has_claim("rlimit")) { + auto payload_json = payload.create_json_obj(); + if (payload_json.is_object() && payload_json["rlimit"].is_array()) { + _rlimit.loadFromJSON(payload_json["rlimit"]); + } } } catch (std::exception& e) { LOG->trace("Error in AccessController: ", e.what()); @@ -106,4 +120,60 @@ bool AccessController::allowCreateToken(const std::string& path) { return true; } + +bool RateLimitConfig::loadFromJSON(const nlohmann::json::array_t& config) { + for (const auto& rlimit: config) { + if (!rlimit.is_object()) + continue; + + try { + auto topic = rlimit["topic"].get(); + auto interval = rlimit["interval"].get(); + auto max = rlimit["max"].get(); + + _limitConfigs.push_back(rlimit_config_t{topic, interval, max}); + } catch (...) { + continue; + } + } + + return true; +} + +// Returns limits for given topic if there are any. +// If no limits is defined we throw NoRateLimitForTopic exception. +const rlimit_config_t RateLimitConfig::getRateLimitForTopic(const std::string& topic) { + rlimit_config_t rlimit; + bool found = false; + std::size_t matchedPatternLen = 0; + + // Exit early if no limits is present in token. + if (_limitConfigs.empty()) + throw(NoRateLimitForTopic{}); + + /* + Check if we have any limits defined for topic or matching pattern. + */ + for (const auto& limit : _limitConfigs) { + // Exact match has highest precedence, use that if we have one. + if (limit.topic.compare(topic) == 0) { + rlimit = limit; + found = true; + break; + } + + // We check the length here to chose the closest matching pattern if there are more than one. + if (TopicManager::isFilterMatched(limit.topic, topic) && limit.topic.length() > matchedPatternLen) { + rlimit = limit; + found = true; + matchedPatternLen = limit.topic.length(); + } + } + + if (!found) + throw(NoRateLimitForTopic{}); + + return rlimit; +} + } // namespace eventhub diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 521ac57..cbad25e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -14,6 +14,7 @@ set (SOURCES Config.cpp RPCHandler.cpp Redis.cpp + KVStore.cpp Util.cpp Topic.cpp TopicManager.cpp diff --git a/src/Config.cpp b/src/Config.cpp index 0bbffab..ad86253 100644 --- a/src/Config.cpp +++ b/src/Config.cpp @@ -1,15 +1,12 @@ -#include "Config.hpp" -#include -#include +#include +#include #include #include -#include +#include #include -#include -#include #include -#include -#include + +#include "Config.hpp" namespace eventhub { Config::Config() {} @@ -21,7 +18,7 @@ void Config::_loadFromStream(std::istream& data, const std::string& path) { while (std::getline(data, line)) { std::string parsedKey, parsedValue; - unsigned int pos = 0; + std::size_t pos = 0; lineNo++; @@ -89,13 +86,13 @@ void Config::_loadFromEnv() { for (const auto& it_opt : _options) { auto optName = it_opt.first; auto val = getenv(optName.c_str()); - if (val != NULL) { + if (val != nullptr) { it_opt.second->_set(val); it_opt.second->_hasValue = true; } else { strToUpper(optName); val = getenv(optName.c_str()); - if (val != NULL) { + if (val != nullptr) { it_opt.second->_set(val); it_opt.second->_hasValue = true; } @@ -125,4 +122,4 @@ void Config::clearValues() { opt.second->clear(); } } -} \ No newline at end of file +} diff --git a/src/Connection.cpp b/src/Connection.cpp index 21fc2bb..6b5ea21 100644 --- a/src/Connection.cpp +++ b/src/Connection.cpp @@ -1,35 +1,31 @@ -#include "Connection.hpp" - #include #include #include #include -#include -#include #include #include #include - -#include +#include #include #include #include #include +#include "Forward.hpp" +#include "Connection.hpp" #include "Common.hpp" -#include "Config.hpp" #include "ConnectionWorker.hpp" -#include "Server.hpp" #include "Topic.hpp" #include "TopicManager.hpp" -#include "Util.hpp" #include "http/Parser.hpp" #include "websocket/Parser.hpp" +#include "AccessController.hpp" +#include "Logger.hpp" namespace eventhub { Connection::Connection(int fd, struct sockaddr_in* csin, Worker* worker, Config& cfg) : - EventhubBase(cfg), _fd(fd), _worker(worker), _access_controller(cfg) { + EventhubBase(cfg), _fd(fd), _worker(worker) { _is_shutdown = false; _is_shutdown_after_flush = false; @@ -55,6 +51,8 @@ Connection::Connection(int fd, struct sockaddr_in* csin, Worker* worker, Config& LOG->trace("Client {} connected.", getIP()); _http_parser = std::make_unique(); + _websocket_parser = std::make_unique(); + _access_controller = std::make_unique(cfg); // Set initial state. setState(ConnectionState::HTTP); @@ -92,7 +90,7 @@ void Connection::_disableEpollOut() { /** * Remove n bytes from the beginning og the write buffer. */ -size_t Connection::_pruneWriteBuffer(size_t bytes) { +std::size_t Connection::_pruneWriteBuffer(std::size_t bytes) { if (_write_buffer.length() < 1) { return 0; } @@ -116,7 +114,7 @@ void Connection::read() { return; } - size_t bytesRead = 0; + ssize_t bytesRead = 0; bytesRead = ::read(_fd, _read_buffer.data(), NET_READ_BUFFER_SIZE); if (bytesRead <= 0) { @@ -133,7 +131,7 @@ void Connection::read() { /** * Parse the request present in our read buffer and call the correct handler. */ -void Connection::_parseRequest(size_t bytesRead) { +void Connection::_parseRequest(std::size_t bytesRead) { // Redirect request to either HTTP handler or websocket handler // based on which state the client is in. switch (getState()) { @@ -142,7 +140,7 @@ void Connection::_parseRequest(size_t bytesRead) { break; case ConnectionState::WEBSOCKET: - _websocket_parser.parse(_read_buffer.data(), bytesRead); + _websocket_parser->parse(_read_buffer.data(), bytesRead); break; default: @@ -187,7 +185,7 @@ ssize_t Connection::flushSendBuffer() { return 0; } - int ret = ::write(_fd, _write_buffer.c_str(), _write_buffer.length()); + ssize_t ret = ::write(_fd, _write_buffer.c_str(), _write_buffer.length()); if (ret <= 0) { if (errno != EAGAIN && errno != EWOULDBLOCK) { @@ -196,7 +194,7 @@ ssize_t Connection::flushSendBuffer() { } else { _enableEpollOut(); } - } else if ((unsigned int)ret < _write_buffer.length()) { + } else if ((std::size_t)ret < _write_buffer.length()) { LOG->trace("Client {} could not write() entire buffer, wrote {} of {} bytes.", ret, _write_buffer.length()); _pruneWriteBuffer(ret); _enableEpollOut(); @@ -245,7 +243,7 @@ const std::string Connection::getIP() { int Connection::addToEpoll(uint32_t epollEvents) { _epoll_event.events = epollEvents; _epoll_event.data.fd = _fd; - _epoll_event.data.ptr = reinterpret_cast(this); + _epoll_event.data.ptr = static_cast(this); int ret = epoll_ctl(_worker->getEpollFileDescriptor(), EPOLL_CTL_ADD, _fd, &_epoll_event); @@ -271,13 +269,13 @@ ConnectionState Connection::setState(ConnectionState newState) { void Connection::subscribe(const std::string& topicPattern, const jsonrpcpp::Id subscriptionRequestId) { std::lock_guard lock(_subscription_list_lock); - auto& tm = _worker->getTopicManager(); + auto tm = _worker->getTopicManager(); if (_subscribedTopics.count(topicPattern)) { return; } - auto topicSubscription = tm.subscribeConnection(getSharedPtr(), topicPattern, subscriptionRequestId); + auto topicSubscription = tm->subscribeConnection(getSharedPtr(), topicPattern, subscriptionRequestId); _subscribedTopics.insert(std::make_pair(topicPattern, TopicSubscription{topicSubscription.first, topicSubscription.second, subscriptionRequestId})); } @@ -286,15 +284,15 @@ ConnectionState Connection::getState() { } void Connection::onWebsocketRequest(websocket::ParserCallback callback) { - _websocket_parser.setCallback(callback); + _websocket_parser->setCallback(callback); } void Connection::onHTTPRequest(http::ParserCallback callback) { _http_parser->setCallback(callback); } -AccessController& Connection::getAccessController() { - return _access_controller; +AccessController* Connection::getAccessController() { + return _access_controller.get(); } void Connection::assignConnectionListIterator(std::list::iterator connectionIterator) { @@ -311,7 +309,7 @@ ConnectionPtr Connection::getSharedPtr() { bool Connection::unsubscribe(const std::string& topicPattern) { std::lock_guard lock(_subscription_list_lock); - auto& tm = _worker->getTopicManager(); + auto tm = _worker->getTopicManager(); if (_subscribedTopics.count(topicPattern) == 0) { return false; @@ -323,7 +321,7 @@ bool Connection::unsubscribe(const std::string& topicPattern) { subscription.topic->deleteSubscriberByIterator(subscription.topicListIterator); if (subscription.topic->getSubscriberCount() == 0) { - tm.deleteTopic(topicPattern); + tm->deleteTopic(topicPattern); } _subscribedTopics.erase(it); @@ -331,17 +329,17 @@ bool Connection::unsubscribe(const std::string& topicPattern) { return true; } -unsigned int Connection::unsubscribeAll() { +std::size_t Connection::unsubscribeAll() { std::lock_guard lock(_subscription_list_lock); - auto& tm = _worker->getTopicManager(); - unsigned int count = _subscribedTopics.size(); + auto tm = _worker->getTopicManager(); + auto count = _subscribedTopics.size(); for (auto it = _subscribedTopics.begin(); it != _subscribedTopics.end();) { auto& subscription = it->second; subscription.topic->deleteSubscriberByIterator(subscription.topicListIterator); if (subscription.topic->getSubscriberCount() == 0) { - tm.deleteTopic(it->first); + tm->deleteTopic(it->first); } it = _subscribedTopics.erase(it); diff --git a/src/ConnectionWorker.cpp b/src/ConnectionWorker.cpp index 2187e90..8187451 100644 --- a/src/ConnectionWorker.cpp +++ b/src/ConnectionWorker.cpp @@ -1,9 +1,12 @@ -#include "ConnectionWorker.hpp" - #include -#include #include #include +#include + +#include "ConnectionWorker.hpp" +#include "Logger.hpp" +#include "http/Parser.hpp" +#include "websocket/Types.hpp" #ifdef __linux__ #include #else @@ -12,12 +15,12 @@ #include #include #include - #include #include #include -#include #include +#include +#include #include "Common.hpp" #include "Config.hpp" @@ -26,12 +29,11 @@ #include "HandlerContext.hpp" #include "SSLConnection.hpp" #include "Server.hpp" +#include "TopicManager.hpp" #include "Util.hpp" -#include "Worker.hpp" #include "http/Handler.hpp" #include "sse/Response.hpp" #include "websocket/Handler.hpp" -#include "websocket/Parser.hpp" #include "websocket/Response.hpp" namespace eventhub { @@ -39,6 +41,9 @@ namespace eventhub { Worker::Worker(Server* srv, unsigned int workerId) : EventhubBase(srv->config()), _workerId(workerId) { _server = srv; _epoll_fd = epoll_create1(0); + + _ev = std::make_unique(); + _topic_manager = std::make_unique(); } Worker::~Worker() { @@ -56,7 +61,7 @@ Worker::~Worker() { } void Worker::addTimer(int64_t delay, std::function callback, bool repeat) { - _ev.addTimer(delay, callback, repeat); + _ev->addTimer(delay, callback, repeat); } /** @@ -119,7 +124,7 @@ ConnectionPtr Worker::_addConnection(int fd, struct sockaddr_in* csin, bool ssl) auto c = wptrClient.lock(); if (!c) return; - http::Handler::HandleRequest(HandlerContext(_server, this, c), req, reqState); + http::Handler::HandleRequest(HandlerContext(_config, _server, this, c), req, reqState); }); // Set up websocket request callback. @@ -129,7 +134,7 @@ ConnectionPtr Worker::_addConnection(int fd, struct sockaddr_in* csin, bool ssl) auto c = wptrClient.lock(); if (!c) return; - websocket::Handler::HandleRequest(HandlerContext(_server, this, c), + websocket::Handler::HandleRequest(HandlerContext(_config, _server, this, c), status, frameType, data); }); @@ -195,8 +200,8 @@ void Worker::_removeConnection(ConnectionPtr conn) { } void Worker::publish(const std::string& topicName, const std::string& data) { - _ev.addJob([this, topicName, data]() { - _topic_manager.publish(topicName, data); + _ev->addJob([this, topicName, data]() { + _topic_manager->publish(topicName, data); }); } @@ -214,7 +219,7 @@ void Worker::_workerMain() { _ev_delay_sample_start = Util::getTimeSinceEpoch(); // Sample eventloop delay every and store it in our metrics. - _ev.addTimer( + _ev->addTimer( METRIC_DELAY_SAMPLE_RATE_MS, [&](TimerCtx* ctx) { const auto epoch = Util::getTimeSinceEpoch(); long diff = epoch - _ev_delay_sample_start - METRIC_DELAY_SAMPLE_RATE_MS; @@ -234,9 +239,11 @@ void Worker::_workerMain() { serverSocketEvent.events = EPOLLIN | EPOLLEXCLUSIVE; serverSocketEvent.data.fd = _server->getServerSocket(); - if (epoll_ctl(_epoll_fd, EPOLL_CTL_ADD, _server->getServerSocket(), &serverSocketEvent) == -1) { - LOG->critical("Failed to add serversocket to epoll in AcceptWorker {}: {}", getWorkerId(), strerror(errno)); - exit(1); + if (!config().get("disable_unsecure_listener")) { + if (epoll_ctl(_epoll_fd, EPOLL_CTL_ADD, _server->getServerSocket(), &serverSocketEvent) == -1) { + LOG->critical("Failed to add serversocket to epoll in AcceptWorker {}: {}", getWorkerId(), strerror(errno)); + exit(1); + } } // Add server listening socket to epoll. @@ -251,10 +258,10 @@ void Worker::_workerMain() { } while (!stopRequested()) { - unsigned int timeout = EPOLL_MAX_TIMEOUT; + std::size_t timeout = EPOLL_MAX_TIMEOUT; - if (_ev.hasWork() && _ev.getNextTimerDelay().count() < EPOLL_MAX_TIMEOUT) { - timeout = _ev.getNextTimerDelay().count(); + if (_ev->hasWork() && _ev->getNextTimerDelay().count() < EPOLL_MAX_TIMEOUT) { + timeout = _ev->getNextTimerDelay().count(); } int n = epoll_wait(_epoll_fd, eventConnectionList, MAXEVENTS, timeout); @@ -297,7 +304,7 @@ void Worker::_workerMain() { } // Process timers and jobs. - _ev.process(); + _ev->process(); } } } // namespace eventhub diff --git a/src/KVStore.cpp b/src/KVStore.cpp new file mode 100644 index 0000000..f564957 --- /dev/null +++ b/src/KVStore.cpp @@ -0,0 +1,43 @@ +#include +#include +#include +#include +#include + +#include "KVStore.hpp" +#include "Config.hpp" +#include "Redis.hpp" + +namespace eventhub { + const std::string KVStore::_prefix_key(const std::string& key) const { + return _prefix.empty() ? std::string(_prefix + ":kv:" + key) + : std::string("kv:" + key); + } + + bool KVStore::is_enabled() { + return config().get("enable_kvstore"); + } + + const std::string KVStore::get(const std::string& key) const { + const std::string value = _redis.connection()->get(_prefix_key(key)).value(); + return value; + } + + bool KVStore::set(const std::string& key, const std::string& value, unsigned long ttl) const { + if (ttl > 0) { + return _redis.connection()->set(_prefix_key(key), value, std::chrono::seconds(ttl)); + } else { + return _redis.connection()->set(_prefix_key(key), value); + } + } + + long long KVStore::del(const std::string& key) const { + long long ret = 0; + + try { + ret = _redis.connection()->del(_prefix_key(key)); + } catch(...) {} + + return ret; + } +} \ No newline at end of file diff --git a/src/RPCHandler.cpp b/src/RPCHandler.cpp index 1cb27b5..7fa2f27 100644 --- a/src/RPCHandler.cpp +++ b/src/RPCHandler.cpp @@ -1,18 +1,26 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + #include "RPCHandler.hpp" -#include "Common.hpp" #include "Config.hpp" #include "Connection.hpp" -#include "ConnectionWorker.hpp" #include "HandlerContext.hpp" #include "Redis.hpp" #include "Server.hpp" #include "TopicManager.hpp" #include "Util.hpp" #include "websocket/Response.hpp" - -#include -#include -#include +#include "AccessController.hpp" +#include "KVStore.hpp" +#include "Logger.hpp" +#include "websocket/Types.hpp" namespace eventhub { @@ -28,7 +36,10 @@ RPCMethod RPCHandler::getHandler(const std::string& methodName) { {"unsubscribeall", _handleUnsubscribeAll}, {"publish", _handlePublish}, {"list", _handleList}, - {"history", _handleHistory}, + {"eventlog", _handleEventlog}, + {"get", _handleGet}, + {"set", _handleSet}, + {"del", _handleDelete}, {"ping", _handlePing}, {"disconnect", _handleDisconnect}}; @@ -57,30 +68,49 @@ void RPCHandler::_sendSuccessResponse(HandlerContext& ctx, jsonrpcpp::request_pt } /** - * Handle subscribe RPC command. - * Subscribe client to given topic pattern. - * @param ctx Client issuing request. - * @param req RPC request. + * If provided a negative number calculate the relative since from + * epoch now in milliseconds - abs(since). + * + * @param since negative number to subtract from epoch now. + * @return epoch now in milliseconds - abs(since). */ -void RPCHandler::_handleSubscribe(HandlerContext& ctx, jsonrpcpp::request_ptr req) { - auto& accessController = ctx.connection()->getAccessController(); - auto params = req->params(); - std::string topicName; +unsigned long long RPCHandler::_calculateRelativeSince(long long since) { + if (since >= 0) + return since; + + auto now_ms = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); + + return now_ms - abs(since); +} + +/** + * Helper function for sending cached events to client if requested. + */ +void RPCHandler::_sendCacheToClient(HandlerContext &ctx, jsonrpcpp::request_ptr req, const std::string& topic) { + // Return early if cache is not enabled. + if (!ctx.config().get("enable_cache")) { + return; + } + std::string sinceEventId; unsigned long long since, limit; - std::stringstream msg; + auto params = req->params(); try { - topicName = params.get("topic").get(); sinceEventId = params.get("sinceEventId").get(); } catch (...) {} try { since = params.get("since").get(); + since = _calculateRelativeSince(since); } catch (...) { since = 0; } + if (sinceEventId.empty() && since == 0) { + return; + } + try { limit = params.get("limit").get(); } catch (...) { @@ -91,21 +121,50 @@ void RPCHandler::_handleSubscribe(HandlerContext& ctx, jsonrpcpp::request_ptr re limit = ctx.config().get("max_cache_request_limit"); } + try { + nlohmann::json result; + auto& redis = ctx.server()->getRedis(); + if (!sinceEventId.empty()) + redis.getCacheSinceId(topic, sinceEventId, limit, TopicManager::isValidTopicFilter(topic), result); + else + redis.getCacheSince(topic, since, limit, TopicManager::isValidTopicFilter(topic), result); + + for (auto& cacheItem : result) { + _sendSuccessResponse(ctx, req, cacheItem); + } + } catch (std::exception& e) { + LOG->error("Error while looking up cache: {}.", e.what()); + } +} + +/** + * Handle subscribe RPC command. + * Subscribe client to given topic pattern. + * @param ctx Client issuing request. + * @param req RPC request. + */ +void RPCHandler::_handleSubscribe(HandlerContext& ctx, jsonrpcpp::request_ptr req) { + auto accessController = ctx.connection()->getAccessController(); + auto params = req->params(); + std::string topicName; + std::stringstream msg; + + try { + topicName = params.get("topic").get(); + } catch (...) {} + if (topicName.empty()) { - _sendInvalidParamsError(ctx, req, "You must specify 'topic' to subscribe to."); - return; + return _sendInvalidParamsError(ctx, req, "You must specify 'topic' to subscribe to."); } if (!TopicManager::isValidTopicOrFilter(topicName)) { msg << "Invalid topic in request: " << topicName; - _sendInvalidParamsError(ctx, req, msg.str()); - return; + return _sendInvalidParamsError(ctx, req, msg.str()); } - if (!accessController.allowSubscribe(topicName)) { + if (!accessController->allowSubscribe(topicName)) { msg << "You are not allowed to subscribe to topic: " << topicName; - _sendInvalidParamsError(ctx, req, msg.str()); - return; + return _sendInvalidParamsError(ctx, req, msg.str()); } ctx.connection()->subscribe(topicName, req->id()); @@ -118,25 +177,8 @@ void RPCHandler::_handleSubscribe(HandlerContext& ctx, jsonrpcpp::request_ptr re _sendSuccessResponse(ctx, req, result); - // Send cached events if since is set. - if (!sinceEventId.empty() || since > 0) { - try { - nlohmann::json result; - auto& redis = ctx.server()->getRedis(); - - if (!sinceEventId.empty()) { - redis.getCacheSinceId(topicName, sinceEventId, limit, TopicManager::isValidTopicFilter(topicName), result); - } else { - redis.getCacheSince(topicName, since, limit, TopicManager::isValidTopicFilter(topicName), result); - } - - for (auto& cacheItem : result) { - _sendSuccessResponse(ctx, req, cacheItem); - } - } catch (std::exception& e) { - LOG->error("Redis error while looking up cache: {}.", e.what()); - } - } + // Send cached events if requested. + _sendCacheToClient(ctx, req, topicName); } /** @@ -149,14 +191,13 @@ void RPCHandler::_handleUnsubscribe(HandlerContext& ctx, jsonrpcpp::request_ptr auto accessController = ctx.connection()->getAccessController(); if (!req->params().is_array()) { - _sendInvalidParamsError(ctx, req, "Parameter is not array of topics to unsubscribe from."); - return; + return _sendInvalidParamsError(ctx, req, "Parameter is not array of topics to unsubscribe from."); } auto topics = req->params().to_json(); - unsigned int count = 0; + std::size_t count = 0; for (auto topic : topics) { - if (!TopicManager::isValidTopicOrFilter(topic) || !accessController.allowSubscribe(topic)) { + if (!TopicManager::isValidTopicOrFilter(topic) || !accessController->allowSubscribe(topic)) { continue; } @@ -195,7 +236,7 @@ void RPCHandler::_handlePublish(HandlerContext& ctx, jsonrpcpp::request_ptr req) std::string message; std::stringstream msg; long long timestamp; - unsigned int ttl; + std::size_t ttl; auto accessController = ctx.connection()->getAccessController(); auto params = req->params(); @@ -207,20 +248,17 @@ void RPCHandler::_handlePublish(HandlerContext& ctx, jsonrpcpp::request_ptr req) if (topicName.empty() || message.empty()) { msg << "You need to specify topic and message to publish to."; - _sendInvalidParamsError(ctx, req, msg.str()); - return; + return _sendInvalidParamsError(ctx, req, msg.str()); } - if (!accessController.allowPublish(topicName)) { + if (!accessController->allowPublish(topicName)) { msg << "Insufficient access to topic: " << topicName; - _sendInvalidParamsError(ctx, req, msg.str()); - return; + return _sendInvalidParamsError(ctx, req, msg.str()); } if (!TopicManager::isValidTopic(topicName)) { msg << topicName << " is not a valid topic."; - _sendInvalidParamsError(ctx, req, msg.str()); - return; + return _sendInvalidParamsError(ctx, req, msg.str()); } try { @@ -230,22 +268,41 @@ void RPCHandler::_handlePublish(HandlerContext& ctx, jsonrpcpp::request_ptr req) } try { - ttl = params.get("ttl").get(); + ttl = params.get("ttl").get(); } catch (...) { ttl = 0; } try { auto& redis = ctx.server()->getRedis(); - auto id = redis.cacheMessage(topicName, message, accessController.subject(), timestamp, ttl); + const auto& subject = accessController->subject(); + + if (!subject.empty()) { + try { + const auto limits = accessController->getRateLimitConfig().getRateLimitForTopic(topicName); + + if (redis.isRateLimited(limits.topic, subject, limits.max)) { + LOG->trace("PUBLISH {}: User {} is currently ratelimited. Interval: {} Max: {} Matched ratelimit pattern: {}", topicName, subject, limits.interval, limits.max, limits.topic); + nlohmann::json result; + result["action"] = "publish"; + result["topic"] = topicName; + result["status"] = "ERR_RATE_LIMIT_EXCEEDED"; + + return _sendSuccessResponse(ctx, req, result); + } else { + redis.incrementLimitCount(limits.topic, subject, limits.interval); + } + } catch (NoRateLimitForTopic) {} + } + + auto id = redis.cacheMessage(topicName, message, accessController->subject(), timestamp, ttl); if (id.length() == 0) { msg << "Failed to cache message in Redis, discarding."; - _sendInvalidParamsError(ctx, req, msg.str()); - return; + return _sendInvalidParamsError(ctx, req, msg.str()); } - redis.publishMessage(topicName, id, message, accessController.subject()); + redis.publishMessage(topicName, id, message, accessController->subject()); LOG->debug("{} - PUBLISH {}", ctx.connection()->getIP(), topicName); nlohmann::json result; @@ -256,8 +313,8 @@ void RPCHandler::_handlePublish(HandlerContext& ctx, jsonrpcpp::request_ptr req) _sendSuccessResponse(ctx, req, result); } catch (std::exception& e) { - LOG->error("Redis error while publishing message: {}.", e.what()); - msg << "Redis error while publishing message: " << e.what(); + LOG->error("Error while publishing message: {}.", e.what()); + msg << "Error while publishing message: " << e.what(); _sendInvalidParamsError(ctx, req, msg.str()); } } @@ -279,13 +336,189 @@ void RPCHandler::_handleList(HandlerContext& ctx, jsonrpcpp::request_ptr req) { } /** - * Handle history RPC command. - * Send history cache for topic to client. + * Handle eventlog RPC command. + * Send eventlog for topic to client. * @param ctx Client issuing request. * @param req RPC request. */ -void RPCHandler::_handleHistory(HandlerContext& ctx, jsonrpcpp::request_ptr req) { - LOG->trace("handleHistory: {}", req->to_json().dump(2)); +void RPCHandler::_handleEventlog(HandlerContext& ctx, jsonrpcpp::request_ptr req) { + auto accessController = ctx.connection()->getAccessController(); + auto params = req->params(); + std::string topicName; + std::string sinceEventId; + unsigned long limit; + long long since; + std::stringstream msg; + + if (!ctx.config().get("enable_cache")) { + msg << "Cache is not enabled in server config"; + return _sendInvalidParamsError(ctx, req, msg.str()); + } + + try { + topicName = params.get("topic").get(); + sinceEventId = params.get("sinceEventId").get(); + } catch (...) {} + + try { + since = params.get("since").get(); + since = _calculateRelativeSince(since); + } catch (...) { + since = 0; + } + + try { + limit = params.get("limit").get(); + } catch (...) { + limit = ctx.config().get("max_cache_request_limit"); + } + + if (limit > (unsigned long)ctx.config().get("max_cache_request_limit")) { + limit = ctx.config().get("max_cache_request_limit"); + } + + if (sinceEventId.empty() && since == 0) { + msg << "You must specify either 'since' or 'sinceEventId'."; + return _sendInvalidParamsError(ctx, req, msg.str()); + } + + if (topicName.empty()) { + return _sendInvalidParamsError(ctx, req, "You must specify topic."); + } + + if (!TopicManager::isValidTopicOrFilter(topicName)) { + msg << "Invalid topic in request: " << topicName; + return _sendInvalidParamsError(ctx, req, msg.str()); + } + + if (!accessController->allowSubscribe(topicName)) { + msg << "You are not allowed to read from topic: " << topicName; + return _sendInvalidParamsError(ctx, req, msg.str()); + } + + LOG->trace("{} - EVENTLOG {} since: {} sinceEventId: {} limit: {}", ctx.connection()->getIP(), topicName, since, sinceEventId, limit); + + nlohmann::json items; + try { + auto& redis = ctx.server()->getRedis(); + if (!sinceEventId.empty()) + redis.getCacheSinceId(topicName, sinceEventId, limit, TopicManager::isValidTopicFilter(topicName), items); + else + redis.getCacheSince(topicName, since, limit, TopicManager::isValidTopicFilter(topicName), items); + } catch (std::exception& e) { + msg << "Error while looking up cache: " << e.what(); + LOG->error(msg.str()); + return _sendInvalidParamsError(ctx, req, msg.str()); + } + + _sendSuccessResponse(ctx, req,{ + {"action", "eventlog"}, + {"topic", topicName}, + {"status", "ok"}, + {"items", items} + }); +} + +/** + * Handle kv-store read request. + * @param ctx Client issuing request. + * @param req RPC request. + */ +void RPCHandler::_handleGet(HandlerContext& ctx, jsonrpcpp::request_ptr req) { + if (!ctx.server()->getKVStore()->is_enabled()) + return _sendInvalidParamsError(ctx, req, "KVStore is not enabled."); + + auto accessController = ctx.connection()->getAccessController(); + auto kvStore = ctx.server()->getKVStore(); + auto params = req->params(); + + try { + const auto key = params.get("key").get(); + + if (!accessController->allowSubscribe(key)) { + return _sendInvalidParamsError(ctx, req, fmt::format("You are not allowed to read key {}", key)); + } + + const auto val = kvStore->get(key); + + _sendSuccessResponse(ctx, req, { + {"action", "get"}, + {"key", key}, + {"value", val} + }); + } catch(const std::exception& e) { + _sendInvalidParamsError(ctx, req, e.what()); + } +} + +/** + * Handle kv-store write request. + * @param ctx Client issuing request. + * @param req RPC request. + */ +void RPCHandler::_handleSet(HandlerContext& ctx, jsonrpcpp::request_ptr req) { + if (!ctx.server()->getKVStore()->is_enabled()) + return _sendInvalidParamsError(ctx, req, "KVStore is not enabled."); + + auto accessController = ctx.connection()->getAccessController(); + auto kvStore = ctx.server()->getKVStore(); + auto params = req->params(); + unsigned long ttl = 0; + + try { + ttl = params.get("ttl").get(); + } catch (...) {} + + try { + const auto key = params.get("key").get(); + const auto value = params.get("value").get(); + + if (!accessController->allowPublish(key)) { + return _sendInvalidParamsError(ctx, req, fmt::format("You are not allowed to write key {}", key)); + } + + auto ret = kvStore->set(key, value, ttl); + + _sendSuccessResponse(ctx, req, { + {"action", "set"}, + {"key", key}, + {"success", ret} + }); + } catch(const std::exception& e) { + _sendInvalidParamsError(ctx, req, e.what()); + } +} + +/** + * Handle kv-store delete request. + * @param ctx Client issuing request. + * @param req RPC request. + */ +void RPCHandler::_handleDelete(HandlerContext& ctx, jsonrpcpp::request_ptr req) { + if (!ctx.server()->getKVStore()->is_enabled()) + return _sendInvalidParamsError(ctx, req, "KVStore is not enabled."); + + auto accessController = ctx.connection()->getAccessController(); + auto kvStore = ctx.server()->getKVStore(); + auto params = req->params(); + + try { + const auto key = params.get("key").get(); + + if (!accessController->allowPublish(key)) { + return _sendInvalidParamsError(ctx, req, fmt::format("You are not allowed to delete key {}", key)); + } + + auto ret = kvStore->del(key); + + _sendSuccessResponse(ctx, req, { + {"action", "del"}, + {"key", key}, + {"success", ret > 0 ? true : false} + }); + } catch(const std::exception& e) { + _sendInvalidParamsError(ctx, req, e.what()); + } } /** diff --git a/src/Redis.cpp b/src/Redis.cpp index b7d2c1d..702bd15 100644 --- a/src/Redis.cpp +++ b/src/Redis.cpp @@ -1,26 +1,32 @@ -#include "Redis.hpp" - +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include -#include #include #include -#include #include #include #include #include #include - -#include "Common.hpp" +#include +#include +#include +#include "Redis.hpp" #include "Config.hpp" #include "TopicManager.hpp" #include "Util.hpp" -#include "Server.hpp" #include "jwt/json/json.hpp" +#include "Logger.hpp" namespace eventhub { - Redis::Redis(Config &cfg) : EventhubBase(cfg) { sw::redis::ConnectionOptions connOpts; sw::redis::ConnectionPoolOptions poolOpts; @@ -41,7 +47,7 @@ Redis::Redis(Config &cfg) : EventhubBase(cfg) { poolOpts.size = config().get("redis_pool_size"); poolOpts.wait_timeout = std::chrono::seconds(5); - _redisInstance = std::make_unique(connOpts, poolOpts); + _redisInstance = std::make_shared(connOpts, poolOpts); _redisSubscriber = nullptr; } @@ -74,7 +80,7 @@ const std::string Redis::_getNextCacheId(long long timestamp) { } // Add a message to the cache. -const std::string Redis::cacheMessage(const std::string& topic, const std::string& payload, const std::string& origin, long long timestamp, unsigned int ttl) { +const std::string Redis::cacheMessage(const std::string& topic, const std::string& payload, const std::string& origin, long long timestamp, unsigned long ttl) { if (timestamp == 0) { timestamp = Util::getTimeSinceEpoch(); } @@ -104,7 +110,7 @@ const std::string Redis::cacheMessage(const std::string& topic, const std::strin // GetCache returns all matching cached messages for topics matching topicPattern // @param since List all messages since Unix timestamp or message ID // @param limit Limit resultset to at most @limit elements. -size_t Redis::getCacheSince(const std::string& topicPattern, long long since, long long limit, bool isPattern, nlohmann::json& result) { +std::size_t Redis::getCacheSince(const std::string& topicPattern, long long since, long long limit, bool isPattern, nlohmann::json& result) { std::vector topics; result = nlohmann::json::array(); @@ -171,7 +177,7 @@ size_t Redis::getCacheSince(const std::string& topicPattern, long long since, lo continue; } - for (unsigned int i = 0; i < cacheItems.size(); i++) { + for (std::size_t i = 0; i < cacheItems.size(); i++) { // Key returned from ZSET does not exist in the HSET anymore. // continue on to the next key. if (cacheItems[i].value().empty()) { @@ -209,7 +215,7 @@ std::pair _splitIdAndSeq(const std::string& cacheId) { } // Get cached messages after a given message ID. -size_t Redis::getCacheSinceId(const std::string& topicPattern, const std::string& sinceId, long long limit, bool isPattern, nlohmann::json& result) { +std::size_t Redis::getCacheSinceId(const std::string& topicPattern, const std::string& sinceId, long long limit, bool isPattern, nlohmann::json& result) { result = nlohmann::json::array(); // If cache is not enabled simply return an empty set. @@ -285,7 +291,7 @@ size_t Redis::getCacheSinceId(const std::string& topicPattern, const std::string } // Delete expired items from the cache. -size_t Redis::purgeExpiredCacheItems() { +std::size_t Redis::purgeExpiredCacheItems() { std::vector allTopics; std::vector> expiredItems; auto now = Util::getTimeSinceEpoch(); @@ -388,14 +394,53 @@ std::vector Redis::_getTopicsSeen(const std::string& topicPattern) return matchingTopics; } -CacheItemMeta::CacheItemMeta(const std::string& id, long expireAt, const std::string& origin) : +/* + Check if a user is ratelimited. +*/ +bool Redis::isRateLimited(const std::string& topic, const std::string& subject, unsigned long max) { + if (max == 0) + return false; + + const auto key = REDIS_RATE_LIMIT_PATH(_prefix, subject, topic); + auto count = _redisInstance->get(key); + if (count) { + auto c = std::stoull(count.value(), nullptr, 10); + if (c >= max) { + return true; + } + } + + return false; +} + +/* + Increment publish count for user. +*/ +void Redis::incrementLimitCount(const std::string& topic, const std::string& subject, unsigned long interval) { + if (interval == 0) + return; + + const auto key = REDIS_RATE_LIMIT_PATH(_prefix, subject, topic); + + // FIXME: We might be able to optimize away this call to get by using the value from the previous get call + // in the isRateLimited() function. + auto count = _redisInstance->get(key); + + if (count) { + _redisInstance->incr(key); + } else { + _redisInstance->setex(key, interval, "1"); + } +} + +CacheItemMeta::CacheItemMeta(const std::string& id, unsigned long expireAt, const std::string& origin) : _id(id), _expireAt(expireAt), _origin(origin) {} CacheItemMeta::CacheItemMeta(const std::string& metaStr) { std::string expireAtStr; - unsigned int j = 0; + std::size_t j = 0; - for (unsigned int i = 0; i < metaStr.length(); i++) { + for (std::size_t i = 0; i < metaStr.length(); i++) { if (metaStr[i] == ':') { j++; continue; @@ -410,7 +455,7 @@ CacheItemMeta::CacheItemMeta(const std::string& metaStr) { throw std::runtime_error("Invalid CacheItemMetaStr '" + metaStr + "'"); } - _expireAt = std::stol(expireAtStr, nullptr, 10); + _expireAt = std::stoul(expireAtStr, nullptr, 10); } const std::string CacheItemMeta::toStr() { diff --git a/src/SSLConnection.cpp b/src/SSLConnection.cpp index 29add96..8f007ef 100644 --- a/src/SSLConnection.cpp +++ b/src/SSLConnection.cpp @@ -1,8 +1,17 @@ #include - +#include +#include +#include +#include +#include +#include +#include + +#include "Forward.hpp" #include "SSLConnection.hpp" -#include "Server.hpp" #include "Util.hpp" +#include "Common.hpp" +#include "Logger.hpp" namespace eventhub { @@ -21,7 +30,7 @@ SSLConnection::~SSLConnection() { void SSLConnection::_init() { _ssl = SSL_new(_ssl_ctx); - if (_ssl == NULL) { + if (_ssl == nullptr) { _ssl = nullptr; LOG->error("Failed to initialize SSL object for client {}", getIP()); shutdown(); @@ -63,7 +72,7 @@ ssize_t SSLConnection::flushSendBuffer() { return 0; } - unsigned int pcktSize = _write_buffer.length() > NET_READ_BUFFER_SIZE ? NET_READ_BUFFER_SIZE : _write_buffer.length(); + std::size_t pcktSize = _write_buffer.length() > NET_READ_BUFFER_SIZE ? NET_READ_BUFFER_SIZE : _write_buffer.length(); int ret = SSL_write(_ssl, _write_buffer.c_str(), pcktSize); if (ret > 0) { @@ -100,13 +109,13 @@ void SSLConnection::read() { return; } - size_t bytesRead = 0; + std::size_t bytesRead = 0; int ret = 0; do { // If more read buffer capacity is required increase it by chunks of NET_READ_BUFFER_SIZE. if ((bytesRead + NET_READ_BUFFER_SIZE) > _read_buffer.capacity()) { - size_t newCapacity = _read_buffer.capacity() + NET_READ_BUFFER_SIZE; + std::size_t newCapacity = _read_buffer.capacity() + NET_READ_BUFFER_SIZE; if (newCapacity > MAX_DATA_FRAME_SIZE + NET_READ_BUFFER_SIZE) { LOG->error("Client {} exceeded max buffer size. Disconnecting.", getIP()); @@ -147,4 +156,4 @@ void SSLConnection::read() { _parseRequest(bytesRead); } -} // namespace eventhub \ No newline at end of file +} // namespace eventhub diff --git a/src/Server.cpp b/src/Server.cpp index 75be8b3..13a541f 100644 --- a/src/Server.cpp +++ b/src/Server.cpp @@ -1,28 +1,39 @@ -#include "Server.hpp" - #include #include #include #include -#include #include #include #include #include - +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include -#include #include #include #include +#include +#include +#include +#include +#include "Server.hpp" #include "Common.hpp" #include "Config.hpp" #include "Util.hpp" -#include "Server.hpp" #include "jwt/json/json.hpp" #include "metrics/Types.hpp" +#include "ConnectionWorker.hpp" +#include "KVStore.hpp" +#include "Logger.hpp" namespace eventhub { @@ -30,7 +41,7 @@ std::atomic stopEventhub{false}; std::atomic reloadEventhub{false}; unsigned const char alpn_protocol[] = "http/1.1"; -unsigned int alpn_protocol_length = 8; +std::size_t alpn_protocol_length = 8; Server::Server(Config& cfg) : _config(cfg), _server_socket(-1), _server_socket_ssl(-1), _ssl_enabled(false), _ssl_ctx(nullptr), _redis(cfg) { @@ -66,7 +77,7 @@ void Server::start() { // Start the connection workers. _connection_workers_lock.lock(); - unsigned int numWorkerThreads = config().get("worker_threads") == 0 ? std::thread::hardware_concurrency() : config().get("worker_threads"); + unsigned int numWorkerThreads = config().get("worker_threads") == 0 ? std::thread::hardware_concurrency() : config().get("worker_threads"); for (unsigned i = 0; i < numWorkerThreads; i++) { _connection_workers.addWorker(new Worker(this, i + 1)); @@ -120,6 +131,9 @@ void Server::start() { // Connect to redis. _redis.psubscribe("*", cb); + // Instanciate KVStore. + _kv_store = std::make_unique(_config, _redis); + // Add cache purge cronjob if cache functionality is enabled. if (config().get("enable_cache")) { _ev.addTimer( @@ -158,7 +172,7 @@ void Server::start() { reconnect = false; _redis.resetSubscribers(); _redis.psubscribe("*", cb); - LOG->info("Connection to Redis regained."); + LOG->info("Connection to Redis restored."); } _redis.consume(); @@ -185,7 +199,7 @@ void Server::reload() { int alpn_cb(SSL* ssl, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg) { - auto reqProto = fmt::format("{}", in); + auto reqProto = fmt::format("{}", reinterpret_cast(in)); if (reqProto.find(reinterpret_cast(alpn_protocol)) != std::string::npos) { *out = alpn_protocol; @@ -237,8 +251,8 @@ void Server::_listenerInit() { void Server::_sslListenerInit() { _server_socket_ssl = socket(AF_INET, SOCK_STREAM, 0); - if (_server_socket == -1) { - LOG->critical("Could not create server socket: {}.", strerror(errno)); + if (_server_socket_ssl == -1) { + LOG->critical("Could not create SSL server socket: {}.", strerror(errno)); exit(1); } @@ -274,7 +288,7 @@ void Server::_initSSL() { const SSL_METHOD* method = TLS_server_method(); _ssl_ctx = SSL_CTX_new(method); - if (_ssl_ctx == NULL) { + if (_ssl_ctx == nullptr) { LOG->critical("Could not initialize SSL context: {}", Util::getSSLErrorString(ERR_get_error())); exit(1); } @@ -283,7 +297,7 @@ void Server::_initSSL() { SSL_CTX_set_ecdh_auto(_ssl_ctx, 1); SSL_CTX_set_options(_ssl_ctx, SSL_OP_CIPHER_SERVER_PREFERENCE); - SSL_CTX_set_alpn_select_cb(_ssl_ctx, alpn_cb, NULL); + SSL_CTX_set_alpn_select_cb(_ssl_ctx, alpn_cb, nullptr); _ssl_enabled = true; } @@ -300,7 +314,7 @@ void Server::_loadSSLCertificates() { if (caCert.empty()) { SSL_CTX_set_default_verify_paths(_ssl_ctx); } else { - if (SSL_CTX_load_verify_locations(_ssl_ctx, caCert.c_str(), NULL) <= 0) { + if (SSL_CTX_load_verify_locations(_ssl_ctx, caCert.c_str(), nullptr) <= 0) { LOG->error("Error loading CA certificate: {}", Util::getSSLErrorString(ERR_get_error())); stop(); exit(EXIT_FAILURE); @@ -342,7 +356,7 @@ void Server::_checkSSLCertUpdated() { auto fp = fopen(config().get("ssl_certificate").c_str(), "r"); if (fp) { - PEM_read_X509(fp, NULL, NULL, NULL); + PEM_read_X509(fp, nullptr, nullptr, nullptr); auto err = ERR_get_error(); if (err != 0) { @@ -358,7 +372,7 @@ void Server::_checkSSLCertUpdated() { LOG->info("Change to private key file " + config().get("ssl_private_key") + " detected."); auto fp = fopen(config().get("ssl_private_key").c_str(), "r"); if (fp) { - PEM_read_PrivateKey(fp, NULL, NULL, NULL); + PEM_read_PrivateKey(fp, nullptr, nullptr, nullptr); auto err = ERR_get_error(); if (err != 0) { diff --git a/src/Topic.cpp b/src/Topic.cpp index 20466f0..c9423a9 100644 --- a/src/Topic.cpp +++ b/src/Topic.cpp @@ -1,14 +1,19 @@ -#include "Topic.hpp" - +#include +#include #include #include #include +#include +#include +#include -#include "Common.hpp" +#include "Topic.hpp" #include "Connection.hpp" #include "sse/Response.hpp" #include "websocket/Response.hpp" #include "websocket/Types.hpp" +#include "Logger.hpp" +#include "jwt/json/json.hpp" namespace eventhub { Topic::~Topic() {} @@ -70,7 +75,7 @@ void Topic::deleteSubscriberByIterator(TopicSubscriberList::iterator it) { /** * Returns the number of subscribers on the topic. */ -size_t Topic::getSubscriberCount() { +std::size_t Topic::getSubscriberCount() { std::lock_guard lock(_subscriber_lock); return _subscriber_list.size(); } diff --git a/src/TopicManager.cpp b/src/TopicManager.cpp index 1d94d79..c59cd05 100644 --- a/src/TopicManager.cpp +++ b/src/TopicManager.cpp @@ -1,14 +1,14 @@ -#include "TopicManager.hpp" - #include - +#include #include #include #include #include +#include -#include "Common.hpp" +#include "TopicManager.hpp" #include "Topic.hpp" +#include "Logger.hpp" namespace eventhub { /* @@ -169,6 +169,10 @@ bool TopicManager::isFilterMatched(const std::string& filterName, const std::str // We have reached the end of the topic. if (tnIt + 1 == topicName.end() && fnIt + 1 != filterName.end()) { + if (*fnIt != *tnIt) { + return false; + } + // Match the root topic in addition to every subtopic // when we have a match-all (#) on that path. // Example: topic/foo/# should also match topic/foo. diff --git a/src/Util.cpp b/src/Util.cpp index d37fef3..1b75c64 100644 --- a/src/Util.cpp +++ b/src/Util.cpp @@ -1,5 +1,3 @@ -#include "Util.hpp" - #include #include #include @@ -7,19 +5,21 @@ #include #include #include -#include #include - +#include +#include #include #include #include #include #include +#include -#include "Common.hpp" +#include "Util.hpp" +#include "Logger.hpp" namespace eventhub { -const std::string Util::base64Encode(const unsigned char* buffer, size_t length) { +const std::string Util::base64Encode(const unsigned char* buffer, std::size_t length) { BIO *bio, *b64; BUF_MEM* bufferPtr; std::string s; diff --git a/src/http/Handler.cpp b/src/http/Handler.cpp index 84015df..fbcb8ba 100644 --- a/src/http/Handler.cpp +++ b/src/http/Handler.cpp @@ -1,14 +1,13 @@ #include - +#include #include #include +#include #include "Common.hpp" #include "Config.hpp" -#include "ConnectionWorker.hpp" #include "HandlerContext.hpp" #include "Server.hpp" -#include "TopicManager.hpp" #include "Util.hpp" #include "http/Handler.hpp" #include "http/Parser.hpp" @@ -16,6 +15,8 @@ #include "metrics/JsonRenderer.hpp" #include "metrics/PrometheusRenderer.hpp" #include "sse/Handler.hpp" +#include "AccessController.hpp" +#include "Connection.hpp" namespace eventhub { namespace http { @@ -109,7 +110,7 @@ void Handler::_handlePath(HandlerContext& ctx, Parser* req) { return; } - if (!ctx.connection()->getAccessController().authenticate(authToken, ctx.server()->config().get("jwt_secret"))) { + if (!ctx.connection()->getAccessController()->authenticate(authToken, ctx.server()->config().get("jwt_secret"))) { _badRequest(ctx, "Authentication failed.", 401); return; } @@ -136,14 +137,10 @@ bool Handler::_websocketHandshake(HandlerContext& ctx, Parser* req) { return false; } - size_t keyLen = secWsKey.length() + 36; - unsigned char key[keyLen]; + const std::string key = secWsKey + WS_MAGIC_STRING; unsigned char keySha1[SHA_DIGEST_LENGTH] = {0}; - memcpy(key, secWsKey.c_str(), secWsKey.length()); - memcpy(key + secWsKey.length(), WS_MAGIC_STRING, 36); - - SHA1(key, keyLen, keySha1); + SHA1(reinterpret_cast(key.c_str()), key.length(), keySha1); const std::string secWsAccept = Util::base64Encode(keySha1, SHA_DIGEST_LENGTH); Response resp; diff --git a/src/http/Parser.cpp b/src/http/Parser.cpp index 6536759..7becfac 100644 --- a/src/http/Parser.cpp +++ b/src/http/Parser.cpp @@ -1,15 +1,13 @@ -#include "http/Parser.hpp" - #include - -#include +#include #include -#include #include +#include -#include "Common.hpp" +#include "http/Parser.hpp" #include "Util.hpp" #include "http/picohttpparser.h" +#include "Logger.hpp" namespace eventhub { namespace http { @@ -23,7 +21,7 @@ Parser::Parser() { _is_complete = false; _error_message = ""; _callback = [](http::Parser* req, http::RequestState reqState) { - LOG->error("Websocket parser callback was called before it was initialized."); + LOG->error("HTTP parser callback was called before it was initialized."); }; } @@ -37,7 +35,7 @@ Parser::~Parser() {} @param data Raw http request data. @param len Length of data. **/ -void Parser::parse(const char* data, int len) { +void Parser::parse(const char* data, std::size_t len) { int pret; if (_is_complete) @@ -76,7 +74,7 @@ void Parser::parse(const char* data, int len) { std::string rawPath; rawPath.insert(0, _phr_path, _phr_path_len); - size_t qsPos = rawPath.find_first_of('?', 0); + std::size_t qsPos = rawPath.find_first_of('?', 0); if (qsPos != std::string::npos) { std::string qStr; qStr = rawPath.substr(qsPos + 1, std::string::npos); @@ -137,12 +135,12 @@ const std::map& Parser::getHeaders() { Extracts query parameters from a string if they exist. @param buf The string to parse. **/ -size_t Parser::_parse_query_string(const std::string& buf) { - size_t prevpos = 0, eqlpos = 0; +std::size_t Parser::_parse_query_string(const std::string& buf) { + std::size_t prevpos = 0, eqlpos = 0; while ((eqlpos = buf.find("=", prevpos)) != std::string::npos) { std::string param, val; - size_t len; + std::size_t len; len = buf.find("&", eqlpos); @@ -183,7 +181,7 @@ const std::string Parser::getQueryString(std::string param) { /** Returns number of query strings in the request. **/ -size_t Parser::numQueryString() { +std::size_t Parser::numQueryString() { return _qsmap.size(); } diff --git a/src/http/Response.cpp b/src/http/Response.cpp index 2c7494a..4255509 100644 --- a/src/http/Response.cpp +++ b/src/http/Response.cpp @@ -1,9 +1,8 @@ -#include "http/Response.hpp" - #include #include +#include -#include "Common.hpp" +#include "http/Response.hpp" namespace eventhub { namespace http { diff --git a/src/main.cpp b/src/main.cpp index 472f9b5..17f4c71 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,17 +1,17 @@ #include -#include #include -#include -#include #include - +#include +#include +#include +#include #include #include -#include +#include -#include "Common.hpp" #include "Config.hpp" #include "Server.hpp" +#include "Logger.hpp" namespace eventhub { extern std::atomic stopEventhub; @@ -56,10 +56,10 @@ int main(int argc, char** argv) { sa.sa_flags = 0; sigemptyset(&(sa.sa_mask)); - sigaction(SIGINT, &sa, NULL); - sigaction(SIGQUIT, &sa, NULL); - sigaction(SIGTERM, &sa, NULL); - sigaction(SIGHUP, &sa, NULL); + sigaction(SIGINT, &sa, nullptr); + sigaction(SIGQUIT, &sa, nullptr); + sigaction(SIGTERM, &sa, nullptr); + sigaction(SIGHUP, &sa, nullptr); struct option long_options[] = { { "config", required_argument, 0, 'c' }, @@ -113,7 +113,8 @@ int main(int argc, char** argv) { { "ssl_private_key", ConfigValueType::STRING, "", ConfigValueSettings::OPTIONAL }, { "ssl_cert_auto_reload", ConfigValueType::BOOL, "false", ConfigValueSettings::OPTIONAL }, { "ssl_cert_check_interval", ConfigValueType::INT, "300", ConfigValueSettings::OPTIONAL }, - { "disable_unsecure_listener", ConfigValueType::BOOL, "false", ConfigValueSettings::OPTIONAL } + { "disable_unsecure_listener", ConfigValueType::BOOL, "false", ConfigValueSettings::OPTIONAL }, + { "enable_kvstore", ConfigValueType::BOOL, "true", ConfigValueSettings::REQUIRED } }; Config cfg(cfgMap); @@ -136,4 +137,4 @@ int main(int argc, char** argv) { server.start(); return 0; -} \ No newline at end of file +} diff --git a/src/metrics/JsonRenderer.cpp b/src/metrics/JsonRenderer.cpp index 52c0376..e9fdec9 100644 --- a/src/metrics/JsonRenderer.cpp +++ b/src/metrics/JsonRenderer.cpp @@ -1,4 +1,7 @@ +#include #include +#include +#include #include "jwt/json/json.hpp" #include "metrics/JsonRenderer.hpp" diff --git a/src/metrics/PrometheusRenderer.cpp b/src/metrics/PrometheusRenderer.cpp index 0b8a206..66a53db 100644 --- a/src/metrics/PrometheusRenderer.cpp +++ b/src/metrics/PrometheusRenderer.cpp @@ -17,28 +17,33 @@ const std::string PrometheusRenderer::RenderMetrics(Server* server) { auto metrics = server->getAggregatedMetrics(); auto& config = server->config(); - std::vector> metricList = { - {"worker_count", metrics.worker_count}, - {"publish_count", metrics.publish_count}, - {"redis_connection_fail_count", metrics.redis_connection_fail_count}, - {"redis_publish_delay_ms", metrics.redis_publish_delay_ms}, + std::vector> metricList = { + {"worker_count", "gauge", metrics.worker_count}, + {"publish_count", "counter", metrics.publish_count}, + {"redis_connection_fail_count", "counter", metrics.redis_connection_fail_count}, + {"redis_publish_delay_ms", "gauge", metrics.redis_publish_delay_ms}, - {"current_connections_count", metrics.current_connections_count}, - {"total_connect_count", metrics.total_connect_count}, - {"total_disconnect_count", metrics.total_disconnect_count}, - {"eventloop_delay_ms", metrics.eventloop_delay_ms}}; + {"current_connections_count", "gauge", metrics.current_connections_count}, + {"total_connect_count", "counter", metrics.total_connect_count}, + {"total_disconnect_count", "counter", metrics.total_disconnect_count}, + {"eventloop_delay_ms", "gauge", metrics.eventloop_delay_ms}}; char h_buf[128] = {0}; std::stringstream ss; gethostname(h_buf, sizeof(h_buf)); - for (auto& m : metricList) { - // Add prefix to metric key if set in configuration. - const std::string metricKey = !config.get("prometheus_metric_prefix").empty() ? (config.get("prometheus_metric_prefix") + "_" + m.first) : m.first; + for (const auto& metric : metricList) { + // Add prefix provided in configuration to metric name + const std::string& metricName = config.get("prometheus_metric_prefix") + "_" + std::get<0>(metric); + const std::string& metricType = std::get<1>(metric); + const long long& metricValue = std::get<2>(metric); - ss << metricKey << "{instance=\"" << h_buf << ":" << config.get("listen_port") << "\"" - << "} " << m.second << "\n"; + // Output the type of each metric + ss << "# TYPE " << metricName << " " << metricType << "\n"; + + ss << metricName << "{instance=\"" << h_buf << ":" << config.get("listen_port") << "\"" + << "} " << metricValue << "\n"; } return ss.str(); diff --git a/src/sse/Handler.cpp b/src/sse/Handler.cpp index 3073a46..70d11a0 100644 --- a/src/sse/Handler.cpp +++ b/src/sse/Handler.cpp @@ -1,20 +1,21 @@ -#include +#include #include #include #include +#include -#include "Common.hpp" #include "Config.hpp" #include "Connection.hpp" -#include "ConnectionWorker.hpp" #include "HandlerContext.hpp" #include "Redis.hpp" #include "Server.hpp" #include "TopicManager.hpp" #include "Util.hpp" -#include "http/Response.hpp" #include "sse/Handler.hpp" #include "sse/Response.hpp" +#include "AccessController.hpp" +#include "http/Parser.hpp" +#include "jwt/json/json.hpp" namespace eventhub { namespace sse { @@ -22,7 +23,7 @@ namespace sse { void Handler::HandleRequest(HandlerContext& ctx, http::Parser* req) { auto conn = ctx.connection(); auto& redis = ctx.server()->getRedis(); - auto& accessController = conn->getAccessController(); + auto accessController = conn->getAccessController(); auto path = Util::uriDecode(req->getPath()); auto lastEventId = req->getHeader("Last-Event-ID"); @@ -40,7 +41,7 @@ void Handler::HandleRequest(HandlerContext& ctx, http::Parser* req) { } // Check authorization. - if (!accessController.allowSubscribe(path)) { + if (!accessController->allowSubscribe(path)) { Response::error(conn, "Insufficient access.", 401); return; } diff --git a/src/sse/Response.cpp b/src/sse/Response.cpp index 37e9ea7..07b5ceb 100644 --- a/src/sse/Response.cpp +++ b/src/sse/Response.cpp @@ -1,11 +1,14 @@ -#include "sse/Response.hpp" -#include "http/Response.hpp" - #include +#include #include #include +#include +#include +#include "sse/Response.hpp" +#include "http/Response.hpp" #include "Connection.hpp" +#include "jwt/json/json.hpp" namespace eventhub { namespace sse { @@ -46,4 +49,4 @@ void Response::error(ConnectionPtr conn, const std::string& message, unsigned in } } // namespace sse -} // namespace eventhub \ No newline at end of file +} // namespace eventhub diff --git a/src/websocket/Handler.cpp b/src/websocket/Handler.cpp index 83583cc..ce25305 100644 --- a/src/websocket/Handler.cpp +++ b/src/websocket/Handler.cpp @@ -1,18 +1,17 @@ -#include "websocket/Handler.hpp" - +#include #include #include -#include +#include +#include -#include "Common.hpp" -#include "Config.hpp" +#include "websocket/Handler.hpp" #include "Connection.hpp" #include "HandlerContext.hpp" #include "RPCHandler.hpp" #include "jsonrpc/jsonrpcpp.hpp" -#include "jwt/json/json.hpp" #include "websocket/Response.hpp" #include "websocket/Types.hpp" +#include "Logger.hpp" namespace eventhub { namespace websocket { diff --git a/src/websocket/Parser.cpp b/src/websocket/Parser.cpp index 05e0ce2..e89bf92 100644 --- a/src/websocket/Parser.cpp +++ b/src/websocket/Parser.cpp @@ -1,10 +1,12 @@ -#include "websocket/Parser.hpp" - +#include +#include #include +#include "websocket/Parser.hpp" #include "Common.hpp" #include "websocket/Types.hpp" #include "websocket/ws_parser.h" +#include "Logger.hpp" namespace eventhub { namespace websocket { @@ -15,7 +17,7 @@ static int parserOnDataBegin(void* userData, uint8_t frameType) { return 0; } -static int parserOnDataPayload(void* userData, const char* buff, size_t len) { +static int parserOnDataPayload(void* userData, const char* buff, std::size_t len) { auto obj = static_cast(userData); obj->appendDataPayload(buff, len); return 0; @@ -34,7 +36,7 @@ static int parserOnControlBegin(void* userData, uint8_t frameType) { return 0; } -static int parserOnControlPayload(void* userData, const char* buff, size_t len) { +static int parserOnControlPayload(void* userData, const char* buff, std::size_t len) { auto obj = static_cast(userData); obj->appendControlPayload(buff, len); return 0; @@ -71,7 +73,7 @@ void Parser::clearControlPayload() { _control_payload_buf.clear(); } -void Parser::appendDataPayload(const char* data, size_t len) { +void Parser::appendDataPayload(const char* data, std::size_t len) { _data_payload_buf.insert(_data_payload_buf.size(), data, len); if (_data_payload_buf.size() > MAX_DATA_FRAME_SIZE) { @@ -79,7 +81,7 @@ void Parser::appendDataPayload(const char* data, size_t len) { } } -void Parser::appendControlPayload(const char* data, size_t len) { +void Parser::appendControlPayload(const char* data, std::size_t len) { _control_payload_buf.insert(_control_payload_buf.size(), data, len); if (_control_payload_buf.size() > WS_MAX_CONTROL_FRAME_SIZE) { @@ -111,7 +113,7 @@ const std::string& Parser::getControlPayload() { return _control_payload_buf; } -void Parser::parse(char* buf, size_t len) { +void Parser::parse(char* buf, std::size_t len) { ws_parser_execute(&_ws_parser, &_ws_parser_callbacks, this, buf, len); } diff --git a/src/websocket/Response.cpp b/src/websocket/Response.cpp index 358fb42..85aee1d 100644 --- a/src/websocket/Response.cpp +++ b/src/websocket/Response.cpp @@ -1,9 +1,9 @@ -#include "websocket/Response.hpp" - -#include #include +#include #include +#include +#include "websocket/Response.hpp" #include "Common.hpp" #include "websocket/Types.hpp" @@ -12,8 +12,8 @@ namespace websocket { void Response::_sendFragment(ConnectionPtr conn, const std::string& fragment, uint8_t frameType, bool fin) { std::string sndBuf; char header[8]; - size_t headerSize = 0; - size_t fragmentSize = fragment.size(); + std::size_t headerSize = 0; + std::size_t fragmentSize = fragment.size(); header[0] = fin << 7; header[0] = header[0] | (0xF & frameType); @@ -40,7 +40,7 @@ void Response::_sendFragment(ConnectionPtr conn, const std::string& fragment, ui } void Response::sendData(ConnectionPtr conn, const std::string& data, FrameType frameType) { - size_t dataSize = data.size(); + std::size_t dataSize = data.size(); if (dataSize < WS_MAX_CHUNK_SIZE) { _sendFragment(conn, data, (uint8_t)frameType, true); @@ -48,11 +48,11 @@ void Response::sendData(ConnectionPtr conn, const std::string& data, FrameType f // First: fin = false, frameType = frameType // Following: fin = false, frameType = CONTINUATION_FRAME // Last: fin = true, frameType = CONTINUATION_FRAME - size_t nChunks = dataSize / WS_MAX_CHUNK_SIZE; + std::size_t nChunks = dataSize / WS_MAX_CHUNK_SIZE; for (unsigned i = 0; i < nChunks; i++) { uint8_t chunkFrametype = uint8_t((i == 0) ? frameType : FrameType::CONTINUATION_FRAME); bool fin = (i < (nChunks - 1)) ? false : true; - size_t len = (i < (nChunks - 1)) ? WS_MAX_CHUNK_SIZE : std::string::npos; + std::size_t len = (i < (nChunks - 1)) ? WS_MAX_CHUNK_SIZE : std::string::npos; auto const chunk = data.substr(i * WS_MAX_CHUNK_SIZE, len); _sendFragment(conn, chunk, chunkFrametype, fin); } diff --git a/src/websocket/ws_parser.cpp b/src/websocket/ws_parser.cpp index 2f71ac0..55c41c5 100644 --- a/src/websocket/ws_parser.cpp +++ b/src/websocket/ws_parser.cpp @@ -294,6 +294,6 @@ ws_parser_error(int rc) { PARSER_ERROR_CODES(XX) #undef XX - return NULL; + return nullptr; } } // namespace eventhub diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9997cc8..2163d64 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -8,6 +8,7 @@ set(SOURCES src/RedisTest.cpp src/AccessControllerTest.cpp src/UtilTest.cpp + src/KVStoreTest.cpp src/main.cpp ) diff --git a/tests/include/catch.hpp b/tests/include/catch.hpp index 0384171..d2a1242 100644 --- a/tests/include/catch.hpp +++ b/tests/include/catch.hpp @@ -1,9 +1,9 @@ /* - * Catch v2.13.4 - * Generated: 2020-12-29 14:48:00.116107 + * Catch v2.13.9 + * Generated: 2022-04-12 22:37:23.260201 * ---------------------------------------------------------- * This file has been merged from multiple headers. Please don't edit it directly - * Copyright (c) 2020 Two Blue Cubes Ltd. All rights reserved. + * Copyright (c) 2022 Two Blue Cubes Ltd. All rights reserved. * * Distributed under the Boost Software License, Version 1.0. (See accompanying * file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) @@ -15,7 +15,7 @@ #define CATCH_VERSION_MAJOR 2 #define CATCH_VERSION_MINOR 13 -#define CATCH_VERSION_PATCH 4 +#define CATCH_VERSION_PATCH 9 #ifdef __clang__ # pragma clang system_header @@ -66,13 +66,16 @@ #if !defined(CATCH_CONFIG_IMPL_ONLY) // start catch_platform.h +// See e.g.: +// https://opensource.apple.com/source/CarbonHeaders/CarbonHeaders-18.1/TargetConditionals.h.auto.html #ifdef __APPLE__ -# include -# if TARGET_OS_OSX == 1 -# define CATCH_PLATFORM_MAC -# elif TARGET_OS_IPHONE == 1 -# define CATCH_PLATFORM_IPHONE -# endif +# include +# if (defined(TARGET_OS_OSX) && TARGET_OS_OSX == 1) || \ + (defined(TARGET_OS_MAC) && TARGET_OS_MAC == 1) +# define CATCH_PLATFORM_MAC +# elif (defined(TARGET_OS_IPHONE) && TARGET_OS_IPHONE == 1) +# define CATCH_PLATFORM_IPHONE +# endif #elif defined(linux) || defined(__linux) || defined(__linux__) # define CATCH_PLATFORM_LINUX @@ -132,9 +135,9 @@ namespace Catch { #endif -// We have to avoid both ICC and Clang, because they try to mask themselves -// as gcc, and we want only GCC in this block -#if defined(__GNUC__) && !defined(__clang__) && !defined(__ICC) && !defined(__CUDACC__) +// Only GCC compiler should be used in this block, so other compilers trying to +// mask themselves as GCC should be ignored. +#if defined(__GNUC__) && !defined(__clang__) && !defined(__ICC) && !defined(__CUDACC__) && !defined(__LCC__) # define CATCH_INTERNAL_START_WARNINGS_SUPPRESSION _Pragma( "GCC diagnostic push" ) # define CATCH_INTERNAL_STOP_WARNINGS_SUPPRESSION _Pragma( "GCC diagnostic pop" ) @@ -237,9 +240,6 @@ namespace Catch { // Visual C++ #if defined(_MSC_VER) -# define CATCH_INTERNAL_START_WARNINGS_SUPPRESSION __pragma( warning(push) ) -# define CATCH_INTERNAL_STOP_WARNINGS_SUPPRESSION __pragma( warning(pop) ) - // Universal Windows platform does not support SEH // Or console colours (or console at all...) # if defined(WINAPI_FAMILY) && (WINAPI_FAMILY == WINAPI_FAMILY_APP) @@ -248,13 +248,18 @@ namespace Catch { # define CATCH_INTERNAL_CONFIG_WINDOWS_SEH # endif +# if !defined(__clang__) // Handle Clang masquerading for msvc + // MSVC traditional preprocessor needs some workaround for __VA_ARGS__ // _MSVC_TRADITIONAL == 0 means new conformant preprocessor // _MSVC_TRADITIONAL == 1 means old traditional non-conformant preprocessor -# if !defined(__clang__) // Handle Clang masquerading for msvc # if !defined(_MSVC_TRADITIONAL) || (defined(_MSVC_TRADITIONAL) && _MSVC_TRADITIONAL) # define CATCH_INTERNAL_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR # endif // MSVC_TRADITIONAL + +// Only do this if we're not using clang on Windows, which uses `diagnostic push` & `diagnostic pop` +# define CATCH_INTERNAL_START_WARNINGS_SUPPRESSION __pragma( warning(push) ) +# define CATCH_INTERNAL_STOP_WARNINGS_SUPPRESSION __pragma( warning(pop) ) # endif // __clang__ #endif // _MSC_VER @@ -323,7 +328,7 @@ namespace Catch { // Check if byte is available and usable # if __has_include() && defined(CATCH_CPP17_OR_GREATER) # include - # if __cpp_lib_byte > 0 + # if defined(__cpp_lib_byte) && (__cpp_lib_byte > 0) # define CATCH_INTERNAL_CONFIG_CPP17_BYTE # endif # endif // __has_include() && defined(CATCH_CPP17_OR_GREATER) @@ -1007,34 +1012,34 @@ struct AutoReg : NonCopyable { #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(Name, Tags, ...) \ - INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, typename TestType, __VA_ARGS__ ) + INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), Name, Tags, typename TestType, __VA_ARGS__ ) #else #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(Name, Tags, ...) \ - INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, typename TestType, __VA_ARGS__ ) ) + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), Name, Tags, typename TestType, __VA_ARGS__ ) ) #endif #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_SIG_NO_REGISTRATION(Name, Tags, Signature, ...) \ - INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, Signature, __VA_ARGS__ ) + INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), Name, Tags, Signature, __VA_ARGS__ ) #else #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_SIG_NO_REGISTRATION(Name, Tags, Signature, ...) \ - INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, Signature, __VA_ARGS__ ) ) + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), Name, Tags, Signature, __VA_ARGS__ ) ) #endif #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION( ClassName, Name, Tags,... ) \ - INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, typename T, __VA_ARGS__ ) + INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_C_L_A_S_S_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ) , ClassName, Name, Tags, typename T, __VA_ARGS__ ) #else #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION( ClassName, Name, Tags,... ) \ - INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, typename T, __VA_ARGS__ ) ) + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_C_L_A_S_S_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ) , ClassName, Name, Tags, typename T, __VA_ARGS__ ) ) #endif #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_SIG_NO_REGISTRATION( ClassName, Name, Tags, Signature, ... ) \ - INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, Signature, __VA_ARGS__ ) + INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_C_L_A_S_S_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ) , ClassName, Name, Tags, Signature, __VA_ARGS__ ) #else #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_SIG_NO_REGISTRATION( ClassName, Name, Tags, Signature, ... ) \ - INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, Signature, __VA_ARGS__ ) ) + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_C_L_A_S_S_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ) , ClassName, Name, Tags, Signature, __VA_ARGS__ ) ) #endif #endif @@ -1047,7 +1052,7 @@ struct AutoReg : NonCopyable { CATCH_INTERNAL_STOP_WARNINGS_SUPPRESSION \ static void TestName() #define INTERNAL_CATCH_TESTCASE( ... ) \ - INTERNAL_CATCH_TESTCASE2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), __VA_ARGS__ ) + INTERNAL_CATCH_TESTCASE2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_S_T_ ), __VA_ARGS__ ) /////////////////////////////////////////////////////////////////////////////// #define INTERNAL_CATCH_METHOD_AS_TEST_CASE( QualifiedMethod, ... ) \ @@ -1069,7 +1074,7 @@ struct AutoReg : NonCopyable { CATCH_INTERNAL_STOP_WARNINGS_SUPPRESSION \ void TestName::test() #define INTERNAL_CATCH_TEST_CASE_METHOD( ClassName, ... ) \ - INTERNAL_CATCH_TEST_CASE_METHOD2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), ClassName, __VA_ARGS__ ) + INTERNAL_CATCH_TEST_CASE_METHOD2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_S_T_ ), ClassName, __VA_ARGS__ ) /////////////////////////////////////////////////////////////////////////////// #define INTERNAL_CATCH_REGISTER_TESTCASE( Function, ... ) \ @@ -1110,18 +1115,18 @@ struct AutoReg : NonCopyable { #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR #define INTERNAL_CATCH_TEMPLATE_TEST_CASE(Name, Tags, ...) \ - INTERNAL_CATCH_TEMPLATE_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, typename TestType, __VA_ARGS__ ) + INTERNAL_CATCH_TEMPLATE_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), Name, Tags, typename TestType, __VA_ARGS__ ) #else #define INTERNAL_CATCH_TEMPLATE_TEST_CASE(Name, Tags, ...) \ - INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, typename TestType, __VA_ARGS__ ) ) + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), Name, Tags, typename TestType, __VA_ARGS__ ) ) #endif #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_SIG(Name, Tags, Signature, ...) \ - INTERNAL_CATCH_TEMPLATE_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, Signature, __VA_ARGS__ ) + INTERNAL_CATCH_TEMPLATE_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), Name, Tags, Signature, __VA_ARGS__ ) #else #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_SIG(Name, Tags, Signature, ...) \ - INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, Signature, __VA_ARGS__ ) ) + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), Name, Tags, Signature, __VA_ARGS__ ) ) #endif #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2(TestName, TestFuncName, Name, Tags, Signature, TmplTypes, TypesList) \ @@ -1159,18 +1164,18 @@ struct AutoReg : NonCopyable { #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE(Name, Tags, ...)\ - INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, typename T,__VA_ARGS__) + INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2(INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), Name, Tags, typename T,__VA_ARGS__) #else #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE(Name, Tags, ...)\ - INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, typename T, __VA_ARGS__ ) ) + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), Name, Tags, typename T, __VA_ARGS__ ) ) #endif #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_SIG(Name, Tags, Signature, ...)\ - INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, Signature, __VA_ARGS__) + INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2(INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), Name, Tags, Signature, __VA_ARGS__) #else #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_SIG(Name, Tags, Signature, ...)\ - INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, Signature, __VA_ARGS__ ) ) + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), Name, Tags, Signature, __VA_ARGS__ ) ) #endif #define INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE_2(TestName, TestFunc, Name, Tags, TmplList)\ @@ -1201,7 +1206,7 @@ struct AutoReg : NonCopyable { static void TestFunc() #define INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE(Name, Tags, TmplList) \ - INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, TmplList ) + INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), Name, Tags, TmplList ) #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( TestNameClass, TestName, ClassName, Name, Tags, Signature, ... ) \ CATCH_INTERNAL_START_WARNINGS_SUPPRESSION \ @@ -1234,18 +1239,18 @@ struct AutoReg : NonCopyable { #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( ClassName, Name, Tags,... ) \ - INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, typename T, __VA_ARGS__ ) + INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_C_L_A_S_S_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ) , ClassName, Name, Tags, typename T, __VA_ARGS__ ) #else #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( ClassName, Name, Tags,... ) \ - INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, typename T, __VA_ARGS__ ) ) + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_C_L_A_S_S_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ) , ClassName, Name, Tags, typename T, __VA_ARGS__ ) ) #endif #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_SIG( ClassName, Name, Tags, Signature, ... ) \ - INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, Signature, __VA_ARGS__ ) + INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_C_L_A_S_S_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ) , ClassName, Name, Tags, Signature, __VA_ARGS__ ) #else #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_SIG( ClassName, Name, Tags, Signature, ... ) \ - INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, Signature, __VA_ARGS__ ) ) + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_C_L_A_S_S_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ) , ClassName, Name, Tags, Signature, __VA_ARGS__ ) ) #endif #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2(TestNameClass, TestName, ClassName, Name, Tags, Signature, TmplTypes, TypesList)\ @@ -1286,18 +1291,18 @@ struct AutoReg : NonCopyable { #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( ClassName, Name, Tags, ... )\ - INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), ClassName, Name, Tags, typename T, __VA_ARGS__ ) + INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), ClassName, Name, Tags, typename T, __VA_ARGS__ ) #else #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( ClassName, Name, Tags, ... )\ - INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), ClassName, Name, Tags, typename T,__VA_ARGS__ ) ) + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), ClassName, Name, Tags, typename T,__VA_ARGS__ ) ) #endif #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_SIG( ClassName, Name, Tags, Signature, ... )\ - INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), ClassName, Name, Tags, Signature, __VA_ARGS__ ) + INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), ClassName, Name, Tags, Signature, __VA_ARGS__ ) #else #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_SIG( ClassName, Name, Tags, Signature, ... )\ - INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), ClassName, Name, Tags, Signature,__VA_ARGS__ ) ) + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), ClassName, Name, Tags, Signature,__VA_ARGS__ ) ) #endif #define INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE_METHOD_2( TestNameClass, TestName, ClassName, Name, Tags, TmplList) \ @@ -1331,7 +1336,7 @@ struct AutoReg : NonCopyable { void TestName::test() #define INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE_METHOD(ClassName, Name, Tags, TmplList) \ - INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), ClassName, Name, Tags, TmplList ) + INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_ ), INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_M_P_L_A_T_E_T_E_S_T_F_U_N_C_ ), ClassName, Name, Tags, TmplList ) // end catch_test_registry.h // start catch_capture.hpp @@ -3088,7 +3093,7 @@ namespace Detail { Approx operator-() const; template ::value>::type> - Approx operator()( T const& value ) { + Approx operator()( T const& value ) const { Approx approx( static_cast(value) ); approx.m_epsilon = m_epsilon; approx.m_margin = m_margin; @@ -4160,7 +4165,7 @@ namespace Generators { if (!m_predicate(m_generator.get())) { // It might happen that there are no values that pass the // filter. In that case we throw an exception. - auto has_initial_value = next(); + auto has_initial_value = nextImpl(); if (!has_initial_value) { Catch::throw_exception(GeneratorException("No valid value found in filtered generator")); } @@ -4172,6 +4177,11 @@ namespace Generators { } bool next() override { + return nextImpl(); + } + + private: + bool nextImpl() { bool success = m_generator.next(); if (!success) { return false; @@ -5455,6 +5465,8 @@ namespace Catch { } // namespace Catch // end catch_outlier_classification.hpp + +#include #endif // CATCH_CONFIG_ENABLE_BENCHMARKING #include @@ -6339,9 +6351,10 @@ namespace Catch { void writeTestCase(TestCaseNode const& testCaseNode); - void writeSection(std::string const& className, - std::string const& rootName, - SectionNode const& sectionNode); + void writeSection( std::string const& className, + std::string const& rootName, + SectionNode const& sectionNode, + bool testOkToFail ); void writeAssertions(SectionNode const& sectionNode); void writeAssertion(AssertionStats const& stats); @@ -6876,7 +6889,7 @@ namespace Catch { } iters *= 2; } - throw optimized_away_error{}; + Catch::throw_exception(optimized_away_error{}); } } // namespace Detail } // namespace Benchmark @@ -6884,6 +6897,7 @@ namespace Catch { // end catch_run_for_at_least.hpp #include +#include namespace Catch { namespace Benchmark { @@ -7054,8 +7068,8 @@ namespace Catch { double b2 = bias - z1; double a1 = a(b1); double a2 = a(b2); - auto lo = std::max(cumn(a1), 0); - auto hi = std::min(cumn(a2), n - 1); + auto lo = (std::max)(cumn(a1), 0); + auto hi = (std::min)(cumn(a2), n - 1); return { point, resample[lo], resample[hi], confidence_level }; } @@ -7124,7 +7138,9 @@ namespace Catch { } template EnvironmentEstimate> estimate_clock_cost(FloatDuration resolution) { - auto time_limit = std::min(resolution * clock_cost_estimation_tick_limit, FloatDuration(clock_cost_estimation_time_limit)); + auto time_limit = (std::min)( + resolution * clock_cost_estimation_tick_limit, + FloatDuration(clock_cost_estimation_time_limit)); auto time_clock = [](int k) { return Detail::measure([k] { for (int i = 0; i < k; ++i) { @@ -7771,7 +7787,7 @@ namespace Catch { double sb = stddev.point; double mn = mean.point / n; double mg_min = mn / 2.; - double sg = std::min(mg_min / 4., sb / std::sqrt(n)); + double sg = (std::min)(mg_min / 4., sb / std::sqrt(n)); double sg2 = sg * sg; double sb2 = sb * sb; @@ -7790,7 +7806,7 @@ namespace Catch { return (nc / n) * (sb2 - nc * sg2); }; - return std::min(var_out(1), var_out(std::min(c_max(0.), c_max(mg_min)))) / sb2; + return (std::min)(var_out(1), var_out((std::min)(c_max(0.), c_max(mg_min)))) / sb2; } bootstrap_analysis analyse_samples(double confidence_level, int n_resamples, std::vector::iterator first, std::vector::iterator last) { @@ -7980,86 +7996,58 @@ namespace Catch { // start catch_fatal_condition.h -// start catch_windows_h_proxy.h - - -#if defined(CATCH_PLATFORM_WINDOWS) - -#if !defined(NOMINMAX) && !defined(CATCH_CONFIG_NO_NOMINMAX) -# define CATCH_DEFINED_NOMINMAX -# define NOMINMAX -#endif -#if !defined(WIN32_LEAN_AND_MEAN) && !defined(CATCH_CONFIG_NO_WIN32_LEAN_AND_MEAN) -# define CATCH_DEFINED_WIN32_LEAN_AND_MEAN -# define WIN32_LEAN_AND_MEAN -#endif - -#ifdef __AFXDLL -#include -#else -#include -#endif - -#ifdef CATCH_DEFINED_NOMINMAX -# undef NOMINMAX -#endif -#ifdef CATCH_DEFINED_WIN32_LEAN_AND_MEAN -# undef WIN32_LEAN_AND_MEAN -#endif - -#endif // defined(CATCH_PLATFORM_WINDOWS) - -// end catch_windows_h_proxy.h -#if defined( CATCH_CONFIG_WINDOWS_SEH ) +#include namespace Catch { - struct FatalConditionHandler { - - static LONG CALLBACK handleVectoredException(PEXCEPTION_POINTERS ExceptionInfo); + // Wrapper for platform-specific fatal error (signals/SEH) handlers + // + // Tries to be cooperative with other handlers, and not step over + // other handlers. This means that unknown structured exceptions + // are passed on, previous signal handlers are called, and so on. + // + // Can only be instantiated once, and assumes that once a signal + // is caught, the binary will end up terminating. Thus, there + class FatalConditionHandler { + bool m_started = false; + + // Install/disengage implementation for specific platform. + // Should be if-defed to work on current platform, can assume + // engage-disengage 1:1 pairing. + void engage_platform(); + void disengage_platform(); + public: + // Should also have platform-specific implementations as needed FatalConditionHandler(); - static void reset(); ~FatalConditionHandler(); - private: - static bool isSet; - static ULONG guaranteeSize; - static PVOID exceptionHandlerHandle; - }; - -} // namespace Catch - -#elif defined ( CATCH_CONFIG_POSIX_SIGNALS ) - -#include - -namespace Catch { - - struct FatalConditionHandler { - - static bool isSet; - static struct sigaction oldSigActions[]; - static stack_t oldSigStack; - static char altStackMem[]; - - static void handleSignal( int sig ); + void engage() { + assert(!m_started && "Handler cannot be installed twice."); + m_started = true; + engage_platform(); + } - FatalConditionHandler(); - ~FatalConditionHandler(); - static void reset(); + void disengage() { + assert(m_started && "Handler cannot be uninstalled without being installed first"); + m_started = false; + disengage_platform(); + } }; -} // namespace Catch - -#else - -namespace Catch { - struct FatalConditionHandler { - void reset(); + //! Simple RAII guard for (dis)engaging the FatalConditionHandler + class FatalConditionHandlerGuard { + FatalConditionHandler* m_handler; + public: + FatalConditionHandlerGuard(FatalConditionHandler* handler): + m_handler(handler) { + m_handler->engage(); + } + ~FatalConditionHandlerGuard() { + m_handler->disengage(); + } }; -} -#endif +} // end namespace Catch // end catch_fatal_condition.h #include @@ -8185,6 +8173,7 @@ namespace Catch { std::vector m_unfinishedSections; std::vector m_activeSections; TrackerContext m_trackerContext; + FatalConditionHandler m_fatalConditionhandler; bool m_lastAssertionPassed = false; bool m_shouldReportUnexpected = true; bool m_includeSuccessfulResults; @@ -10057,6 +10046,36 @@ namespace Catch { } // end catch_errno_guard.h +// start catch_windows_h_proxy.h + + +#if defined(CATCH_PLATFORM_WINDOWS) + +#if !defined(NOMINMAX) && !defined(CATCH_CONFIG_NO_NOMINMAX) +# define CATCH_DEFINED_NOMINMAX +# define NOMINMAX +#endif +#if !defined(WIN32_LEAN_AND_MEAN) && !defined(CATCH_CONFIG_NO_WIN32_LEAN_AND_MEAN) +# define CATCH_DEFINED_WIN32_LEAN_AND_MEAN +# define WIN32_LEAN_AND_MEAN +#endif + +#ifdef __AFXDLL +#include +#else +#include +#endif + +#ifdef CATCH_DEFINED_NOMINMAX +# undef NOMINMAX +#endif +#ifdef CATCH_DEFINED_WIN32_LEAN_AND_MEAN +# undef WIN32_LEAN_AND_MEAN +#endif + +#endif // defined(CATCH_PLATFORM_WINDOWS) + +// end catch_windows_h_proxy.h #include namespace Catch { @@ -10573,7 +10592,7 @@ namespace Catch { // Extracts the actual name part of an enum instance // In other words, it returns the Blue part of Bikeshed::Colour::Blue StringRef extractInstanceName(StringRef enumInstance) { - // Find last occurence of ":" + // Find last occurrence of ":" size_t name_start = enumInstance.size(); while (name_start > 0 && enumInstance[name_start - 1] != ':') { --name_start; @@ -10735,25 +10754,47 @@ namespace Catch { // end catch_exception_translator_registry.cpp // start catch_fatal_condition.cpp -#if defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wmissing-field-initializers" -#endif +#include + +#if !defined( CATCH_CONFIG_WINDOWS_SEH ) && !defined( CATCH_CONFIG_POSIX_SIGNALS ) + +namespace Catch { + + // If neither SEH nor signal handling is required, the handler impls + // do not have to do anything, and can be empty. + void FatalConditionHandler::engage_platform() {} + void FatalConditionHandler::disengage_platform() {} + FatalConditionHandler::FatalConditionHandler() = default; + FatalConditionHandler::~FatalConditionHandler() = default; + +} // end namespace Catch + +#endif // !CATCH_CONFIG_WINDOWS_SEH && !CATCH_CONFIG_POSIX_SIGNALS + +#if defined( CATCH_CONFIG_WINDOWS_SEH ) && defined( CATCH_CONFIG_POSIX_SIGNALS ) +#error "Inconsistent configuration: Windows' SEH handling and POSIX signals cannot be enabled at the same time" +#endif // CATCH_CONFIG_WINDOWS_SEH && CATCH_CONFIG_POSIX_SIGNALS #if defined( CATCH_CONFIG_WINDOWS_SEH ) || defined( CATCH_CONFIG_POSIX_SIGNALS ) namespace { - // Report the error condition + //! Signals fatal error message to the run context void reportFatal( char const * const message ) { Catch::getCurrentContext().getResultCapture()->handleFatalErrorCondition( message ); } -} -#endif // signals/SEH handling + //! Minimal size Catch2 needs for its own fatal error handling. + //! Picked anecdotally, so it might not be sufficient on all + //! platforms, and for all configurations. + constexpr std::size_t minStackSizeForErrors = 32 * 1024; +} // end unnamed namespace + +#endif // CATCH_CONFIG_WINDOWS_SEH || CATCH_CONFIG_POSIX_SIGNALS #if defined( CATCH_CONFIG_WINDOWS_SEH ) namespace Catch { + struct SignalDefs { DWORD id; const char* name; }; // There is no 1-1 mapping between signals and windows exceptions. @@ -10766,7 +10807,7 @@ namespace Catch { { static_cast(EXCEPTION_INT_DIVIDE_BY_ZERO), "Divide by zero error" }, }; - LONG CALLBACK FatalConditionHandler::handleVectoredException(PEXCEPTION_POINTERS ExceptionInfo) { + static LONG CALLBACK handleVectoredException(PEXCEPTION_POINTERS ExceptionInfo) { for (auto const& def : signalDefs) { if (ExceptionInfo->ExceptionRecord->ExceptionCode == def.id) { reportFatal(def.name); @@ -10777,38 +10818,50 @@ namespace Catch { return EXCEPTION_CONTINUE_SEARCH; } + // Since we do not support multiple instantiations, we put these + // into global variables and rely on cleaning them up in outlined + // constructors/destructors + static PVOID exceptionHandlerHandle = nullptr; + + // For MSVC, we reserve part of the stack memory for handling + // memory overflow structured exception. FatalConditionHandler::FatalConditionHandler() { - isSet = true; - // 32k seems enough for Catch to handle stack overflow, - // but the value was found experimentally, so there is no strong guarantee - guaranteeSize = 32 * 1024; - exceptionHandlerHandle = nullptr; + ULONG guaranteeSize = static_cast(minStackSizeForErrors); + if (!SetThreadStackGuarantee(&guaranteeSize)) { + // We do not want to fully error out, because needing + // the stack reserve should be rare enough anyway. + Catch::cerr() + << "Failed to reserve piece of stack." + << " Stack overflows will not be reported successfully."; + } + } + + // We do not attempt to unset the stack guarantee, because + // Windows does not support lowering the stack size guarantee. + FatalConditionHandler::~FatalConditionHandler() = default; + + void FatalConditionHandler::engage_platform() { // Register as first handler in current chain exceptionHandlerHandle = AddVectoredExceptionHandler(1, handleVectoredException); - // Pass in guarantee size to be filled - SetThreadStackGuarantee(&guaranteeSize); + if (!exceptionHandlerHandle) { + CATCH_RUNTIME_ERROR("Could not register vectored exception handler"); + } } - void FatalConditionHandler::reset() { - if (isSet) { - RemoveVectoredExceptionHandler(exceptionHandlerHandle); - SetThreadStackGuarantee(&guaranteeSize); - exceptionHandlerHandle = nullptr; - isSet = false; + void FatalConditionHandler::disengage_platform() { + if (!RemoveVectoredExceptionHandler(exceptionHandlerHandle)) { + CATCH_RUNTIME_ERROR("Could not unregister vectored exception handler"); } + exceptionHandlerHandle = nullptr; } - FatalConditionHandler::~FatalConditionHandler() { - reset(); - } +} // end namespace Catch -bool FatalConditionHandler::isSet = false; -ULONG FatalConditionHandler::guaranteeSize = 0; -PVOID FatalConditionHandler::exceptionHandlerHandle = nullptr; +#endif // CATCH_CONFIG_WINDOWS_SEH -} // namespace Catch +#if defined( CATCH_CONFIG_POSIX_SIGNALS ) -#elif defined( CATCH_CONFIG_POSIX_SIGNALS ) +#include namespace Catch { @@ -10817,10 +10870,6 @@ namespace Catch { const char* name; }; - // 32kb for the alternate stack seems to be sufficient. However, this value - // is experimentally determined, so that's not guaranteed. - static constexpr std::size_t sigStackSize = 32768 >= MINSIGSTKSZ ? 32768 : MINSIGSTKSZ; - static SignalDefs signalDefs[] = { { SIGINT, "SIGINT - Terminal interrupt signal" }, { SIGILL, "SIGILL - Illegal instruction signal" }, @@ -10830,7 +10879,32 @@ namespace Catch { { SIGABRT, "SIGABRT - Abort (abnormal termination) signal" } }; - void FatalConditionHandler::handleSignal( int sig ) { +// Older GCCs trigger -Wmissing-field-initializers for T foo = {} +// which is zero initialization, but not explicit. We want to avoid +// that. +#if defined(__GNUC__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wmissing-field-initializers" +#endif + + static char* altStackMem = nullptr; + static std::size_t altStackSize = 0; + static stack_t oldSigStack{}; + static struct sigaction oldSigActions[sizeof(signalDefs) / sizeof(SignalDefs)]{}; + + static void restorePreviousSignalHandlers() { + // We set signal handlers back to the previous ones. Hopefully + // nobody overwrote them in the meantime, and doesn't expect + // their signal handlers to live past ours given that they + // installed them after ours.. + for (std::size_t i = 0; i < sizeof(signalDefs) / sizeof(SignalDefs); ++i) { + sigaction(signalDefs[i].id, &oldSigActions[i], nullptr); + } + // Return the old stack + sigaltstack(&oldSigStack, nullptr); + } + + static void handleSignal( int sig ) { char const * name = ""; for (auto const& def : signalDefs) { if (sig == def.id) { @@ -10838,16 +10912,33 @@ namespace Catch { break; } } - reset(); - reportFatal(name); + // We need to restore previous signal handlers and let them do + // their thing, so that the users can have the debugger break + // when a signal is raised, and so on. + restorePreviousSignalHandlers(); + reportFatal( name ); raise( sig ); } FatalConditionHandler::FatalConditionHandler() { - isSet = true; + assert(!altStackMem && "Cannot initialize POSIX signal handler when one already exists"); + if (altStackSize == 0) { + altStackSize = std::max(static_cast(SIGSTKSZ), minStackSizeForErrors); + } + altStackMem = new char[altStackSize](); + } + + FatalConditionHandler::~FatalConditionHandler() { + delete[] altStackMem; + // We signal that another instance can be constructed by zeroing + // out the pointer. + altStackMem = nullptr; + } + + void FatalConditionHandler::engage_platform() { stack_t sigStack; sigStack.ss_sp = altStackMem; - sigStack.ss_size = sigStackSize; + sigStack.ss_size = altStackSize; sigStack.ss_flags = 0; sigaltstack(&sigStack, &oldSigStack); struct sigaction sa = { }; @@ -10859,40 +10950,17 @@ namespace Catch { } } - FatalConditionHandler::~FatalConditionHandler() { - reset(); - } +#if defined(__GNUC__) +# pragma GCC diagnostic pop +#endif - void FatalConditionHandler::reset() { - if( isSet ) { - // Set signals back to previous values -- hopefully nobody overwrote them in the meantime - for( std::size_t i = 0; i < sizeof(signalDefs)/sizeof(SignalDefs); ++i ) { - sigaction(signalDefs[i].id, &oldSigActions[i], nullptr); - } - // Return the old stack - sigaltstack(&oldSigStack, nullptr); - isSet = false; - } + void FatalConditionHandler::disengage_platform() { + restorePreviousSignalHandlers(); } - bool FatalConditionHandler::isSet = false; - struct sigaction FatalConditionHandler::oldSigActions[sizeof(signalDefs)/sizeof(SignalDefs)] = {}; - stack_t FatalConditionHandler::oldSigStack = {}; - char FatalConditionHandler::altStackMem[sigStackSize] = {}; - -} // namespace Catch - -#else - -namespace Catch { - void FatalConditionHandler::reset() {} -} - -#endif // signals/SEH handling +} // end namespace Catch -#if defined(__GNUC__) -# pragma GCC diagnostic pop -#endif +#endif // CATCH_CONFIG_POSIX_SIGNALS // end catch_fatal_condition.cpp // start catch_generators.cpp @@ -11447,7 +11515,8 @@ namespace { return lhs == rhs; } - auto ulpDiff = std::abs(lc - rc); + // static cast as a workaround for IBM XLC + auto ulpDiff = std::abs(static_cast(lc - rc)); return static_cast(ulpDiff) <= maxUlpDiff; } @@ -11621,7 +11690,6 @@ Floating::WithinRelMatcher WithinRel(float target) { } // namespace Matchers } // namespace Catch - // end catch_matchers_floating.cpp // start catch_matchers_generic.cpp @@ -12955,9 +13023,8 @@ namespace Catch { } void RunContext::invokeActiveTestCase() { - FatalConditionHandler fatalConditionHandler; // Handle signals + FatalConditionHandlerGuard _(&m_fatalConditionhandler); m_activeTestCase->invoke(); - fatalConditionHandler.reset(); } void RunContext::handleUnfinishedSections() { @@ -13325,6 +13392,10 @@ namespace Catch { filename.erase(0, lastSlash); filename[0] = '#'; } + else + { + filename.insert(0, "#"); + } auto lastDot = filename.find_last_of('.'); if (lastDot != std::string::npos) { @@ -15320,7 +15391,7 @@ namespace Catch { } Version const& libraryVersion() { - static Version version( 2, 13, 4, "", 0 ); + static Version version( 2, 13, 9, "", 0 ); return version; } @@ -16733,6 +16804,7 @@ CATCH_REGISTER_REPORTER("console", ConsoleReporter) #include #include #include +#include namespace Catch { @@ -16760,7 +16832,7 @@ namespace Catch { #else std::strftime(timeStamp, timeStampSize, fmt, timeInfo); #endif - return std::string(timeStamp); + return std::string(timeStamp, timeStampSize-1); } std::string fileNameTag(const std::vector &tags) { @@ -16771,6 +16843,17 @@ namespace Catch { return it->substr(1); return std::string(); } + + // Formats the duration in seconds to 3 decimal places. + // This is done because some genius defined Maven Surefire schema + // in a way that only accepts 3 decimal places, and tools like + // Jenkins use that schema for validation JUnit reporter output. + std::string formatDuration( double seconds ) { + ReusableStringStream rss; + rss << std::fixed << std::setprecision( 3 ) << seconds; + return rss.str(); + } + } // anonymous namespace JunitReporter::JunitReporter( ReporterConfig const& _config ) @@ -16840,7 +16923,7 @@ namespace Catch { if( m_config->showDurations() == ShowDurations::Never ) xml.writeAttribute( "time", "" ); else - xml.writeAttribute( "time", suiteTime ); + xml.writeAttribute( "time", formatDuration( suiteTime ) ); xml.writeAttribute( "timestamp", getCurrentTimestamp() ); // Write properties if there are any @@ -16885,12 +16968,13 @@ namespace Catch { if ( !m_config->name().empty() ) className = m_config->name() + "." + className; - writeSection( className, "", rootSection ); + writeSection( className, "", rootSection, stats.testInfo.okToFail() ); } - void JunitReporter::writeSection( std::string const& className, - std::string const& rootName, - SectionNode const& sectionNode ) { + void JunitReporter::writeSection( std::string const& className, + std::string const& rootName, + SectionNode const& sectionNode, + bool testOkToFail) { std::string name = trim( sectionNode.stats.sectionInfo.name ); if( !rootName.empty() ) name = rootName + '/' + name; @@ -16907,13 +16991,18 @@ namespace Catch { xml.writeAttribute( "classname", className ); xml.writeAttribute( "name", name ); } - xml.writeAttribute( "time", ::Catch::Detail::stringify( sectionNode.stats.durationInSeconds ) ); + xml.writeAttribute( "time", formatDuration( sectionNode.stats.durationInSeconds ) ); // This is not ideal, but it should be enough to mimic gtest's // junit output. // Ideally the JUnit reporter would also handle `skipTest` // events and write those out appropriately. xml.writeAttribute( "status", "run" ); + if (sectionNode.stats.assertions.failedButOk) { + xml.scopedElement("skipped") + .writeAttribute("message", "TEST_CASE tagged with !mayfail"); + } + writeAssertions( sectionNode ); if( !sectionNode.stdOut.empty() ) @@ -16923,9 +17012,9 @@ namespace Catch { } for( auto const& childNode : sectionNode.childSections ) if( className.empty() ) - writeSection( name, "", *childNode ); + writeSection( name, "", *childNode, testOkToFail ); else - writeSection( className, name, *childNode ); + writeSection( className, name, *childNode, testOkToFail ); } void JunitReporter::writeAssertions( SectionNode const& sectionNode ) { @@ -17570,9 +17659,9 @@ int main (int argc, char * const argv[]) { #if defined(CATCH_CONFIG_ENABLE_BENCHMARKING) #define CATCH_BENCHMARK(...) \ - INTERNAL_CATCH_BENCHMARK(INTERNAL_CATCH_UNIQUE_NAME(____C_A_T_C_H____B_E_N_C_H____), INTERNAL_CATCH_GET_1_ARG(__VA_ARGS__,,), INTERNAL_CATCH_GET_2_ARG(__VA_ARGS__,,)) + INTERNAL_CATCH_BENCHMARK(INTERNAL_CATCH_UNIQUE_NAME(C_A_T_C_H_B_E_N_C_H_), INTERNAL_CATCH_GET_1_ARG(__VA_ARGS__,,), INTERNAL_CATCH_GET_2_ARG(__VA_ARGS__,,)) #define CATCH_BENCHMARK_ADVANCED(name) \ - INTERNAL_CATCH_BENCHMARK_ADVANCED(INTERNAL_CATCH_UNIQUE_NAME(____C_A_T_C_H____B_E_N_C_H____), name) + INTERNAL_CATCH_BENCHMARK_ADVANCED(INTERNAL_CATCH_UNIQUE_NAME(C_A_T_C_H_B_E_N_C_H_), name) #endif // CATCH_CONFIG_ENABLE_BENCHMARKING // If CATCH_CONFIG_PREFIX_ALL is not defined then the CATCH_ prefix is not required @@ -17674,9 +17763,9 @@ int main (int argc, char * const argv[]) { #if defined(CATCH_CONFIG_ENABLE_BENCHMARKING) #define BENCHMARK(...) \ - INTERNAL_CATCH_BENCHMARK(INTERNAL_CATCH_UNIQUE_NAME(____C_A_T_C_H____B_E_N_C_H____), INTERNAL_CATCH_GET_1_ARG(__VA_ARGS__,,), INTERNAL_CATCH_GET_2_ARG(__VA_ARGS__,,)) + INTERNAL_CATCH_BENCHMARK(INTERNAL_CATCH_UNIQUE_NAME(C_A_T_C_H_B_E_N_C_H_), INTERNAL_CATCH_GET_1_ARG(__VA_ARGS__,,), INTERNAL_CATCH_GET_2_ARG(__VA_ARGS__,,)) #define BENCHMARK_ADVANCED(name) \ - INTERNAL_CATCH_BENCHMARK_ADVANCED(INTERNAL_CATCH_UNIQUE_NAME(____C_A_T_C_H____B_E_N_C_H____), name) + INTERNAL_CATCH_BENCHMARK_ADVANCED(INTERNAL_CATCH_UNIQUE_NAME(C_A_T_C_H_B_E_N_C_H_), name) #endif // CATCH_CONFIG_ENABLE_BENCHMARKING using Catch::Detail::Approx; @@ -17723,8 +17812,8 @@ using Catch::Detail::Approx; #define CATCH_WARN( msg ) (void)(0) #define CATCH_CAPTURE( msg ) (void)(0) -#define CATCH_TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) -#define CATCH_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) +#define CATCH_TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_S_T_ )) +#define CATCH_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_S_T_ )) #define CATCH_METHOD_AS_TEST_CASE( method, ... ) #define CATCH_REGISTER_TEST_CASE( Function, ... ) (void)(0) #define CATCH_SECTION( ... ) @@ -17733,7 +17822,7 @@ using Catch::Detail::Approx; #define CATCH_FAIL_CHECK( ... ) (void)(0) #define CATCH_SUCCEED( ... ) (void)(0) -#define CATCH_ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) +#define CATCH_ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_S_T_ )) #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR #define CATCH_TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(__VA_ARGS__) @@ -17756,8 +17845,8 @@ using Catch::Detail::Approx; #endif // "BDD-style" convenience wrappers -#define CATCH_SCENARIO( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) -#define CATCH_SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), className ) +#define CATCH_SCENARIO( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_S_T_ )) +#define CATCH_SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_S_T_ ), className ) #define CATCH_GIVEN( desc ) #define CATCH_AND_GIVEN( desc ) #define CATCH_WHEN( desc ) @@ -17805,10 +17894,10 @@ using Catch::Detail::Approx; #define INFO( msg ) (void)(0) #define UNSCOPED_INFO( msg ) (void)(0) #define WARN( msg ) (void)(0) -#define CAPTURE( msg ) (void)(0) +#define CAPTURE( ... ) (void)(0) -#define TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) -#define TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) +#define TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_S_T_ )) +#define TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_S_T_ )) #define METHOD_AS_TEST_CASE( method, ... ) #define REGISTER_TEST_CASE( Function, ... ) (void)(0) #define SECTION( ... ) @@ -17816,7 +17905,7 @@ using Catch::Detail::Approx; #define FAIL( ... ) (void)(0) #define FAIL_CHECK( ... ) (void)(0) #define SUCCEED( ... ) (void)(0) -#define ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) +#define ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_S_T_ )) #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR #define TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(__VA_ARGS__) @@ -17846,8 +17935,8 @@ using Catch::Detail::Approx; #define CATCH_TRANSLATE_EXCEPTION( signature ) INTERNAL_CATCH_TRANSLATE_EXCEPTION_NO_REG( INTERNAL_CATCH_UNIQUE_NAME( catch_internal_ExceptionTranslator ), signature ) // "BDD-style" convenience wrappers -#define SCENARIO( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ) ) -#define SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), className ) +#define SCENARIO( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_S_T_ ) ) +#define SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( C_A_T_C_H_T_E_S_T_ ), className ) #define GIVEN( desc ) #define AND_GIVEN( desc ) diff --git a/tests/src/AccessControllerTest.cpp b/tests/src/AccessControllerTest.cpp index 947aebe..7526eec 100644 --- a/tests/src/AccessControllerTest.cpp +++ b/tests/src/AccessControllerTest.cpp @@ -104,4 +104,45 @@ TEST_CASE("Test authorization", "[access_controller]") { REQUIRE(acs.allowPublish("my/very/private/channel")); } } + + GIVEN("We have a rlimit") { + Config cfg(cfgMap); + cfg << "disable_auth = 0"; + cfg.load(); + auto token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyQGRvbWFpbi5jb20iLCJyZWFkIjpbInRvcGljLyMiLCJ0b3BpYzIvIyJdLCJ3cml0ZSI6WyJ0b3BpYzEvIyJdLCJybGltaXQiOlt7InRvcGljIjoidG9waWMxLyMiLCJpbnRlcnZhbCI6MTAsIm1heCI6MTB9LHsidG9waWMiOiJ0b3BpYzIiLCJpbnRlcnZhbCI6MTAwLCJtYXgiOjEwfV19.i938ZYQL4NR1VUUfrtAwPaivd3cldW6Pegdo9ofpjcE"; + AccessController acs(cfg); + bool tokenLoaded = acs.authenticate(token, "eventhub_secret"); + + THEN("The token should be successfully loaded") { + REQUIRE(tokenLoaded == true); + } + + THEN("We should have the correct limits set for topics defined in token.") { + /* + "rlimit": [{ + "topic": "topic1/#", + "interval": 10, + "max": 10 + }, + { + "topic": "topic2", + "interval": 100, + "max": 10 + } + */ + + REQUIRE(acs.getRateLimitConfig().getRateLimitForTopic("topic1").interval == 10); + REQUIRE(acs.getRateLimitConfig().getRateLimitForTopic("topic1").max == 10); + REQUIRE(acs.getRateLimitConfig().getRateLimitForTopic("topic1/foo").interval == 10); + REQUIRE(acs.getRateLimitConfig().getRateLimitForTopic("topic1/foo").max == 10); + REQUIRE(acs.getRateLimitConfig().getRateLimitForTopic("topic1/foo/bar").interval == 10); + REQUIRE(acs.getRateLimitConfig().getRateLimitForTopic("topic1/foo/bar").max == 10); + REQUIRE(acs.getRateLimitConfig().getRateLimitForTopic("topic2").interval == 100); + REQUIRE(acs.getRateLimitConfig().getRateLimitForTopic("topic2").max == 10); + + CHECK_THROWS_AS(acs.getRateLimitConfig().getRateLimitForTopic("topic2/foo"), NoRateLimitForTopic); + CHECK_THROWS_AS(acs.getRateLimitConfig().getRateLimitForTopic("topic3"), NoRateLimitForTopic); + CHECK_THROWS_AS(acs.getRateLimitConfig().getRateLimitForTopic("topic3"), NoRateLimitForTopic); + } + } } diff --git a/tests/src/ConfigTest.cpp b/tests/src/ConfigTest.cpp index 5c621c9..bbd6e19 100644 --- a/tests/src/ConfigTest.cpp +++ b/tests/src/ConfigTest.cpp @@ -1,11 +1,12 @@ +#include +#include +#include +#include +#include + #include "Config.hpp" #include "catch.hpp" -#include -#include -#include -#include - using namespace eventhub; using Catch::Matchers::Contains; @@ -111,7 +112,7 @@ TEST_CASE("Config test") { REQUIRE_THROWS_AS(cfg.load(), RequiredOptionMissingException); } - SECTION("Ceheck that we support quoted values") { + SECTION("Check that we support quoted values") { Config cfg; cfg.defineOption("test_string_double_quote", ConfigValueSettings::REQUIRED); cfg.defineOption("test_string_single_quote", ConfigValueSettings::REQUIRED); @@ -256,4 +257,4 @@ TEST_CASE("Config test") { } } -} \ No newline at end of file +} diff --git a/tests/src/EventLoopTest.cpp b/tests/src/EventLoopTest.cpp index 53712b9..f6bf282 100644 --- a/tests/src/EventLoopTest.cpp +++ b/tests/src/EventLoopTest.cpp @@ -1,5 +1,6 @@ #include #include +#include #include "EventLoop.hpp" #include "catch.hpp" diff --git a/tests/src/KVStoreTest.cpp b/tests/src/KVStoreTest.cpp new file mode 100644 index 0000000..23df955 --- /dev/null +++ b/tests/src/KVStoreTest.cpp @@ -0,0 +1,13 @@ +#include + +#include "catch.hpp" + +namespace eventhub { + +TEST_CASE("Test KVStore") { + SECTION("Test set and get") { + // To be implemented. + } +} + +} diff --git a/tests/src/RedisTest.cpp b/tests/src/RedisTest.cpp index e86b7e3..7e2fa8e 100644 --- a/tests/src/RedisTest.cpp +++ b/tests/src/RedisTest.cpp @@ -1,12 +1,21 @@ -#include "Common.hpp" -#include "Config.hpp" -#include "catch.hpp" -#include "jwt/json/json.hpp" +#include +#include +#include +#include #include #include +#include +#include +#include +#include +#include +#include +#include +#include "Config.hpp" +#include "catch.hpp" +#include "jwt/json/json.hpp" #include "Redis.hpp" -#include "Util.hpp" using namespace eventhub; @@ -19,7 +28,7 @@ TEST_CASE("Test redis", "[Redis") { { "redis_pool_size", ConfigValueType::INT, "5", ConfigValueSettings::REQUIRED }, { "max_cache_length", ConfigValueType::INT, "1000", ConfigValueSettings::REQUIRED }, { "max_cache_request_limit", ConfigValueType::INT, "100", ConfigValueSettings::REQUIRED }, - { "default_cache_ttl", ConfigValueType::INT, "60", ConfigValueSettings::REQUIRED }, + { "default_cache_ttl", ConfigValueType::INT, "60", ConfigValueSettings::REQUIRED }, { "enable_cache", ConfigValueType::BOOL, "true", ConfigValueSettings::REQUIRED } }; @@ -32,7 +41,7 @@ TEST_CASE("Test redis", "[Redis") { bool connected = true; try { - redis.getRedisInstance()->ping(); + redis.connection()->ping(); } catch (std::exception& e) { connected = false; } @@ -41,11 +50,11 @@ TEST_CASE("Test redis", "[Redis") { } GIVEN("That we increase pub count for test/channel1") { - redis.getRedisInstance()->hdel("eventhub_test:pub_count", "test/channel1"); + redis.connection()->hdel("eventhub_test:pub_count", "test/channel1"); redis._incrTopicPubCount("test/channel1"); THEN("Hashentry eventhub_test.test/channel1 should be larger than 0") { - auto countStr = redis.getRedisInstance()->hget("eventhub_test:pub_count", "test/channel1"); + auto countStr = redis.connection()->hget("eventhub_test:pub_count", "test/channel1"); int count = 0; try { @@ -86,21 +95,21 @@ TEST_CASE("Test redis", "[Redis") { THEN("Cache size should be larger than 0 when requesting a matching pattern") { nlohmann::json j; - size_t cacheSize = redis.getCacheSince("test/#", 0, -1, true, j); + std::size_t cacheSize = redis.getCacheSince("test/#", 0, -1, true, j); REQUIRE(cacheSize > 0); REQUIRE(j.size() > 0); } THEN("Cache size should be larger than 0 when requesting the actual topic") { nlohmann::json j; - size_t cacheSize = redis.getCacheSince("test/channel1", 0, -1, false, j); + std::size_t cacheSize = redis.getCacheSince("test/channel1", 0, -1, false, j); REQUIRE(cacheSize > 0); REQUIRE(j.size() > 0); } } GIVEN("That we publish 2 messages") { - unsigned int msgRcvd = 0; + std::size_t msgRcvd = 0; redis.psubscribe("*", [&msgRcvd](const std::string& pattern, const std::string& topic, const std::string& msg) { REQUIRE(pattern.compare("eventhub_test:*") == 0); REQUIRE((topic.compare("test/topic1") == 0 || topic.compare("test/topic2") == 0)); @@ -167,14 +176,14 @@ TEST_CASE("Test redis", "[Redis") { std::vector cacheIds; auto firstId = redis.cacheMessage("test/topic1", "31337", "petter@testmann.no"); - for (unsigned int i = 0; i < 10; i++) { + for (std::size_t i = 0; i < 10; i++) { cacheIds.push_back(redis.cacheMessage("test/topic1", "31337", "petter@testmann.no")); } THEN("We should get the expected results back") { redis.getCacheSinceId("test/topic1", firstId, 100, false, res); - unsigned int i = 0; + std::size_t i = 0; for (auto item : res) { REQUIRE(cacheIds[i] == item["id"]); i++; diff --git a/tests/src/TopicTest.cpp b/tests/src/TopicTest.cpp index 33d0b16..aac2d33 100644 --- a/tests/src/TopicTest.cpp +++ b/tests/src/TopicTest.cpp @@ -1,7 +1,10 @@ -#include "TopicManager.hpp" -#include "catch.hpp" #include #include +#include + +#include "TopicManager.hpp" +#include "catch.hpp" +#include "Config.hpp" using namespace eventhub; @@ -135,6 +138,7 @@ TEST_CASE("isFilterMatched", "[topic_manager]") { SHOULD_NOT_MATCH("test1/foo", "test1/foobar"); SHOULD_NOT_MATCH("test1/+/test", "test1/test"); SHOULD_NOT_MATCH("+", "foobar/baz"); + SHOULD_NOT_MATCH("topic1/#", "topic2"); SECTION("Should match") { for (auto p : should_match) { diff --git a/tests/src/UtilTest.cpp b/tests/src/UtilTest.cpp index e28c81b..994c2cb 100644 --- a/tests/src/UtilTest.cpp +++ b/tests/src/UtilTest.cpp @@ -1,7 +1,9 @@ +#include +#include +#include + #include "Util.hpp" #include "catch.hpp" -#include -#include namespace eventhub {