diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn index 561cb9ae7f..ac27ee8251 100644 --- a/rtc_base/BUILD.gn +++ b/rtc_base/BUILD.gn @@ -1455,6 +1455,7 @@ rtc_library("ssl") { "../api/task_queue:pending_task_safety_flag", "../api/units:time_delta", "../system_wrappers:field_trial", + "synchronization:mutex", "system:rtc_export", "task_utils:repeating_task", "third_party/base64", diff --git a/rtc_base/helpers.cc b/rtc_base/helpers.cc index 337239894a..84cbe5fba1 100644 --- a/rtc_base/helpers.cc +++ b/rtc_base/helpers.cc @@ -19,19 +19,14 @@ #include "absl/strings/string_view.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" +#include "rtc_base/synchronization/mutex.h" // Protect against max macro inclusion. #undef max namespace rtc { -// Base class for RNG implementations. -class RandomGenerator { - public: - virtual ~RandomGenerator() {} - virtual bool Init(const void* seed, size_t len) = 0; - virtual bool Generate(void* buf, size_t len) = 0; -}; +namespace { // The OpenSSL RNG. class SecureRandomGenerator : public RandomGenerator { @@ -64,8 +59,6 @@ class TestRandomGenerator : public RandomGenerator { int seed_; }; -namespace { - // TODO: Use Base64::Base64Table instead. static const char kBase64[64] = { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', @@ -79,6 +72,13 @@ static const char kHex[16] = {'0', '1', '2', '3', '4', '5', '6', '7', static const char kUuidDigit17[4] = {'8', '9', 'a', 'b'}; +// Lock for the global random generator, only needed to serialize changing the +// generator. +webrtc::Mutex& GetRandomGeneratorLock() { + static webrtc::Mutex& mutex = *new webrtc::Mutex(); + return mutex; +} + // This round about way of creating a global RNG is to safe-guard against // indeterminant static initialization order. std::unique_ptr& GetGlobalRng() { @@ -94,7 +94,18 @@ RandomGenerator& Rng() { } // namespace +void SetDefaultRandomGenerator() { + webrtc::MutexLock lock(&GetRandomGeneratorLock()); + GetGlobalRng().reset(new SecureRandomGenerator()); +} + +void SetRandomGenerator(std::unique_ptr generator) { + webrtc::MutexLock lock(&GetRandomGeneratorLock()); + GetGlobalRng() = std::move(generator); +} + void SetRandomTestMode(bool test) { + webrtc::MutexLock lock(&GetRandomGeneratorLock()); if (!test) { GetGlobalRng().reset(new SecureRandomGenerator()); } else { diff --git a/rtc_base/helpers.h b/rtc_base/helpers.h index c214f5212f..51ca672ab5 100644 --- a/rtc_base/helpers.h +++ b/rtc_base/helpers.h @@ -14,6 +14,7 @@ #include #include +#include #include #include "absl/strings/string_view.h" @@ -21,6 +22,23 @@ namespace rtc { +// Interface for RNG implementations. +class RandomGenerator { + public: + virtual ~RandomGenerator() {} + virtual bool Init(const void* seed, size_t len) = 0; + virtual bool Generate(void* buf, size_t len) = 0; +}; + +// Sets the default random generator as the source of randomness. The default +// source uses the OpenSSL RNG and provides cryptographically secure randomness. +void SetDefaultRandomGenerator(); + +// Set a custom random generator. Results produced by CreateRandomXyz +// are cryptographically random iff the output of the supplied generator is +// cryptographically random. +void SetRandomGenerator(std::unique_ptr generator); + // For testing, we can return predictable data. void SetRandomTestMode(bool test); diff --git a/rtc_base/helpers_unittest.cc b/rtc_base/helpers_unittest.cc index b85587234a..015b4d0a7c 100644 --- a/rtc_base/helpers_unittest.cc +++ b/rtc_base/helpers_unittest.cc @@ -12,20 +12,30 @@ #include +#include #include #include "rtc_base/buffer.h" +#include "test/gmock.h" #include "test/gtest.h" namespace rtc { +namespace { -class RandomTest : public ::testing::Test {}; +using ::testing::_; +using ::testing::DoAll; +using ::testing::Invoke; +using ::testing::IsEmpty; +using ::testing::Not; +using ::testing::Return; +using ::testing::WithArg; +using ::testing::WithArgs; -TEST_F(RandomTest, TestCreateRandomId) { +TEST(RandomTest, TestCreateRandomId) { CreateRandomId(); } -TEST_F(RandomTest, TestCreateRandomDouble) { +TEST(RandomTest, TestCreateRandomDouble) { for (int i = 0; i < 100; ++i) { double r = CreateRandomDouble(); EXPECT_GE(r, 0.0); @@ -33,11 +43,11 @@ TEST_F(RandomTest, TestCreateRandomDouble) { } } -TEST_F(RandomTest, TestCreateNonZeroRandomId) { +TEST(RandomTest, TestCreateNonZeroRandomId) { EXPECT_NE(0U, CreateRandomNonZeroId()); } -TEST_F(RandomTest, TestCreateRandomString) { +TEST(RandomTest, TestCreateRandomString) { std::string random = CreateRandomString(256); EXPECT_EQ(256U, random.size()); std::string random2; @@ -46,7 +56,7 @@ TEST_F(RandomTest, TestCreateRandomString) { EXPECT_EQ(256U, random2.size()); } -TEST_F(RandomTest, TestCreateRandomData) { +TEST(RandomTest, TestCreateRandomData) { static size_t kRandomDataLength = 32; std::string random1; std::string random2; @@ -57,7 +67,7 @@ TEST_F(RandomTest, TestCreateRandomData) { EXPECT_NE(0, memcmp(random1.data(), random2.data(), kRandomDataLength)); } -TEST_F(RandomTest, TestCreateRandomStringEvenlyDivideTable) { +TEST(RandomTest, TestCreateRandomStringEvenlyDivideTable) { static std::string kUnbiasedTable("01234567"); std::string random; EXPECT_TRUE(CreateRandomString(256, kUnbiasedTable, &random)); @@ -68,12 +78,12 @@ TEST_F(RandomTest, TestCreateRandomStringEvenlyDivideTable) { EXPECT_EQ(0U, random.size()); } -TEST_F(RandomTest, TestCreateRandomUuid) { +TEST(RandomTest, TestCreateRandomUuid) { std::string random = CreateRandomUuid(); EXPECT_EQ(36U, random.size()); } -TEST_F(RandomTest, TestCreateRandomForTest) { +TEST(RandomTest, TestCreateRandomForTest) { // Make sure we get the output we expect. SetRandomTestMode(true); EXPECT_EQ(2154761789U, CreateRandomId()); @@ -112,4 +122,50 @@ TEST_F(RandomTest, TestCreateRandomForTest) { SetRandomTestMode(false); } +class MockRandomGenerator : public RandomGenerator { + public: + MOCK_METHOD(void, Die, ()); + ~MockRandomGenerator() override { Die(); } + + MOCK_METHOD(bool, Init, (const void* seed, size_t len), (override)); + MOCK_METHOD(bool, Generate, (void* buf, size_t len), (override)); +}; + +TEST(RandomTest, TestSetRandomGenerator) { + std::unique_ptr will_move = + std::make_unique(); + MockRandomGenerator* generator = will_move.get(); + SetRandomGenerator(std::move(will_move)); + + EXPECT_CALL(*generator, Init(_, sizeof(int))).WillOnce(Return(true)); + EXPECT_TRUE(InitRandom(5)); + + std::string seed = "seed"; + EXPECT_CALL(*generator, Init(seed.data(), seed.size())) + .WillOnce(Return(true)); + EXPECT_TRUE(InitRandom(seed.data(), seed.size())); + + uint32_t id = 4658; + EXPECT_CALL(*generator, Generate(_, sizeof(uint32_t))) + .WillOnce(DoAll(WithArg<0>(Invoke([&id](void* p) { + std::memcpy(p, &id, sizeof(uint32_t)); + })), + Return(true))); + EXPECT_EQ(CreateRandomId(), id); + + EXPECT_CALL(*generator, Generate) + .WillOnce(DoAll( + WithArgs<0, 1>([](void* p, size_t len) { std::memset(p, 0, len); }), + Return(true))); + EXPECT_THAT(CreateRandomUuid(), Not(IsEmpty())); + + // Set the default random generator, and expect that mock generator is + // not used beyond this point. + EXPECT_CALL(*generator, Die); + EXPECT_CALL(*generator, Generate).Times(0); + SetDefaultRandomGenerator(); + EXPECT_THAT(CreateRandomUuid(), Not(IsEmpty())); +} + +} // namespace } // namespace rtc