本人微信公众号"aeolian"~

架构探险笔记6-ThreadLocal简介

什么是ThreadLocal?

ThreadLocal直译为“线程本地”或“本地线程”,如果真的这么认为,那就错了!其实它就是一个容器,用于存放线程的局部变量,应该叫ThreadLocalVariable(线程局部变量)才对。

早在JDK1.2的时代,java.lang.ThreadLocal就诞生了,它是为了解决多线程并发问题而设计的,只不过设计得有些难用而已,所以至今没有得到广泛的应用。

一个序列号生成器的程序可能同时会有多个线程并发访问它,要保证每个线程得到的序列号都是自增的,而补鞥呢互相干扰。

先定义一个接口:

public interface Sequence {
    int getNumber();
}

每次调用getNumber方法可获取一个序列号,下次再调用时,序列号会自增。

在做一个线程类:

public class ClientThread extends Thread{
    private Sequence sequence;

    public ClientThread(Sequence sequence) {
        this.sequence = sequence;
    }

    @Override
    public void run() {
        for (int i=0;i<3;i++){
            System.out.println(Thread.currentThread().getName() + " =>"+sequence.getNumber());
        }
    }
}

在线程中连续输出三次线程名与其对应的序列号。

我们不用ThreadLocal,先做一个实现类:

public class SequenceA implements Sequence {
    private static int number = 0;

    @Override
    public int getNumber() {
        number = number +1;
        return number;
    }

    public static void main(String[] args) {
        Sequence sequence = new SequenceA();
        ClientThread thread1 = new ClientThread(sequence);
        ClientThread thread2 = new ClientThread(sequence);
        ClientThread thread3 = new ClientThread(sequence);

        thread1.start();
        thread2.start();
        thread3.start();
    }
}

序列号初始值是0,在main方法中模拟了三个线程,运行后结果如下:

《架构探险笔记6-ThreadLocal简介》

分析发现,线程之间共享的static变量无法保证对于不同线程而言是安全的,也就是说,此时无法保证”线程安全”。

那么如何才能做到“线程安全”呢?对于这个案例,就是说不同的线程可拥有自己的static变量,如何实现呢?下面看另一个实现:

public class SequenceB implements Sequence {
    //private static int number = 0;
    private static ThreadLocal numberContainer = new ThreadLocal(){

        @Override
        protected Integer initialValue() {
            return 0;
        }
    };
    @Override
    public int getNumber() {
        numberContainer.set(numberContainer.get()+1);
        return numberContainer.get();
    }

    public static void main(String[] args) {
        Sequence sequence = new SequenceB();
        ClientThread thread1 = new ClientThread(sequence);
        ClientThread thread2 = new ClientThread(sequence);
        ClientThread thread3 = new ClientThread(sequence);

        thread1.start();
        thread2.start();
        thread3.start();
    }
}

通过ThreadLocal封装了一个Integer类型的numberContainer静态成员变量,并且初始值是0。再看getNumber方法,首先从numberContainer中get出当前的值,加1,随后set到numberContainer中,最后在numberContainer中get出当前的值并返回。

是不是很绕?但是很强大!我们不妨把ThreadLocal看作是一个容器,这样理解起来就简单了。所以,这里故意用了Container这个词作为后缀来命名ThreadLocal变量。

《架构探险笔记6-ThreadLocal简介》

每个线程独立了,同样是static变量,对于不同的线程而言,它没有被共享,而是每个线程各一份,这样也就保证了线程安全。也就是说,ThreadLocal为每一个线程提供了一个独立的副本。

搞清楚ThreadLocal的原理后,总结一下API:

public void set(T value):将值放入线程局部变量中;

public T get():从线程局部变量中获取值;

public void remove():从线程局部变量中移除值(有助于JVM垃圾回收);

protected T initialValue():返回线程局部变量中的初始值(默认为null)。

为什么initialValue方法是protected的呢?就是为了提醒程序员,这个方法是要程序员来实现的,要给这个线程局部变量设置一个初始值。

自己实现ThreadLocal

熟悉了原理之后与这些API之后,可以想想ThreadLocal里面不就是封装了一个Map吗?我们自己可以写一个ThreadLocal了:

package com.autumn.threadlocal;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

/**
 * @program: MyThreadLocal
 * @description: 模式ThreadLocal
 * @author: Created by Autumn
 * @create: 2018-11-21 17:09
 */
public class MyThreadLocal {
    private Map container = Collections.synchronizedMap(new HashMap());

    public void set(T value){
        container.put(Thread.currentThread(),value);
    }

    public T get(){
        Thread thread = Thread.currentThread();
        T value = container.get(thread);
        if (value == null && !container.containsKey(thread)){
            value = initialValue();
            container.put(thread,value);
        }
        return value;
    }

     public void remove(){
        container.remove(Thread.currentThread());
    }

    protected T initialValue(){
        return null;
    }
}

上面定义了一个山寨版的ThreadLocal,其中定义了一个同步Map(这个操作会在map上加锁)

写个类运行一下

/**
 * @program: SequenceB
 * @description: 用ThreadLocal实现线程共享
 * @author: Created by Autumn
 * @create: 2018-11-21 15:45
 */
public class SequenceC implements Sequence {
    //private static int number = 0;
    private static MyThreadLocal numberContainer = new MyThreadLocal(){
        @Override
        protected Integer initialValue() {
            return 0;
        }
    };
    @Override
    public int getNumber() {
        numberContainer.set(numberContainer.get()+1);
        return numberContainer.get();
    }

    public static void main(String[] args) {
        Sequence sequence = new SequenceC();
        ClientThread thread1 = new ClientThread(sequence);
        ClientThread thread2 = new ClientThread(sequence);
        ClientThread thread3 = new ClientThread(sequence);

        thread1.start();
        thread2.start();
        thread3.start();
    }
}

返回结果

《架构探险笔记6-ThreadLocal简介》

只是把ThreadLocal换成了MyThreadLocal而已,运行效果和之前的一样,也是正确的。

提示:当在一个类中使用了static成员变量的时候,一定要多问问自己,这个static成员变量考虑“线程安全”了吗?也就是说,多个线程需要独享自己的static成员变量吗?如果需要考虑,不妨用ThreadLocal

ThreadLocal使用例子

ThreadLocal具体有哪些使用案例呢?

首先要说的就是通过ThreadLocal存放JDBC Connection,以达到事务控制的能力。

记得在很久以前,用户提出过一个需求,需求就很繁琐,就一句话:

当修改产品价格的时候,需要记录操作日志,什么时候做了什么事情。

想必这个案例,只要是做个应用系统的小伙伴都应该遇到过。不外乎数据库里就两张表:product与log,用两条sql语句应该就可以解决问题:

update product set price = ? and id = ?
insert into log(created,description) values(?,?)

但要确保这两条sql语句必须在同一个事务里进行提交,否则有可能update提交了,但是insert却没有提交。

为了解决这个问题,首先我们写一个DBUtil的工具类

/**
 * @program: DBUtil
 * @description: 数据库配置工具类
 * @author: qiuyu
 * @create: 2018-11-28 05:52
 **/
public class DBUtil {
    private static final Logger LOGGER = LoggerFactory.getLogger(DBUtil.class);
    //数据库配置
    private static final String DRIVER = "com.mysql.jdbc.Driver";
    private static final String URL = "jdbc:mysql://222.222.221.198:3306/customer";
    private static final String USERNAME = "root";
    private static final String PASSWORD = "root";

    //定义一个数据库连接
    private static Connection conn = null;

    /**
     * 获取数据库连接
     * @return
     */
    public static Connection getConnection(){
        try {
            /*JDBC获取连接*/
            Class.forName(DRIVER);
            conn = DriverManager.getConnection(URL,USERNAME,PASSWORD);
        } catch (Exception e) {
            e.printStackTrace();  //在catlina.out中打印
            LOGGER.error("get connection failure",e);
        }
        return conn;
    }

    /**
     * 关闭数据库连接
     * @param conn
     */
    public static void closeConnection(Connection conn){
        if (conn!=null){
            try {
                conn.close();
            } catch (SQLException e) {
                e.printStackTrace();
                LOGGER.error("close connection failure",e);
            }
        }
    }
}

里面设置了一个static的Connection,这下数据库连接就好操作了。

然后定义一个借口用于逻辑层调用:

/**
 * 接口 - 更新数据添加日志表记录
 */
public interface ProductService {
    void updateProductPrice(long id,int price);
}

根据productId去更新对应的Product的price,然后再插入一条数据到log表中。

实现类

/**
 * @program: ProductServiceImpl
 * @description: ProductService实现类
 * @author: qiuyu
 * @create: 2018-11-28 06:04
 **/
public class ProductServiceImpl implements ProductService{
    private static final String UPDATE_PRODUCT_SQL = "update product set price = ? where id = ?";
    private static final String INSERT_LOG_SQL = "insert into log(createid,description) value (?,?)";

    @Override
    public void updateProductPrice(long id, int price) {
        try {
            //获取连接
            Connection conn = DBUtil.getConnection();
            conn.setAutoCommit(false);   //关闭自动提交事物(开启事物)

            //执行操作
            updateProduct(conn,UPDATE_PRODUCT_SQL,id,price);   //更新产品
            insertLog(conn,INSERT_LOG_SQL,"create product.");    //插入日志

            //提交事物
            conn.commit();
        }catch (Exception e){
            e.printStackTrace();
        }finally {
            DBUtil.closeConnection();  //关闭连接
        }
    }

    private void updateProduct(Connection conn,String updateProdutSQL,long productId,int productPrice) throws SQLException {
        PreparedStatement pstmt = conn.prepareStatement(updateProdutSQL);
        pstmt.setInt(1,productPrice);
        pstmt.setLong(2,productId);
        int rows = pstmt.executeUpdate();
        if (rows != 0){
            System.out.println("Update Product Success!");
        }
    }

    private void insertLog(Connection conn,String insertLogSQL,String logDescription) throws SQLException {
        PreparedStatement pstmt = conn.prepareStatement(insertLogSQL);
        pstmt.setString(1,new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()));
        pstmt.setString(2,logDescription);
        int rows = pstmt.executeUpdate();
        if (rows != 0){
            System.out.println("Insert log Success!");
        }
    }

}

这里用到了JDBC的高级特性Transaction。暗自庆幸了一番后,是不是有必要写一个客户端来测试一下执行结果是不是我想要的呢?于是偷懒,直接在ProductServiceImpl中加了一个main方法:

    public static void main(String[] args) {
        ProductService service = new ProductServiceImpl();
        service.updateProductPrice(1,3000);
    }

运行程序

 《架构探险笔记6-ThreadLocal简介》

作为一名专业的程序员,为了万无一失,我一定要到数据库里再看看。没错!product表对应的记录更新了,log表也插入了一条记录。这样就可以将ProductService接口交付给别人来调用了。

几个小时过去了,QA妹妹开始对着我嚷:“那谁!我刚才模拟10个请求,你这个接口怎么就挂了?报错说是数据库连接关闭了!”。

她是用工具模拟的,也就是模拟多个线程了!那我也可以模拟,于是写了一个线程类:

/**
 * @program: ClientThread
 * @description: 线程类
 * @author: qiuyu
 * @create: 2018-11-28 07:16
 **/
public class ClientThread extends Thread {
    private ProductService productService;

    public ClientThread(ProductService productService) {
        this.productService = productService;
    }

    @Override
    public void run() {
        System.out.println(Thread.currentThread().getName());
        productService.updateProductPrice(1,3000);
    }
}

用这个线程去调用ProductService的方法,看看是不是有问题。此时,还要再修改一下main方法:

    public static void main(String[] args) {
        /*调用*/
        /*ProductService service = new ProductServiceImpl();
        service.updateProductPrice(1,3000);*/
        /*多线程调用*/
        for (int i=1;i<10;i++){
            ProductService service = new ProductServiceImpl();
            ClientThread thread = new ClientThread(service);
            thread.start();
        }
    }

模拟十个线程,运行结果如下:

《架构探险笔记6-ThreadLocal简介》

没想到!竟然在多线程的环境下报错了,果然是数据库连接关闭了。怎么回事呢?我陷入了沉思中。在百度、Google,还有OSC上都查找了那句报错信息,解答实在是千奇百怪。

既然是跟Connection有关系,那就将主要精力放在检查Connection相关的代码上。是不是Connection不应该是static呢?当初设计成static的主要目的是为了让DBUtil的static方法访问更加方便,用static变量来存放Connection也提高了性能。怎么办呢?

后来看到OSC上非常火爆的一片文章“ThreadLocal”那点事,才终于明白了,原来要使每个线程都拥有自己的连接,而不是共享同一个连接,否则“线程一”有可能会关闭“线程二”的连接,所以“线程二”就报错了。

于是将DBUtil重构:

public class DBUtil {
    private static final Logger LOGGER = LoggerFactory.getLogger(DBUtil.class);
    //数据库配置
    private static final String DRIVER = "com.mysql.jdbc.Driver";
    private static final String URL = "jdbc:mysql://222.222.221.198:3306/demo2";
    private static final String USERNAME = "root";
    private static final String PASSWORD = "autumn-19950926";

    //定义一个数据库连接
    //private static Connection conn = null;
    //定义一个用于放置数据库连接的局部线程变量(是每个线程拥有自己的连接)
    private static ThreadLocal connContainer = new ThreadLocal();

    /**
     * 获取数据库连接
     * @return
     */
    public static Connection getConnection(){
        Connection conn = connContainer.get();   //从ThreadLocal中获取conn

        try {
            if (conn == null) {   //从ThreadLocal中拿到的conn如果为null
                /*JDBC获取连接*/
                Class.forName(DRIVER);
                conn = DriverManager.getConnection(URL, USERNAME, PASSWORD);
            }
        } catch (Exception e) {
            e.printStackTrace();  //在catlina.out中打印
            LOGGER.error("get connection failure",e);
        }finally {
            connContainer.set(conn);
        }
        return conn;
    }

    /**
     * 关闭数据库连接
     */
    public static void closeConnection(){
        Connection conn = connContainer.get();   //从ThreadLocal中获取conn
        if (conn!=null){
            try {
                conn.close();
            } catch (SQLException e) {
                e.printStackTrace();
                LOGGER.error("close connection failure",e);
            }finally {
                connContainer.remove();   //从ThreadLocal中删除当前线程的conn
            }
        }
    }
}

把Connection放入到了ThreadLocal中,这样每个线程之间就隔离了,不会互相干扰了。

此外,在getConnection方法中,首先从ThreadLocal(也就是ConnContainer)中获取Connection,如果没有,就通过JDBC来创建连接,最后再把创建好的连接放入这个ThreadLocal中。可以把ThreadLocal看作一个容器。

同样也对closeConnection方法做了重构,先从容器中获取Connection,拿到了就close掉,最后从容器中将其remove掉,以保持容器的清洁。

注意:该示例仅用于ThreadLocal的基本用法。在实际工作中,推荐使用连接池来管理数据库连接。

源码 

点赞

Leave a Reply

Your email address will not be published. Required fields are marked *