附上源码,仅供参考:此种方法是通过反射,拼接完整的merge into语句,然后通过mybatis去执行sql。
这种方法在更新时需要注意,传入的实体类对于属性值为null的,也会将原来对应的字段值修改为null。所以使用此种方法批量更新,需要将所有的值都塞入实体类。
oracle,需要主键自增的字段必须要有序列和触发器。其次,无论是db2还是oracle都需要配合@TableId注解来指定主键,实体类中多余的属性需要用@TableField(exist=flase)来排除。
import com.cmbchina.cc.mc.aptms.infrastructure.helper.McCmsDbSysParmCacheHelper; import com.cmbchina.cc.mc.aptms.infrastructure.repository.db2.BatchMapper; import com.cmbchina.cc.mc.aptms.infrastructure.util.SqlTemplateUtils; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; import java.util.List; @Service @Slf4j public class CommonBatchService { @Value("${batchSize:1000}") private int batchSize; @Autowired private BatchMapper batchMapper; @Autowired private McCmsDbSysParmCacheHelper mcCmsDbSysParmCacheHelper; publicBoolean batchSaveOrUpdate(List sources, Class clazz) { return batchSaveOrUpdate(sources, clazz, batchSize); } public Boolean batchSaveOrUpdate(List sources, Class clazz, int size) { if (CollectionUtils.isEmpty(sources)) { return false; } try { Boolean isOracleDataSource = mcCmsDbSysParmCacheHelper.isChangeDatabase(); int sourceSize = sources.size(); for (int i = 0; i < sourceSize; i = i + size) { List subList; if (i + size > sourceSize) { subList = sources.subList(i, sourceSize); } else { subList = sources.subList(i, i + size); } String batchSaveOrUpdateSql = SqlTemplateUtils.batchSaveOrUpdateSql(subList, clazz, isOracleDataSource); batchMapper.mergeInto(batchSaveOrUpdateSql); } } catch (Exception e) { log.error(e.toString()); return false; } return true; } }
import com.baomidou.mybatisplus.core.mapper.BaseMapper; import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Param; import org.apache.ibatis.annotations.Update; @Mapper public interface BatchMapper extends BaseMapper { @Update("${sql}") void mergeInto(@Param("sql") String sql); }
import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableId; import com.cmbchina.cc.mc.aptms.enums.MCSysErrorCodeType; import lombok.extern.slf4j.Slf4j; import java.lang.reflect.Field; import java.util.List; @Slf4j public class SqlTemplateUtils { public staticString batchSaveOrUpdateSql(List sources, Class clazz, Boolean isOracleDataSource) { if (sources.size() < 1) { return null; } StringBuffer finalSqlTemplate = new StringBuffer("MERGE INTO "); finalSqlTemplate.append(StringUtils.getTableName(clazz)).append(" u USING ( "); StringBuffer tableIdName = new StringBuffer(""); StringBuffer basicSql = getBasicSql(sources, clazz, tableIdName, isOracleDataSource); if (StringUtils.isEmptyString(tableIdName.toString())) { throw new RuntimeException(MCSysErrorCodeType.UN_CAUGHT_TABLE_ID.getErrorMessage()); } finalSqlTemplate.append(basicSql).append(" ) t on (").append(String.format("u.%s = t.%s) ", tableIdName, tableIdName)); StringBuffer lastSql = getLastSql(clazz, isOracleDataSource); return finalSqlTemplate.append(lastSql).toString(); } private static StringBuffer getBasicSql(List list, Class clazz, StringBuffer tableIdName, Boolean isOracleDataSource) { StringBuffer basicSql = new StringBuffer(""); Field[] declaredFields = clazz.getDeclaredFields(); int fieldsLength = declaredFields.length; String tableIdValue = ""; for (int j = 0; j < list.size(); ) { basicSql.append("SELECT "); Object o = list.get(j); for (int i = 0; i < fieldsLength; ) { Field field = declaredFields[i]; TableField tableField = field.getAnnotation(TableField.class); if (null != tableField && !tableField.exist()) { continue; } TableId tableId = field.getAnnotation(TableId.class); String tableFieldValue = getColumnByField(field); if (null != tableId) { tableIdValue = tableFieldValue; } if (StringUtils.isEmpty(tableFieldValue)) { throw new RuntimeException(MCSysErrorCodeType.UN_CAUGHT_COLUMNS.getErrorMessage()); } field.setAccessible(true); try { String simpleName = field.getType().getSimpleName(); Object fieldValue = field.get(o); if (null == fieldValue) { if (!isOracleDataSource && null != tableId) { basicSql.append("-1 AS ").append(tableFieldValue); } else { basicSql.append("NULL AS ").append(tableFieldValue); } } else if ("String".equals(simpleName)) { basicSql.append(String.format("'%s'", fieldValue)).append(" AS ").append(tableFieldValue); } else if ("Timestamp".equals(simpleName)) { basicSql.append(String.format("TIMESTAMP'%s'", fieldValue)).append(" AS ").append(tableFieldValue); } else if ("Date".equals(simpleName)) { if (isOracleDataSource) { basicSql.append(String.format("TO_DATE('%s','yyyy-mm-dd hh24:mi:ss')", fieldValue)).append(" AS ").append(tableFieldValue); } else { basicSql.append(String.format("'%s'", fieldValue)).append(" AS ").append(tableFieldValue); } } else { basicSql.append(fieldValue).append(" AS ").append(tableFieldValue); } if (++i < fieldsLength) { basicSql.append(", "); } } catch (IllegalAccessException e) { log.error(e.toString()); } } if (isOracleDataSource) { basicSql.append(" FROM DUAL"); } else { basicSql.append(" FROM sysibm.sysdummy1"); } if (++j < list.size()) { basicSql.append(" UNIOn ALL "); } } tableIdName.append(tableIdValue); return basicSql; } private static StringBuffer getLastSql(Class clazz, Boolean isOracleDataSource) { StringBuffer matchedBuf = new StringBuffer(" WHEN MATCHED THEN UPDATE SET "); String whereSql = ""; StringBuffer notMatchedBuf = new StringBuffer(" WHEN NOT MATCHED THEN INSERT "); StringBuffer columns = new StringBuffer("( "); StringBuffer values = new StringBuffer("VALUES ( "); Field[] declaredFields = clazz.getDeclaredFields(); for (int i = 0; i < declaredFields.length; ) { Field field = declaredFields[i]; TableField tableField = field.getAnnotation(TableField.class); if (null != tableField && !tableField.exist()) { continue; } TableId tableId = field.getAnnotation(TableId.class); String tableFieldVal = getColumnByField(field); if (null != tableId) { if (!isOracleDataSource) { i++; continue; } whereSql = String.format(" WHERe u.%s = t.%s ", tableFieldVal, tableFieldVal); columns.append(tableFieldVal).append(", "); values.append(String.format("t.%s", tableFieldVal)).append(", "); i++; continue; } matchedBuf.append(String.format("u.%s = t.%s", tableFieldVal, tableFieldVal)); columns.append(tableFieldVal); values.append(String.format("t.%s", tableFieldVal)); if (++i < declaredFields.length) { matchedBuf.append(", "); columns.append(", "); values.append(", "); } else { columns.append(") "); values.append(")"); } } matchedBuf.append(whereSql).append(notMatchedBuf).append(columns).append(values); return matchedBuf; } private static String getColumnByField(Field field) { String fieldName = field.getName(); String[] split = fieldName.split(""); StringBuffer buffer = new StringBuffer(""); for (String s : split) { if (s.equals(s.toUpperCase())) { buffer.append("_").append(s.toUpperCase()); } else { buffer.append(s.toUpperCase()); } } return buffer.toString(); } }
import com.baomidou.mybatisplus.annotation.TableName; public class StringUtils { private StringUtils() { } public staticString getTableName(Class clazz) { TableName tableName = clazz.getAnnotation(TableName.class); if (null == tableName || isEmptyString(tableName.value())) { throw new RuntimeException("未查询到表名"); } return tableName.value(); } public static boolean isEmpty(Object str) { return str == null || "".equals(str); } public static boolean isNotEmpty(Object str) { return str != null && !"".equals(str); } public static String nullToString(String s) { return isEmptyString(s) ? "" : s; } public static boolean isEmptyString(String s) { if (s == null) { return true; } return "".equals(s.trim()); } public static boolean occurAtLeastCount(String src, String toFind, int atLeastCount) { if (isEmptyString(src) || isEmptyString(toFind) || atLeastCount < 1) { return false; } int index = 0; int count = 0; while ((index = src.indexOf(toFind, index)) != -1) { count++; if (count >= atLeastCount) { return true; } index = index + toFind.length(); } return false; } public static boolean isBlank(String string) { int strLen; if (string == null || (strLen = string.length()) == 0) return true; for (int i = 0; i < strLen; i++) if (!Character.isWhitespace(string.charAt(i))) return false; return true; } public static String trim(String string) { return string != null ? string.trim() : null; } }