并发编程系列之ThreadLocal实现原理

x33g5p2x  于2022-04-17 转载在 其他  
字(10.5k)|赞(0)|评价(0)|浏览(537)

并发编程系列之ThreadLocal实现原理

ThreadLocal看词义,线程本地变量?线程的变量,要怎么定义?怎么使用?ThreadLocal是线程安全的?下面给出一个简单例子,引出本文

1、变量的作用域?

  • 局部变量(线程安全)
  1. public class A {
  2. void doSome1() {
  3. int a = 11;
  4. }
  5. void doSome2() {
  6. int a = 12;
  7. }
  8. void doSome3() {
  9. doSome1();
  10. doSome2();
  11. }
  12. }
  • 全局变量(线程不安全)
    需要加同步控制才能保证线程安全
  1. public class A {
  2. public static int count =1;
  3. }

2、什么是ThreadLocal?

引用ThreadLocal里的代码注释:
This class provides thread-local variables. These variables differ from their normal counterparts in that each thread that accesses one (via its {@code get} or {@code set} method) has its own, independently initialized copy of the variable. {@code ThreadLocal} instances are typically private static fields in classes that wish to associate state with a thread (e.g., a user ID or Transaction ID).

ThreadLocal是一个线程的本地变量,可以理解为线程的变量,在线程执行过程随时可以访问。ThreadLocal变量,只有当前线程才能访问,其它线程不能访问,所以本质上ThreadLocal就是线程安全的。所以ThreadLocal的作用和上面例子说的局部变量一样是线程安全的。

前面的学习,我们知道要保证线程安全,一般就是想到加锁,不管是synchronized还是cas锁等,都会在并发的时候对性能产生一定的影响。ThreadLocal是怎么实现线程安全的?详细可以学习一下ThreadLocal源码

3、ThreadLocal 主要方法和成员变量

ThreadLocal主要的方法有:

  1. // 获取当前线程本地变量的值
  2. public T get() {}
  3. // 给当前线程本地变量设置值
  4. public void set(T value){}
  5. // 清除当前线程本地变量的值。
  6. public void remove(){}
  7. // 统一初始化所有线程的ThreadLocal的值
  8. public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
  9. }

主要变量:

  1. // 调用nextHashCode()方法获取下一个hashCode值
  2. private final int threadLocalHashCode = nextHashCode();
  3. // AmoicInteger原子类,用于计算hashCode值
  4. private staitc AmoicInteger nextHashCode = new AmoicInteger();
  5. // 斐波那契数,也叫黄金分割数,可以让hash值分布非常均匀
  6. private static final int HASH_INCREMENT = 0x61c88647
  7. // 获取下一个hashCode值方法,只用原子类操作
  8. private static int nextHashCode () {
  9. return nextHashCode.getAndAdd(HASH_INCREMENT);
  10. }

4、ThreadLocalMap

看了源码,找到set方法都可以找到一个关键的ThreadLocalMapThreadLocalMapThreadLocal 类的一个静态内部类
ThreadLocalMap is a customized hash map suitable only for maintaining thread local values.

ThreadLocal是ThreadLocal里自定义的hash map,当然和jdk里的HashMap实现是不同,这个map主要作用也是存储ThreadLocal变量值

ThreadLocalMap内部维护着一个Entry节点,Entry继承WeakReference,泛型是ThreadLocal,key申明为ThreadLocal<?> k,实际上就是ThreadLocal的弱引用

  1. /**
  2. * The entries in this hash map extend WeakReference, using
  3. * its main ref field as the key (which is always a
  4. * ThreadLocal object). Note that null keys (i.e. entry.get()
  5. * == null) mean that the key is no longer referenced, so the
  6. * entry can be expunged from table. Such entries are referred to
  7. * as "stale entries" in the code that follows.
  8. */
  9. static class Entry extends WeakReference<ThreadLocal<?>> {
  10. /** The value associated with this ThreadLocal. */
  11. Object value;
  12. Entry(ThreadLocal<?> k, Object v) {
  13. super(k);
  14. value = v;
  15. }
  16. }
  • 强引用: new 出来的对象就是强引用,内存不足或者垃圾回收时候,垃圾回收器都不会回收强引用的对象
  • 软引用:使用 SoftReference 修饰的对象被称为软引用,在内存溢出时,软引用指向的对象会被回收
  • 弱引用:使用 WeakReference 修饰的对象被称为弱引用,只要发生垃圾回收,被弱引用指向的对象就会被回收。
  • 虚引用:虚引用是最弱的引用,用 PhantomReference 进行指定。同样是发生垃圾回收也会被回收,作用是跟踪对象的垃圾回收。
引用类型回收时间用途
强引用JVM停止运行时对象的一般状态
软引用当内存不足时对象缓存
弱引用正常垃圾回收时对象缓存
虚引用正常垃圾回收时跟踪对象的垃圾回收

5、Thread、ThreadLocalMap、ThreadLocal关系

Thread、ThreadLocalMap、ThreadLocal 结构关系图:
每一个Thread都有一个threadLocals变量,这个threadLocals变量其实就是ThreadLocal.ThreadLocalMapThreadLocalMap被设计为ThreadLocal的内部类,在ThreadLocalMap内部类里,在其静态内部类Entry是以ThreadLocal的虚引用为key

Thread、ThreadLocalMap、ThreadLocal 类关系图:

6、ThreadLocal.set () 方法源码实现

  1. public void set(T value) {
  2. // 获取当前线程
  3. Thread t = Thread.currentThread();
  4. // 获取当前线程的ThreadLocalMap
  5. ThreadLocalMap map = getMap(t);
  6. // map不为null,调用ThreadLocalMap的set方法设置值
  7. if (map != null)
  8. map.set(this, value);
  9. else
  10. // map为null,调用createMap方法初始化创建map
  11. createMap(t, value);
  12. }
  13. // 获取当前线程的threadLocals,也就是ThreadLocal.ThreadLocalMap
  14. ThreadLocalMap getMap(Thread t) {
  15. return t.threadLocals;
  16. }
  17. // 创建ThreadLocalMap
  18. void createMap(Thread t, T firstValue) {
  19. t.threadLocals = new ThreadLocalMap(this, firstValue);
  20. }
  21. // ThreadLocalMap构造函数
  22. ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
  23. // 初始化Entry表的容量默认为16
  24. table = new Entry[INITIAL_CAPACITY];
  25. // 数组下标,hashCode与(INITIAL_CAPACITY - 1)
  26. int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
  27. // 创建Entry
  28. table[i] = new Entry(firstKey, firstValue);
  29. // size初始化为1
  30. size = 1;
  31. // 设置扩容阙值 ,默认为 len * 2 / 3
  32. setThreshold(INITIAL_CAPACITY);
  33. }
  34. // 设置阙值
  35. private void setThreshold(int len) {
  36. threshold = len * 2 / 3;
  37. }

所以,set方法主要流程为:

  1. 获取当前线程的 ThreadLocalMap
  2. 获取得到,调用ThreadLocalMap 的set方法设置值
  3. 获取不到,调用createMap方法创建ThreadLocalMap

看起来并不复杂,其实并不然,复杂的逻辑在ThreadLocalMapset方法里

  1. private void set(ThreadLocal<?> key, Object value) {
  2. // 获取Entry表
  3. Entry[] tab = table;
  4. // 获取表长度
  5. int len = tab.length;
  6. // 获取数组下标 ,hashcode 与 (len-1)
  7. int i = key.threadLocalHashCode & (len-1);
  8. for (Entry e = tab[i];
  9. e != null;
  10. e = tab[i = nextIndex(i, len)]) {
  11. ThreadLocal<?> k = e.get();
  12. // 找到key相同的就更新value的值
  13. if (k == key) {
  14. e.value = value;
  15. return;
  16. }
  17. // key为null,说明key过期了,被gc回收
  18. if (k == null) {
  19. // 初始化探测式清理的起始位置,替换过期元素
  20. replaceStaleEntry(key, value, i);
  21. return;
  22. }
  23. }
  24. // 没有找到key相等的entry,而且没有key过期的entry,新建一个entry
  25. tab[i] = new Entry(key, value);
  26. // 存放元素数量+1
  27. int sz = ++size;
  28. if (!cleanSomeSlots(i, sz) && sz >= threshold)
  29. rehash();
  30. }

replaceStaleEntry方法:

  1. private void replaceStaleEntry(ThreadLocal<?> key, Object value,
  2. int staleSlot) {
  3. // 获取Entry表
  4. Entry[] tab = table;
  5. // Entry表长度
  6. int len = tab.length;
  7. Entry e;
  8. // 定义探测式清理起始位置
  9. int slotToExpunge = staleSlot;
  10. // 从staleSlot开始遍历查找是否有key为null的,有就更新slaleSlot
  11. for (int i = prevIndex(staleSlot, len);
  12. (e = tab[i]) != null;
  13. i = prevIndex(i, len))
  14. if (e.get() == null)
  15. slotToExpunge = i;
  16. // staleSlot开始向后循环
  17. for (int i = nextIndex(staleSlot, len);
  18. (e = tab[i]) != null;
  19. i = nextIndex(i, len)) {
  20. ThreadLocal<?> k = e.get();
  21. // 如果找到key相同的entry,就替换staleSlot和i的位置,更新value的值
  22. if (k == key) {
  23. e.value = value;
  24. // 替换staleSlot和i的位置
  25. tab[i] = tab[staleSlot];
  26. // 更新value的值
  27. tab[staleSlot] = e;
  28. // 向前循环的没有查找到key过期的entry,更新slotToExpunge值
  29. if (slotToExpunge == staleSlot)
  30. slotToExpunge = i;
  31. // 会调用启动式过期清理,先会进行一遍过期元素探测操作
  32. cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
  33. return;
  34. }
  35. // 没找到过期的key,更新slotToExpunge
  36. if (k == null && slotToExpunge == staleSlot)
  37. slotToExpunge = i;
  38. }
  39. // 找到Entry为null的数据,将数据放入该槽位
  40. tab[staleSlot].value = null;
  41. tab[staleSlot] = new Entry(key, value);
  42. // 从staleSlot开始向前迭代查找有key=null的entry
  43. if (slotToExpunge != staleSlot)
  44. // 调用启动式过期清理,先会进行一次过期元素探测,如果发现了有过期的数据就会先进行探测式清理
  45. cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
  46. }

探测式清理:

  1. private int expungeStaleEntry(int staleSlot) {
  2. Entry[] tab = table;
  3. int len = tab.length;
  4. // 将起始位置置空
  5. tab[staleSlot].value = null;
  6. tab[staleSlot] = null;
  7. // 元素数量减1
  8. size--;
  9. Entry e;
  10. int i;
  11. for (i = nextIndex(staleSlot, len);
  12. (e = tab[i]) != null;
  13. i = nextIndex(i, len)) {
  14. ThreadLocal<?> k = e.get();
  15. // key为null,说明过期了,被GC回收
  16. if (k == null) {
  17. // 清空元素,并减1
  18. e.value = null;
  19. tab[i] = null;
  20. size--;
  21. } else {
  22. // key没有过期,则重新计算hash,重新获取下标
  23. int h = k.threadLocalHashCode & (len - 1);
  24. if (h != i) {
  25. // i位置槽置空
  26. tab[i] = null;
  27. // Unlike Knuth 6.4 Algorithm R, we must scan until
  28. // null because multiple entries could have been stale.
  29. // 寻找离冲突key所在entry最近的空槽,放入该槽
  30. while (tab[h] != null)
  31. h = nextIndex(h, len);
  32. tab[h] = e;
  33. }
  34. }
  35. }
  36. return i;
  37. }

启动式清理:

  1. private boolean cleanSomeSlots(int i, int n) {
  2. boolean removed = false;
  3. Entry[] tab = table;
  4. int len = tab.length;
  5. do {
  6. // 从下一个位置开始
  7. i = nextIndex(i, len);
  8. Entry e = tab[i];
  9. // 遍历到key==null的Entry
  10. if (e != null && e.get() == null) {
  11. // 重置n
  12. n = len;
  13. // 标志有清理元素
  14. removed = true;
  15. // 清理
  16. i = expungeStaleEntry(i);
  17. }
  18. } while ( (n >>>= 1) != 0); // log(n) 限制 对数次
  19. return removed;
  20. }

7、ThreadLocal.get () 方法源码实现

  1. public T get() {
  2. // 获取当前线程
  3. Thread t = Thread.currentThread();
  4. // 获取当前线程的ThreadLocalMap
  5. ThreadLocalMap map = getMap(t);
  6. if (map != null) {
  7. // map获取得到,返回value
  8. ThreadLocalMap.Entry e = map.getEntry(this);
  9. if (e != null) {
  10. @SuppressWarnings("unchecked")
  11. T result = (T)e.value;
  12. return result;
  13. }
  14. }
  15. // 未找到的话,则调用setInitialValue()方法设置null
  16. return setInitialValue();
  17. }
  18. private Entry getEntry(ThreadLocal<?> key) {
  19. int i = key.threadLocalHashCode & (table.length - 1);
  20. Entry e = table[i];
  21. // key相等直接返回
  22. if (e != null && e.get() == key)
  23. return e;
  24. else
  25. // key不相等,调用getEntryAfterMiss()方法
  26. return getEntryAfterMiss(key, i, e);
  27. }
  28. private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
  29. Entry[] tab = table;
  30. int len = tab.length;
  31. // 迭代往后查找key相等的entry
  32. while (e != null) {
  33. ThreadLocal<?> k = e.get();
  34. if (k == key)
  35. return e;
  36. // 遇到key=null的entry,先进行探测式清理工作
  37. if (k == null)
  38. expungeStaleEntry(i);
  39. else
  40. i = nextIndex(i, len);
  41. e = tab[i];
  42. }
  43. return null;
  44. }

8、ThreadLocal的扩容机制

当散列数组中元素已经超过扩容阙值 len*2/3,会进行扩容

  1. if (!cleanSomeSlots(i, sz) && sz >= threshold)
  2. rehash();

扩容机制核心方法:

  1. private void rehash() {
  2. //先进行探测式清理工作
  3. expungeStaleEntries();
  4. //探测式清理完毕之后 如果size >= threshold - threshold / 4(也就是 size >= len * 1/2),则扩容
  5. if (size >= threshold - threshold / 4)
  6. resize();
  7. }
  8. private void expungeStaleEntries() {
  9. Entry[] tab = table;
  10. int len = tab.length;
  11. for (int j = 0; j < len; j++) {
  12. Entry e = tab[j];
  13. if (e != null && e.get() == null)
  14. expungeStaleEntry(j);
  15. }
  16. }

所以,主要流程是:

  1. 先进行探测式清理工作
  2. 探测式清理完毕之后 如果size >= threshold - threshold / 4(也就是 size >= len * 1/2),则扩容
  1. private void resize() {
  2. Entry[] oldTab = table;
  3. int oldLen = oldTab.length;
  4. // tab 的大小变为原先的两倍 oldLen * 2
  5. int newLen = oldLen * 2;
  6. Entry[] newTab = new Entry[newLen];
  7. int count = 0;
  8. // 遍历生成新的散列表
  9. for (int j = 0; j < oldLen; ++j) {
  10. Entry e = oldTab[j];
  11. if (e != null) {
  12. ThreadLocal<?> k = e.get();
  13. if (k == null) {
  14. e.value = null;
  15. } else {
  16. // entry表下标
  17. int h = k.threadLocalHashCode & (newLen - 1);
  18. while (newTab[h] != null)
  19. h = nextIndex(h, newLen);
  20. newTab[h] = e;
  21. count++;
  22. }
  23. }
  24. }
  25. // 重新计算扩容阙值
  26. setThreshold(newLen);
  27. size = count;
  28. table = newTab;
  29. }

9、ThreadLocal.remove()方法实现

  1. public void remove() {
  2. // 获取当前线程的ThreadLocalMap
  3. ThreadLocalMap m = getMap(Thread.currentThread());
  4. if (m != null)
  5. m.remove(this);
  6. }
  7. private void remove(ThreadLocal<?> key) {
  8. Entry[] tab = table;
  9. int len = tab.length;
  10. // 获取Entry下标
  11. int i = key.threadLocalHashCode & (len-1);
  12. // 从hash获取的下标开始,寻找key相等的entry元素清除
  13. for (Entry e = tab[i];
  14. e != null;
  15. e = tab[i = nextIndex(i, len)]) {
  16. if (e.get() == key) {
  17. e.clear();
  18. // 进行探测式清理工作
  19. expungeStaleEntry(i);
  20. return;
  21. }
  22. }
  23. }

10、如何正确使用ThreadLocal

前面已经对ThreadLocal进行了浅显的分析,然后在实际工作中如何使用ThreadLocal?

在ThreadLocal源码的注释里,作者已经给出一个例子:

  1. package com.example.concurrent.threadlocal;
  2. import java.util.concurrent.atomic.AtomicInteger;
  3. public class ThreadId {
  4. // Atomic integer containing the next thread ID to be assigned
  5. private static final AtomicInteger nextId = new AtomicInteger(0);
  6. // Thread local variable containing each thread's ID
  7. private static final ThreadLocal<Integer> threadId =
  8. new ThreadLocal<Integer>() {
  9. @Override
  10. protected Integer initialValue() {
  11. return nextId.getAndIncrement();
  12. }
  13. };
  14. // Returns the current thread's unique ID, assigning it if necessary
  15. public static int get() {
  16. return threadId.get();
  17. }
  18. }

我们复制例子运行一下,例子也比较简单,是通过原子类加上ThreadLocal实现的线程安全的计数例子,然后ThreadLocal如何正确使用?

  1. 使用ThreadLocal时候,最好声明为static的
  2. 使用ThreadLocal之后,记得手动调用remove方法

为什么要使用remove?在阿里编程规范里也说明了不remove可能会造成内存泄漏问题,不正确使用可能造成:

  1. 内存被占用
  2. 内存泄漏
  3. 线程被复用的情况,比如使用线程池或者是在web容器线程池中的线程,都可能会造成使用遗留的脏数据,影响业务逻辑。
    所以正确的使用规范:
  1. private static final ThreadLocal<?> threadLocal = new ThreadLocal<>();
  2. try {
  3. threadLocal.set(a);
  4. //执行业务逻辑,逻辑中 get()值
  5. }finally{
  6. //确保用完后,清除
  7. threadLocal.remove();
  8. }
  • InheritableThreadLocal
    在实际的使用中,可能会遇到,子线程获取父线程里创建的ThreadLocal对象的数据,不过ThreadLocal是不支持这种情况,需要使用InheritableThreadLocal
  1. package com.example.concurrent.threadlocal;
  2. public class InheritableThreadLocalSample {
  3. public static void main(String[] args) {
  4. ThreadLocal<String> t1 = new ThreadLocal<>();
  5. InheritableThreadLocal<String> t2 = new InheritableThreadLocal<>();
  6. t1.set("test1");
  7. t2.set("test2");
  8. new Thread(()->{
  9. System.out.println(String.format("获取ThreadLocal数据 %s" , t1.get()));
  10. System.out.println(String.format("获取InheritableThreadLocal数据 %s" , t2.get()));
  11. }).start();
  12. }
  13. }

获取ThreadLocal数据 null
获取InheritableThreadLocal数据 test2

参考资料

相关文章