Skip to content

Commit

Permalink
[C++]: add auth tests for archive C++ API.
Browse files Browse the repository at this point in the history
  • Loading branch information
tmontgomery committed Dec 4, 2019
1 parent 451dd9f commit b6985af
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 14 deletions.
11 changes: 6 additions & 5 deletions aeron-archive/src/main/cpp/client/ArchiveConfiguration.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ inline void defaultCredentialsOnFree(std::pair<const char *, std::uint32_t> cred
/**
* Structure to hold credential callbacks.
*/
typedef struct CredentialsSupplierDefn
struct CredentialsSupplier
{
credentials_encoded_credentials_supplier_t m_encodedCredentials = defaultCredentialsEncodedCredentials;
credentials_challenge_supplier_t m_onChallenge = defaultCredentialsOnChallenge;
credentials_free_t m_onFree = defaultCredentialsOnFree;

explicit CredentialsSupplierDefn(
explicit CredentialsSupplier(
credentials_encoded_credentials_supplier_t encodedCredentials = defaultCredentialsEncodedCredentials,
credentials_challenge_supplier_t onChallenge = defaultCredentialsOnChallenge,
credentials_free_t onFree = defaultCredentialsOnFree) :
Expand All @@ -96,8 +96,7 @@ typedef struct CredentialsSupplierDefn
m_onFree(std::move(onFree))
{
}
}
CredentialsSupplier;
};

namespace Configuration
{
Expand Down Expand Up @@ -518,7 +517,9 @@ class Context
*/
inline this_t& credentialsSupplier(const CredentialsSupplier& supplier)
{
m_credentialsSupplier = supplier;
m_credentialsSupplier.m_encodedCredentials = supplier.m_encodedCredentials;
m_credentialsSupplier.m_onChallenge = supplier.m_onChallenge;
m_credentialsSupplier.m_onFree = supplier.m_onFree;
return *this;
}

Expand Down
1 change: 1 addition & 0 deletions aeron-archive/src/main/cpp/client/ControlResponsePoller.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class ControlResponsePoller
m_isControlResponse = false;
m_wasChallenged = false;
delete [] m_encodedChallenge.first;
m_encodedChallenge.first = nullptr;
m_encodedChallenge.second = 0;

return m_subscription->controlledPoll(m_fragmentHandler, m_fragmentLimit);
Expand Down
123 changes: 114 additions & 9 deletions aeron-archive/src/test/cpp/AeronArchiveTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class AeronArchiveTest : public testing::Test
"-Daeron.term.buffer.sparse.file=true",
"-Daeron.driver.termination.validator=io.aeron.driver.DefaultAllowTerminationValidator",
"-Daeron.term.buffer.length=64k",
"-Daeron.archive.authenticator.supplier=io.aeron.samples.archive.TestAuthenticatorSupplier",
("-Daeron.archive.dir=" + m_archiveDir).c_str(),
"-cp",
m_aeronAllJar.c_str(),
Expand All @@ -103,6 +104,19 @@ class AeronArchiveTest : public testing::Test
}
}

auto onEncodedCrdentials = []() -> std::pair<const char *, std::uint32_t>
{
std::string creds("admin:admin");

char *arr = new char[creds.length() + 1];
std::strcpy(arr, creds.data());
arr[creds.length()] = '\0';

return { arr, creds.length() };
};

m_context.credentialsSupplier(CredentialsSupplier(onEncodedCrdentials));

m_stream << "ArchivingMediaDriver PID " << std::to_string(m_pid) << std::endl;
}

Expand Down Expand Up @@ -221,6 +235,8 @@ class AeronArchiveTest : public testing::Test

const int m_fragmentLimit = 10;

AeronArchive::Context_t m_context;

pid_t m_pid = 0;

std::ostringstream m_stream;
Expand All @@ -239,12 +255,12 @@ TEST_F(AeronArchiveTest, shouldSpinUpArchiveAndShutdown)

TEST_F(AeronArchiveTest, shouldBeAbleToConnectToArchive)
{
std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect();
std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect(m_context);
}

TEST_F(AeronArchiveTest, shouldBeAbleToConnectToArchiveViaAsync)
{
std::shared_ptr<AeronArchive::AsyncConnect> asyncConnect = AeronArchive::asyncConnect();
std::shared_ptr<AeronArchive::AsyncConnect> asyncConnect = AeronArchive::asyncConnect(m_context);
aeron::concurrent::YieldingIdleStrategy idle;

std::shared_ptr<AeronArchive> aeronArchive = asyncConnect->poll();
Expand All @@ -263,7 +279,7 @@ TEST_F(AeronArchiveTest, shouldRecordPublicationAndFindRecording)
std::int64_t recordingIdFromCounter = aeron::NULL_VALUE;
std::int64_t stopPosition = aeron::NULL_VALUE;

std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect();
std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect(m_context);

const std::int64_t subscriptionId = aeronArchive->startRecording(
m_recordingChannel, m_recordingStreamId, AeronArchive::SourceLocation::LOCAL);
Expand Down Expand Up @@ -326,7 +342,7 @@ TEST_F(AeronArchiveTest, shouldRecordThenReplay)
std::int64_t recordingIdFromCounter = aeron::NULL_VALUE;
std::int64_t stopPosition = aeron::NULL_VALUE;

std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect();
std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect(m_context);

const std::int64_t subscriptionId = aeronArchive->startRecording(
m_recordingChannel, m_recordingStreamId, AeronArchive::SourceLocation::LOCAL);
Expand Down Expand Up @@ -383,7 +399,7 @@ TEST_F(AeronArchiveTest, shouldRecordThenReplayThenTruncate)
std::int64_t recordingIdFromCounter = aeron::NULL_VALUE;
std::int64_t stopPosition = aeron::NULL_VALUE;

std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect();
std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect(m_context);

const std::int64_t subscriptionId = aeronArchive->startRecording(
m_recordingChannel, m_recordingStreamId, AeronArchive::SourceLocation::LOCAL);
Expand Down Expand Up @@ -468,7 +484,7 @@ TEST_F(AeronArchiveTest, shouldRecordAndCancelReplayEarly)
std::int64_t recordingId = aeron::NULL_VALUE;
std::int64_t stopPosition = aeron::NULL_VALUE;

std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect();
std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect(m_context);

{
std::shared_ptr<Subscription> subscription = addSubscription(
Expand Down Expand Up @@ -515,7 +531,7 @@ TEST_F(AeronArchiveTest, shouldReplayRecordingFromLateJoinPosition)
const std::string messagePrefix = "Message ";
const std::size_t messageCount = 10;

std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect();
std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect(m_context);

const std::int64_t subscriptionId = aeronArchive->startRecording(
m_recordingChannel, m_recordingStreamId, AeronArchive::SourceLocation::LOCAL);
Expand Down Expand Up @@ -598,7 +614,7 @@ TEST_F(AeronArchiveTest, shouldListRegisteredRecordingSubscriptions)
const std::string channelTwo = "aeron:udp?endpoint=localhost:5678";
const std::string channelThree = "aeron:udp?endpoint=localhost:4321";

std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect();
std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect(m_context);

const std::int64_t subIdOne = aeronArchive->startRecording(
channelOne, expectedStreamId, AeronArchive::SourceLocation::LOCAL);
Expand Down Expand Up @@ -690,7 +706,7 @@ TEST_F(AeronArchiveTest, shouldMergeFromReplayToLive)
const std::size_t totalMessageCount = initialMessageCount + subsequentMessageCount;
aeron::concurrent::YieldingIdleStrategy idle;

std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect();
std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect(m_context);

std::shared_ptr<Publication> publication = addPublication(
*aeronArchive->context().aeron(), publicationChannel.build(), m_recordingStreamId);
Expand Down Expand Up @@ -769,3 +785,92 @@ TEST_F(AeronArchiveTest, shouldMergeFromReplayToLive)

aeronArchive->stopRecording(recordingSubscriptionId);
}

TEST_F(AeronArchiveTest, shouldExceptionForIncorrectInitialCredentials)
{
auto onEncodedCrdentials = []() -> std::pair<const char *, std::uint32_t>
{
std::string creds("admin:NotAdmin");

char *arr = new char[creds.length() + 1];
std::strcpy(arr, creds.data());
arr[creds.length()] = '\0';

return { arr, creds.length() };
};

m_context.credentialsSupplier(CredentialsSupplier(onEncodedCrdentials));

ASSERT_THROW(
{
std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect(m_context);
},
ArchiveException);
}

TEST_F(AeronArchiveTest, shouldBeAbleToHandleBeingChallenged)
{
auto onEncodedCrdentials = []() -> std::pair<const char *, std::uint32_t>
{
std::string creds("admin:adminC");

char *arr = new char[creds.length() + 1];
std::strcpy(arr, creds.data());
arr[creds.length()] = '\0';

return { arr, creds.length() };
};

auto onChallenge = [](std::pair<const char *, std::uint32_t> encodedChallenge) ->
std::pair<const char *, std::uint32_t>
{
std::string creds("admin:CSadmin");

char *arr = new char[creds.length() + 1];
std::strcpy(arr, creds.data());
arr[creds.length()] = '\0';

return { arr, creds.length() };
};

m_context.credentialsSupplier(CredentialsSupplier(onEncodedCrdentials, onChallenge));

ASSERT_NO_THROW(
{
std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect(m_context);
});
}

TEST_F(AeronArchiveTest, shouldExceptionForIncorrectChallengeCredentials)
{
auto onEncodedCrdentials = []() -> std::pair<const char *, std::uint32_t>
{
std::string creds("admin:adminC");

char *arr = new char[creds.length() + 1];
std::strcpy(arr, creds.data());
arr[creds.length()] = '\0';

return { arr, creds.length() };
};

auto onChallenge = [](std::pair<const char *, std::uint32_t> encodedChallenge) ->
std::pair<const char *, std::uint32_t>
{
std::string creds("admin:adminNoCS");

char *arr = new char[creds.length() + 1];
std::strcpy(arr, creds.data());
arr[creds.length()] = '\0';

return { arr, creds.length() };
};

m_context.credentialsSupplier(CredentialsSupplier(onEncodedCrdentials, onChallenge));

ASSERT_THROW(
{
std::shared_ptr<AeronArchive> aeronArchive = AeronArchive::connect(m_context);
},
ArchiveException);
}

0 comments on commit b6985af

Please sign in to comment.