本文为廖雪峰老师的手写Springopen in new window笔记中的JDBC和事务一节,仅作个人学习使用。

我们在本节会完成一个JdbcTemplate,可以覆盖绝大多数数据库操作。

同时会实现声明式事务@Transcation

JdbcTemplate

我们首先提供一个默认的数据源以及数据库连接池HikariCP

配置项在yml上长这样:

summer:
  datasource:
    url: 
    driver-class-name: 
    username: 
    password: 

我们创建一个JdbcConfiguration,用来管理这个包下的所有要注册的Bean,并将默认的数据数据源注入:

@Configuration
public class JdbcConfiguration {

    @Bean(destroyMethod = "close")
    DataSource dataSource(
            // properties:
            @Value("${summer.datasource.url}") String url,
            @Value("${summer.datasource.username}") String username,
            @Value("${summer.datasource.password}") String password,
            @Value("${summer.datasource.driver-class-name:}") String driver,
            @Value("${summer.datasource.maximum-pool-size:20}") int maximumPoolSize,
            @Value("${summer.datasource.minimum-pool-size:1}") int minimumPoolSize,
            @Value("${summer.datasource.connection-timeout:30000}") int connTimeout
    ) {
        var config = new HikariConfig();
        config.setAutoCommit(false);
        config.setJdbcUrl(url);
        config.setUsername(username);
        config.setPassword(password);
        if (driver != null) {
            config.setDriverClassName(driver);
        }
        config.setMaximumPoolSize(maximumPoolSize);
        config.setMinimumIdle(minimumPoolSize);
        config.setConnectionTimeout(connTimeout);
        return new HikariDataSource(config);
    }
}

然后开始创建我们的JdbcTemplate。

JdbcTemplate中使用到了大量的基于回调函数模板方法,并且设计到了许多函数式接口,我们在这里对重要的类先展开说明。

RowMapper

RowMapper是一个用于将ResultSet转换为指定类型的FunctionalInterface,

FunctionalInterface首先是一个接口,然后就是在这个接口里面只能有一个抽象方法

RowMapper的定义如下:

@FunctionalInterface
public interface RowMapper<T> {
    @Nullable
    T mapRow(ResultSet rs, int rowNum) throws SQLException;
}

我们可以基于它创建很多子类,如

class StringRowMapper implements RowMapper<String> {
    static StringRowMapper instance = new StringRowMapper();
    @Override
    public String mapRow(ResultSet rs, int rowNum) throws SQLException {
        // 将结果集合中的第一列数据转换为String类型
        return rs.getString(1);
    }
}

代码中提供了StringRowMapper,NumberRowMapper,BooleanRowMapper,BeanRowMapper。前三种都是转换为一些基础类型,而最后一种是转换为自定义的Bean。

public class BeanRowMapper<T> implements RowMapper<T> {

    final Logger logger = LoggerFactory.getLogger(getClass());
    // Bean的Class
    Class<T> clazz;
    // Bean的构造器
    Constructor<T> constructor;
    // Bean的字段
    Map<String, Field> fields = new HashMap<>();
    // Bean的set方法
    Map<String, Method> methods = new HashMap<>();

    public BeanRowMapper(Class<T> clazz) {
        this.clazz = clazz;
        try {
            this.constructor = clazz.getConstructor();
        } catch (ReflectiveOperationException e) {
            throw new DataAccessException();
        }
        for (Field f : clazz.getFields()) {
            String name = f.getName();
            this.fields.put(name, f);
        }
        for (Method m : clazz.getMethods()) {
            Parameter[] ps = m.getParameters();
            if (ps.length == 1) {
                String name = m.getName();
                // 到这一步相当于只向methods中放入以set开头并且参数个数为一的方法
                if (name.length() >= 4 && name.startsWith("set")) {
                    String prop = Character.toLowerCase(name.charAt(3)) + name.substring(4);
                    this.methods.put(prop, m);
                }
            }
        }
    }

    @Override
    public T mapRow(ResultSet rs, int rowNum) throws SQLException {
        T bean;
        try {
            bean = this.constructor.newInstance();
            ResultSetMetaData meta = rs.getMetaData();
            int columns = meta.getColumnCount();
            // 遍历MetaData中的所有列,和Methods做匹配
            for (int i = 1; i <= columns; i++) {
                String label = meta.getColumnLabel(i);
                Method method = this.methods.get(label);
                if (method != null) {
                    // 执行set方法
                    method.invoke(bean, rs.getObject(label));
                } else {
                    // 如果没有该字段的set方法,尝试直接向field直接set值
                    Field field = this.fields.get(label);
                    if (field != null) {
                        field.set(bean, rs.getObject(label));
                    }
                }
            }
        } catch (ReflectiveOperationException e) {
            throw new DataAccessException();
        }
        return bean;
    }
}

CallBack

接着我们来看一下代码中定义的回调函数:

  • ConnectionCallback:定义拿到Connection后的行为(一般都是创建PreparedStatement,执行PreparedStatementCallback#doInPreparedStatement)

    @FunctionalInterface
    public interface ConnectionCallback<T> {
        @Nullable
        T doInConnection(Connection con) throws SQLException;
    }
    
  • PreparedStatementCallback:用于定义SQL的类型,也就是拿到PreparedStatement后具体的执行逻辑,可以是查询,可以是其他的语句,由我们自己定义。

    @FunctionalInterface
    public interface PreparedStatementCallback<T> {
        @Nullable
        T doInPreparedStatement(PreparedStatement ps) throws SQLException;
    }
    
  • PreparedStatementCreator:创建PreparedStatement,可以选择是否返回自动增长的主键(主要是在Insert中使用,同时MySQL必须设置自动增长)

    @FunctionalInterface
    public interface PreparedStatementCreator {
        PreparedStatement createPreparedStatement(Connection con) throws SQLException;
    }
    

Template

现在我们来从上往下看JdbcTemplate的核心代码:

  • 首先我们需要注入DataSource,放入成员变量中。
  • preparedStatementCreator方法规定了一个最基础的PreparedStatementCreator方法,在其中对sql进行预编译(放置sql注入,提升效率),向PreparedStatement动态绑定参数。
  • 最核心的就是execute(ConnectionCallback<T> action)方法,每次获取一个Connection,然后执行ConnectionCallback#doInConnection方法,相当于把业务方法抽象出来,在这个方法中专心的做有关连接的事情。
  • 然后就是execute(PreparedStatementCreator psc, PreparedStatementCallback<T> action)方法,它需要传入这两个回调函数,将PreparedStatement的生成和具体要执行的操作隔离起来。
  • 其他的方法就都是在此之上构建出来的,对于基础类型,通过Class对象获取到对应的RowMapper,对于Bean类型我们创建一个指定的BeanRowMapper。
public class JdbcTemplate {

    final DataSource dataSource;

    public JdbcTemplate(DataSource dataSource) {
        this.dataSource = dataSource;
    }
    private PreparedStatementCreator preparedStatementCreator(String sql, Object... args) {
        return (Connection con) -> {
            var ps = con.prepareStatement(sql);
            bindArgs(ps, args);
            return ps;
        };
    }

    private void bindArgs(PreparedStatement ps, Object... args) throws SQLException {
        for (int i = 0; i < args.length; i++) {
            ps.setObject(i + 1, args[i]);
        }
    }
}

     public <T> T execute(ConnectionCallback<T> action) throws DataAccessException {
        // 获取新连接:
        try (Connection newConn = dataSource.getConnection()) {
            final boolean autoCommit = newConn.getAutoCommit();
            if (!autoCommit) {
                newConn.setAutoCommit(true);
            }
            T result = action.doInConnection(newConn);
            if (!autoCommit) {
                newConn.setAutoCommit(false);
            }
            return result;
        } catch (SQLException e) {
            throw new DataAccessException(e);
        }
    }
    public <T> T execute(PreparedStatementCreator psc, PreparedStatementCallback<T> action) {
        return execute((Connection con) -> {
            try (PreparedStatement ps = psc.createPreparedStatement(con)) {
                return action.doInPreparedStatement(ps);
            }
        });
    }
    
    public <T> T queryForObject(String sql, RowMapper<T> rowMapper, Object... args) throws DataAccessException {
        return execute(preparedStatementCreator(sql, args),
                // PreparedStatementCallback
                (PreparedStatement ps) -> {
                    T t = null;
                    try (ResultSet rs = ps.executeQuery()) {
                        while (rs.next()) {
                            if (t == null) {
                                t = rowMapper.mapRow(rs, rs.getRow());
                            } else {
                                throw new DataAccessException("Multiple rows found.");
                            }
                        }
                    }
                    if (t == null) {
                        throw new DataAccessException("Empty result set.");
                    }
                    return t;
                });
    }

    @SuppressWarnings("unchecked")
    public <T> T queryForObject(String sql, Class<T> clazz, Object... args) throws DataAccessException {
        if (clazz == String.class) {
            return (T) queryForObject(sql, StringRowMapper.instance, args);
        }
        if (clazz == Boolean.class || clazz == boolean.class) {
            return (T) queryForObject(sql, BooleanRowMapper.instance, args);
        }
        if (Number.class.isAssignableFrom(clazz) || clazz.isPrimitive()) {
            return (T) queryForObject(sql, NumberRowMapper.instance, args);
        }
        return queryForObject(sql, new BeanRowMapper<>(clazz), args);
    }

    public <T> List<T> queryForList(String sql, RowMapper<T> rowMapper, Object... args) throws DataAccessException {
        return execute(preparedStatementCreator(sql, args),
                // PreparedStatementCallback
                (PreparedStatement ps) -> {
                    List<T> list = new ArrayList<>();
                    try (ResultSet rs = ps.executeQuery()) {
                        while (rs.next()) {
                            list.add(rowMapper.mapRow(rs, rs.getRow()));
                        }
                    }
                    return list;
                });
    }

    public <T> List<T> queryForList(String sql, Class<T> clazz, Object... args) throws DataAccessException {
        return queryForList(sql, new BeanRowMapper<>(clazz), args);
    }

    public Number updateAndReturnGeneratedKey(String sql, Object... args) throws DataAccessException {
        return execute(
                // PreparedStatementCreator
                (Connection con) -> {
                    var ps = con.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS);
                    bindArgs(ps, args);
                    return ps;
                },
                // PreparedStatementCallback
                (PreparedStatement ps) -> {
                    int n = ps.executeUpdate();
                    if (n == 0) {
                        throw new DataAccessException("0 rows inserted.");
                    }
                    if (n > 1) {
                        throw new DataAccessException("Multiple rows inserted.");
                    }
                    try (ResultSet keys = ps.getGeneratedKeys()) {
                        while (keys.next()) {
                            return (Number) keys.getObject(1);
                        }
                    }
                    throw new DataAccessException("Should not reach here.");
                });
    }

    public int update(String sql, Object... args) throws DataAccessException {
        return execute(preparedStatementCreator(sql, args),
                // PreparedStatementCallback
                (PreparedStatement ps) -> {
                    return ps.executeUpdate();
                });
    }
}

声明式事务

在这里我们使用@Transaction来定义声明式事务,同时只支持REQUIRED传播模式即如果当前连接没有事务,声明一个事务,如果由事务,那么就加入。对于隔离级别只能采用数据库默认的隔离级别。

首先定义注解:

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Inherited
public @interface Transactional {
    String value() default "platformTransactionManager";
}

其中@Inherited指明注解可被子类继承。

platformTransactionManager说明了默认使用platformTransactionManager为增强逻辑。

所以我们创建一个接口:

public interface PlatformTransactionManager {
}

同时创建一个实现类:

  • 我们使用ThreadLocal存储当前线程的事务状态,后期可以通过TransactionStatus来实现传播方式。
  • 在invoke中,我们首先通过ThreadLocal获取到当前的事务状态,如果存在事务,那么直接执行方法(加入),如果不存在事务,则创建一个事务。
  • 操作执行完毕后,执行commit操作。如果有异常抛出则回滚。
  • 最后将ThreadLocal清除。
public class TransactionStatus {
    final Connection connection;
    public TransactionStatus(Connection connection) {
        this.connection = connection;
    }
}

public class DataSourceTransactionManager implements PlatformTransactionManager, InvocationHandler {

    static final ThreadLocal<TransactionStatus> transactionStatus = new ThreadLocal<>();

    final Logger logger = LoggerFactory.getLogger(getClass());

    final DataSource dataSource;

    public DataSourceTransactionManager(DataSource dataSource) {
        this.dataSource = dataSource;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        TransactionStatus ts = transactionStatus.get();
        if (ts == null) {
            // start new transaction:
            try (Connection connection = dataSource.getConnection()) {
                final boolean autoCommit = connection.getAutoCommit();
                if (autoCommit) {
                    connection.setAutoCommit(false);
                }
                try {
                    transactionStatus.set(new TransactionStatus(connection));
                    Object r = method.invoke(proxy, args);
                    connection.commit();
                    return r;
                } catch (InvocationTargetException e) {
                    TransactionException te = new TransactionException(e.getCause());
                    try {
                        connection.rollback();
                    } catch (SQLException sqle) {
                        te.addSuppressed(sqle);
                    }
                    throw te;
                } finally {
                    transactionStatus.remove();
                    if (autoCommit) {
                        connection.setAutoCommit(true);
                    }
                }
            }
        } else {
            // join current transaction:
            return method.invoke(proxy, args);
        }
    }
}

为了实现加入事务的功能,我们需要对execute进行改造,如果当前存在一个事务连接,那么直接使用当前连接:

public <T> T execute(ConnectionCallback<T> action) throws DataAccessException {
    // 尝试获取当前事务连接:
    Connection current = TransactionalUtils.getCurrentConnection();
    if (current != null) {
        try {
            return action.doInConnection(current);
        } catch (SQLException e) {
            throw new DataAccessException(e);
        }
    }
    // 获取新连接:
    try (Connection newConn = dataSource.getConnection()) {
        final boolean autoCommit = newConn.getAutoCommit();
        if (!autoCommit) {
            newConn.setAutoCommit(true);
        }
        T result = action.doInConnection(newConn);
        if (!autoCommit) {
            newConn.setAutoCommit(false);
        }
        return result;
    } catch (SQLException e) {
        throw new DataAccessException(e);
    }
}
public class TransactionalUtils {
    @Nullable
    public static Connection getCurrentConnection() {
        TransactionStatus ts = DataSourceTransactionManager.transactionStatus.get();
        return ts == null ? null : ts.connection;
    }
}
上次更新:
Contributors: YangZhang