跳到主要内容

限制单线程调用

Talk is cheap, Let me show your the code.

用于保证特殊场景中,某个特定函数只能被一个线程调用。当有多个线程调用时,触发断言。

📌 代码详细解析

#define THREAD_GUARD()                                           			\
static std::atomic<std::thread::id> stored_thread_id; \
std::thread::id current_thread_id = std::this_thread::get_id(); \
std::thread::id expected = std::thread::id(); \
\
/* CAS 操作:如果 stored_thread_id 还是默认值,则尝试设置为当前线程 ID */ \
stored_thread_id.compare_exchange_strong(expected, current_thread_id); \
\
/* 现在 stored_thread_id 要么是当前线程 ID,要么是之前的线程 ID */ \
assert(stored_thread_id.load() == current_thread_id && "Function called by multiple threads!");

📍 主要变量

  • stored_thread_id存储调用此函数的线程 ID(静态的 std::atomic<std::thread::id>)。
  • current_thread_id当前正在调用该函数的线程的 ID
  • expectedCAS 操作的对比值,初始为 std::thread::id()(未初始化状态)。

📍 关键逻辑

  1. 初次调用stored_thread_id 还是默认值(未初始化)。
  2. 使用 CAS (compare_exchange_strong()):
    • 如果 stored_thread_id == expected (std::thread::id())
      • 说明当前函数是 第一个被调用,将 stored_thread_id 设置为 current_thread_id(当前线程 ID)。
    • 如果 stored_thread_id != expected
      • 说明已有线程调用过该函数,不能修改 stored_thread_id
  3. 最后进行断言:
    • stored_thread_id 要么是 第一个调用线程的 ID,要么是 之前已经存储的 ID
    • 如果当前线程的 ID 和 stored_thread_id 不匹配,触发 assert,程序崩溃!

🔍 两种情况分析

✅ 情况 1:两个线程「先后」调用

假设:

  • 线程 At=0 调用 testFunction()
  • 线程 A 退出,线程 B 在 t=10 调用 testFunction()

执行步骤:

  1. 线程 A 进入 testFunction()
    • stored_thread_id == std::thread::id()(默认值)。
    • current_thread_id = A 的线程 ID
    • CAS 成功:stored_thread_id 被设置为 A 的 ID。
    • 线程 A 继续执行,不触发 assert
  2. 线程 A 执行完毕,退出函数
  3. 线程 B 在 t=10 进入 testFunction()
    • stored_thread_id == A 的 ID(已被线程 A 设置)。
    • current_thread_id = B 的线程 ID
    • CAS 失败(因为 stored_thread_id != std::thread::id())。
    • stored_thread_id 仍然是 A 的 ID,B 的 ID 和它不匹配
    • 触发 assert,程序崩溃!🚨

成功拦截不同线程先后调用的情况!


✅ 情况 2:两个线程「同时」调用

假设:

  • 线程 A 和 线程 B 在 t=0 几乎同时调用 testFunction()

执行步骤:

  1. 线程 A 进入 testFunction()
    • stored_thread_id == std::thread::id()(默认值)。
    • current_thread_id = A 的线程 ID
    • 线程 A 执行 CAS:
      • 如果 stored_thread_id == std::thread::id(),则设置 stored_thread_id = A 的 ID
      • CAS 可能成功,线程 A 继续执行。
  2. 线程 B 也进入 testFunction()
    • stored_thread_id 可能已经被 线程 A 修改为 A 的 ID(如果 A 先执行 CAS)。
    • current_thread_id = B 的线程 ID
    • 线程 B 执行CAS:
      • 如果 stored_thread_id == std::thread::id(),尝试设置 stored_thread_id = B 的 ID
      • stored_thread_id 已经是 A 的 ID,CAS 失败
      • stored_thread_id 仍然是 A 的 ID,B 的 ID 和它不匹配。
      • 触发 assert,程序崩溃!🚨

成功拦截多个线程同时调用的情况!


📌 为什么 CAS 是关键?

  • 保证只有一个线程能成功初始化 stored_thread_id,所有其他线程只能看到已经存储的值。
  • 即使多个线程并发执行,CAS 仍然是线程安全的,不会导致竞态条件(不像 mutex 需要显式锁)。
  • 后续线程无法修改 stored_thread_id,确保只能由 一个线程 访问该函数。

🛠 总结

情况执行逻辑结果
线程 A 先调用,线程 B 后调用线程 A 成功初始化 stored_thread_id,线程 B 发现 ID 不匹配B 触发 assert 🚨
线程 A 和 B 同时调用线程 A 或 B 其中一个先成功 CAS,另一个线程失败失败的线程触发 assert 🚨