Fix UAF in the test case where signaling thread goes away

Bug: chromium:1478193
Change-Id: If5207e7f740abcc43f74cf8eab30455a8bb0d5ac
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/318622
Commit-Queue: Harald Alvestrand <hta@webrtc.org>
Reviewed-by: Tomas Gunnarsson <tommi@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#40687}
This commit is contained in:
Harald Alvestrand 2023-09-02 05:26:55 +00:00 committed by WebRTC LUCI CQ
parent 6e586e1ad2
commit 8219cc3dc9
4 changed files with 73 additions and 10 deletions

View file

@ -935,7 +935,9 @@ rtc_library("async_dns_resolver") {
":logging",
":macromagic",
":platform_thread",
":refcount",
"../api:async_dns_resolver",
"../api:make_ref_counted",
"../api:sequence_checker",
"../api/task_queue:pending_task_safety_flag",
]

View file

@ -15,6 +15,7 @@
#include <utility>
#include <vector>
#include "api/make_ref_counted.h"
#include "rtc_base/logging.h"
#include "rtc_base/platform_thread.h"
@ -98,6 +99,42 @@ void PostTaskToGlobalQueue(
} // namespace
class AsyncDnsResolver::State : public rtc::RefCountedBase {
public:
enum class Status {
kActive, // Running request, or able to be passed one
kFinished, // Request has finished processing
kDead // The owning AsyncDnsResolver has been deleted
};
static rtc::scoped_refptr<AsyncDnsResolver::State> Create() {
return rtc::make_ref_counted<AsyncDnsResolver::State>();
}
// Execute the passed function if the state is Active.
void Finish(absl::AnyInvocable<void()> function) {
webrtc::MutexLock lock(&mutex_);
if (status_ != Status::kActive) {
return;
}
status_ = Status::kFinished;
function();
}
void Kill() {
webrtc::MutexLock lock(&mutex_);
status_ = Status::kDead;
}
private:
webrtc::Mutex mutex_;
Status status_ RTC_GUARDED_BY(mutex_) = Status::kActive;
};
AsyncDnsResolver::AsyncDnsResolver() : state_(State::Create()) {}
AsyncDnsResolver::~AsyncDnsResolver() {
state_->Kill();
}
void AsyncDnsResolver::Start(const rtc::SocketAddress& addr,
absl::AnyInvocable<void()> callback) {
Start(addr, addr.family(), std::move(callback));
@ -111,17 +148,22 @@ void AsyncDnsResolver::Start(const rtc::SocketAddress& addr,
result_.addr_ = addr;
callback_ = std::move(callback);
auto thread_function = [this, addr, family, flag = safety_.flag(),
caller_task_queue =
webrtc::TaskQueueBase::Current()] {
caller_task_queue = webrtc::TaskQueueBase::Current(),
state = state_] {
std::vector<rtc::IPAddress> addresses;
int error = ResolveHostname(addr.hostname(), family, addresses);
caller_task_queue->PostTask(
SafeTask(flag, [this, error, addresses = std::move(addresses)] {
RTC_DCHECK_RUN_ON(&result_.sequence_checker_);
result_.addresses_ = addresses;
result_.error_ = error;
callback_();
}));
// We assume that the caller task queue is still around if the
// AsyncDnsResolver has not been destroyed.
state->Finish([this, error, flag, caller_task_queue,
addresses = std::move(addresses)]() {
caller_task_queue->PostTask(
SafeTask(flag, [this, error, addresses = std::move(addresses)] {
RTC_DCHECK_RUN_ON(&result_.sequence_checker_);
result_.addresses_ = addresses;
result_.error_ = error;
callback_();
}));
});
};
#if defined(WEBRTC_MAC) || defined(WEBRTC_IOS)
PostTaskToGlobalQueue(

View file

@ -15,6 +15,7 @@
#include "api/async_dns_resolver.h"
#include "api/sequence_checker.h"
#include "api/task_queue/pending_task_safety_flag.h"
#include "rtc_base/ref_counted_object.h"
#include "rtc_base/thread_annotations.h"
namespace webrtc {
@ -38,6 +39,8 @@ class AsyncDnsResolverResultImpl : public AsyncDnsResolverResult {
class AsyncDnsResolver : public AsyncDnsResolverInterface {
public:
AsyncDnsResolver();
~AsyncDnsResolver();
// Start address resolution of the hostname in `addr`.
void Start(const rtc::SocketAddress& addr,
absl::AnyInvocable<void()> callback) override;
@ -48,7 +51,9 @@ class AsyncDnsResolver : public AsyncDnsResolverInterface {
const AsyncDnsResolverResult& result() const override;
private:
ScopedTaskSafety safety_;
class State;
ScopedTaskSafety safety_; // To check for client going away
rtc::scoped_refptr<State> state_; // To check for "this" going away
AsyncDnsResolverResultImpl result_;
absl::AnyInvocable<void()> callback_;
};

View file

@ -40,5 +40,19 @@ TEST(AsyncDnsResolver, ResolvingLocalhostWorks) {
}
}
TEST(AsyncDnsResolver, ResolveAfterDeleteDoesNotReturn) {
test::RunLoop loop;
std::unique_ptr<AsyncDnsResolver> resolver =
std::make_unique<AsyncDnsResolver>();
rtc::SocketAddress address("localhost",
kPortNumber); // Port number does not matter
rtc::SocketAddress resolved_address;
bool done = false;
resolver->Start(address, [&done] { done = true; });
resolver.reset(); // Deletes resolver.
rtc::Thread::Current()->SleepMs(1); // Allows callback to execute
EXPECT_FALSE(done); // Expect no result.
}
} // namespace
} // namespace webrtc