Java 线程同步组件 CountDownLatch

简介

CountDownLatch 允许一个或一组线程等待其他线程完成后再恢复运行。线程可通过调用await方法进入等待状态,在其他线程调用countDown方法将计数器减为0后,处于等待状态的线程即可恢复运行。

CountDownLatch 的同步功能是基于 AQS 实现的,CountDownLatch 使用 AQS 中的 state 成员变量作为计数器。在 state 不为0的情况下,凡是调用 await 方法的线程将会被阻塞,并被放入 AQS 所维护的同步队列中进行等待。大致示意图如下:

每个阻塞的线程都会被封装成节点对象,节点之间通过 prev 和 next 指针形成同步队列。初始情况下,队列的头结点是一个虚拟节点。该节点仅是一个占位符,没什么特别的意义。每当有一个线程调用 countDown 方法,就将计数器 state–。当 state 被减至0时,队列中的节点就会按照 FIFO 顺序被唤醒,被阻塞的线程即可恢复运行。

CountDownLatch 本身的原理并不难理解,不过如果大家想深入理解 CountDownLatch 的实现细节,那么需要先去学习一下 AQS 的相关原理。CountDownLatch 是基于 AQS 实现的,所以理解 AQS 是学习 CountDownLatch 的前置条件,可以读这篇文章 AbstractQueuedSynchronizer源码分析

Demo例子

该例子前几天写在这篇文章 ReentrantLock源码分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
/**
* reentrantloct 测试
* @author: peijiepang
* @date 2018/11/7
* @Description:
*/
public class ReentrantLockTest extends Thread{

private final static Logger LOGGER = LoggerFactory.getLogger(ReentrantLockTest.class);

private ReentrantLock reentrantLock = new ReentrantLock();

private CountDownLatch countDownLatch = null;

public static int j = 0;

public ReentrantLockTest(String threadName,CountDownLatch countDownLatch) {
super(threadName);
this.countDownLatch = countDownLatch;
}

@Override
public void run() {
for(int i=0;i<1000;i++){
//可限时加锁
//reentrantLock.tryLock(1000,TimeUnit.MILLISECONDS);

//可响应线程中断请求
//reentrantLock.lockInterruptibly();

//可指定公平锁
//ReentrantLock fairLock = new ReentrantLock(true);

reentrantLock.lock();
try{
LOGGER.info("{}:{}",Thread.currentThread().getName(),i);
j++;
}finally {
reentrantLock.unlock();
}
}
countDownLatch.countDown();
}

public static void main(String[] args) throws InterruptedException {
CountDownLatch countDownLatch = new CountDownLatch(2);
ReentrantLockTest reentrantLockTest1 = new ReentrantLockTest("thread1",countDownLatch);
ReentrantLockTest reentrantLockTest2 = new ReentrantLockTest("thread2",countDownLatch);
reentrantLockTest1.start();
reentrantLockTest2.start();
countDownLatch.await();
LOGGER.info("---------j:{}",j);
}
}

源码分析

  1. 类图
  1. 构造函数

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    /**
    * CountDownLatch 的同步控制器,继承自 AQS
    */
    private static final class Sync extends AbstractQueuedSynchronizer {
    private static final long serialVersionUID = 4982264981922014374L;

    Sync(int count) {
    setState(count); // 设置 AQS state
    }

    int getCount() {
    return getState();
    }

    /**
    * 尝试在共享状态下获取同步状态,该方法在 AQS 中是抽象方法,这里进行了覆写
    * @param acquires
    * @return
    */
    protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1; //如果 state = 0,则返回1,表明可获取同步状态 此时线程调用 await 方法时就不会被阻塞。
    }

    /**
    * 尝试在共享状态下释放同步状态,该方法在 AQS 中也是抽象方法
    * @param releases
    * @return
    */
    protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    /*
    * 下面的逻辑是将 state--,state 减至0时,调用 await 等待的线程会被唤醒。
    * 这里使用循环 + CAS,表明会存在竞争的情况,也就是多个线程可能会同时调用
    * countDown 方法。在 state 不为0的情况下,线程调用 countDown 是必须要完
    * 成 state-- 这个操作。所以这里使用了循环 + CAS,确保 countDown 方法可正
    * 常运行。
    */
    for (;;) {
    int c = getState(); // 获取 state
    if (c == 0)
    return false;
    int nextc = c-1;
    if (compareAndSetState(c, nextc)) // 使用 CAS 设置新的 state 值
    return nextc == 0;
    }
    }
    }

    /**
    * 同步器
    */
    private final Sync sync;

    /**
    * Constructs a {@code CountDownLatch} initialized with the given count.
    *
    * @param count the number of times {@link #countDown} must be invoked
    * before threads can pass through {@link #await}
    * @throws IllegalArgumentException if {@code count} is negative
    */
    /**
    * CountDownLatch 的构造方法,该方法要求传入大于0的整型数值作为计数器
    * @param count
    */
    public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count); //初始化 Sync
    }
  2. await分析
    CountDownLatch中有两个版本的 await 方法,一个响应中断,另一个在此基础上增加了超时功能。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    /**
    * 该方法会使线程进入等待状态,直到计数器减至0,或者线程被中断。当计数器为0时,调用
    * 此方法将会立即返回,不会被阻塞住。
    */
    public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1); //调用 AQS 中的 acquireSharedInterruptibly 方法
    }

    /**
    * 带有超时功能的 await
    * @param timeout
    * @param unit
    * @return
    * @throws InterruptedException
    */
    public boolean await(long timeout, TimeUnit unit)
    throws InterruptedException {
    return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    //--- AbstractQueuedSynchronizer---//
    //该函数只是简单的判断AQS的state是否为0,为0则返回1,不为0则返回-1。doAcquireSharedInterruptibly函数的源码如下  
    public final void acquireSharedInterruptibly(int arg) throws InterruptedException {
    // 若线程被中断,则直接抛出中断异常
    if (Thread.interrupted())
    throw new InterruptedException();
    // 调用 Sync 中覆写的 tryAcquireShared 方法,尝试获取同步状态
    if (tryAcquireShared(arg) < 0)
    /*
    * 若 tryAcquireShared 小于0,则表示获取同步状态失败,
    * 此时将线程放入 AQS 的同步队列中进行等待。
    */
    doAcquireSharedInterruptibly(arg);
    }

    private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    // 添加节点至等待队列
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
    for (;;) { // 无限循环
    // 获取node的前驱节点
    final Node p = node.predecessor();
    if (p == head) { // 前驱节点为头结点
    // 试图在共享模式下获取对象状态
    int r = tryAcquireShared(arg);
    if (r >= 0) { // 获取成功
    // 设置头结点并进行繁殖
    setHeadAndPropagate(node, r);
    // 设置节点next域
    p.next = null; // help GC
    failed = false;
    return;
    }
    }
    if (shouldParkAfterFailedAcquire(p, node) &&
    parkAndCheckInterrupt()) // 在获取失败后是否需要禁止线程并且进行中断检查
    // 抛出异常
    throw new InterruptedException();
    }
    } finally {
    if (failed)
    cancelAcquire(node);
    }
    }

    //在AQS的doAcquireSharedInterruptibly中可能会再次调用CountDownLatch的内部类Sync的tryAcquireShared方法和AQS的setHeadAndPropagate方法。setHeadAndPropagate方法源码如下
    private void setHeadAndPropagate(Node node, int propagate) {
    // 获取头结点
    Node h = head; // Record old head for check below
    // 设置头结点
    setHead(node);
    /*
    * Try to signal next queued node if:
    * Propagation was indicated by caller,
    * or was recorded (as h.waitStatus either before
    * or after setHead) by a previous operation
    * (note: this uses sign-check of waitStatus because
    * PROPAGATE status may transition to SIGNAL.)
    * and
    * The next node is waiting in shared mode,
    * or we don't know, because it appears null
    *
    * The conservatism in both of these checks may cause
    * unnecessary wake-ups, but only when there are multiple
    * racing acquires/releases, so most need signals now or soon
    * anyway.
    */
    // 进行判断
    if (propagate > 0 || h == null || h.waitStatus < 0 ||
    (h = head) == null || h.waitStatus < 0) {
    // 获取节点的后继
    Node s = node.next;
    if (s == null || s.isShared()) // 后继为空或者为共享模式
    // 以共享模式进行释放
    doReleaseShared();
    }
    }

    //该方法设置头结点并且释放头结点后面的满足条件的结点,该方法中可能会调用到AQS的doReleaseShared方法,其源码如下。
    private void doReleaseShared() {
    /*
    * Ensure that a release propagates, even if there are other
    * in-progress acquires/releases. This proceeds in the usual
    * way of trying to unparkSuccessor of head if it needs
    * signal. But if it does not, status is set to PROPAGATE to
    * ensure that upon release, propagation continues.
    * Additionally, we must loop in case a new node is added
    * while we are doing this. Also, unlike other uses of
    * unparkSuccessor, we need to know if CAS to reset status
    * fails, if so rechecking.
    */
    // 无限循环
    for (;;) {
    // 保存头结点
    Node h = head;
    if (h != null && h != tail) { // 头结点不为空并且头结点不为尾结点
    // 获取头结点的等待状态
    int ws = h.waitStatus;
    if (ws == Node.SIGNAL) { // 状态为SIGNAL
    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) // 不成功就继续
    continue; // loop to recheck cases
    // 释放后继结点
    unparkSuccessor(h);
    }
    else if (ws == 0 &&
    !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) // 状态为0并且不成功,继续
    continue; // loop on failed CAS
    }
    if (h == head) // 若头结点改变,继续循环
    break;
    }
    }

    ```
    > 从上面的代码中可以看出,CountDownLatch await 方法实际上调用的是 AQS 的 acquireSharedInterruptibly 方法。该方法会在内部调用 Sync 所覆写的 tryAcquireShared 方法。在 state != 0时,tryAcquireShared 返回值 -1。此时线程将进入 doAcquireSharedInterruptibly 方法中,在此方法中,线程会被放入同步队列中进行等待。若 state = 0,此时 tryAcquireShared 返回1,acquireSharedInterruptibly 会直接返回。此时调用 await 的线程也不会被阻塞住。

    4. countDown分析
    ```java
    /**
    * 此函数将递减锁存器的计数,如果计数到达零,则释放所有等待的线程
    */
    public void countDown() {
    sync.releaseShared(1);//对countDown的调用转换为对Sync对象的releaseShared(从AQS继承而来)方法的调用
    }

    /**
    * 此函数会以共享模式释放对象,并且在函数中会调用到CountDownLatch的tryReleaseShared函数,并且可能会调用AQS的doReleaseShared函数
    */
    public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) {
    doReleaseShared();
    return true;
    }
    return false;
    }

    // Sync重写的tryreleaseshared
    protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    // 无限循环
    for (;;) {
    // 获取状态
    int c = getState();
    if (c == 0) // 没有被线程占有
    return false;
    // 下一个状态
    int nextc = c-1;
    if (compareAndSetState(c, nextc)) // 比较并且设置成功
    return nextc == 0;
    }
    }

    //调用aqs的doReleaseShared
    private void doReleaseShared() {
    /*
    * Ensure that a release propagates, even if there are other
    * in-progress acquires/releases. This proceeds in the usual
    * way of trying to unparkSuccessor of head if it needs
    * signal. But if it does not, status is set to PROPAGATE to
    * ensure that upon release, propagation continues.
    * Additionally, we must loop in case a new node is added
    * while we are doing this. Also, unlike other uses of
    * unparkSuccessor, we need to know if CAS to reset status
    * fails, if so rechecking.
    */
    // 无限循环
    for (;;) {
    // 保存头结点
    Node h = head;
    if (h != null && h != tail) { // 头结点不为空并且头结点不为尾结点
    // 获取头结点的等待状态
    int ws = h.waitStatus;
    if (ws == Node.SIGNAL) { // 状态为SIGNAL
    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) // 不成功就继续
    continue; // loop to recheck cases
    // 释放后继结点
    unparkSuccessor(h);
    }
    else if (ws == 0 &&
    !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) // 状态为0并且不成功,继续
    continue; // loop on failed CAS
    }
    if (h == head) // 若头结点改变,继续循环
    break;
    }
    }

总结

经过分析CountDownLatch的源码可知,其底层结构仍然是AQS,对其线程所封装的结点是采用共享模式,而ReentrantLock是采用独占模式。由于采用的共享模式,所以会导致后面的操作会有所差异,通过阅读源码就会很容易掌握CountDownLatch实现机制。