@@ -73,6 +73,25 @@ public boolean isDefaultInsertSQL() {
7373 return valuesMatcher .find ();
7474 }
7575
76+ /**
77+ * 是否为批量的insert语句类型
78+ * INSERT INTO user (id,name) VALUES (?,?),(?,?)
79+ */
80+ public boolean isBatchInsertSQL () {
81+ if (!isDefaultInsertSQL ()) {
82+ return false ;
83+ }
84+
85+ String valuesSQL = getValuesSQL ();
86+ if (valuesSQL == null ) {
87+ return false ;
88+ }
89+
90+ String normalized = valuesSQL .replaceAll ("\\ s+" , "" );
91+ return normalized .contains ("),(" );
92+ }
93+
94+
7695 /**
7796 * 提取 VALUES 或 SELECT 后面的 SQL 内容(包含完整结构)
7897 * 示例:
@@ -103,6 +122,100 @@ public String getValuesSQL() {
103122 return null ;
104123 }
105124
125+ private List <String > splitValueGroups (String input ) {
126+ List <String > groups = new ArrayList <>();
127+
128+ int level = 0 ;
129+ StringBuilder current = new StringBuilder ();
130+
131+ for (int i = 0 ; i < input .length (); i ++) {
132+ char c = input .charAt (i );
133+
134+ if (c == '(' ) {
135+ if (level > 0 ) {
136+ current .append (c );
137+ }
138+ level ++;
139+ } else if (c == ')' ) {
140+ level --;
141+ if (level > 0 ) {
142+ current .append (c );
143+ } else {
144+ // 一个完整 group
145+ groups .add ("(" + current .toString () + ")" );
146+ current .setLength (0 );
147+ }
148+ } else if (c == ',' && level == 0 ) {
149+ // group 之间的逗号,忽略
150+ continue ;
151+ } else {
152+ if (level > 0 ) {
153+ current .append (c );
154+ }
155+ }
156+ }
157+
158+ return groups ;
159+ }
160+
161+ public List <List <InsertValue >> getBatchValues () {
162+ List <List <InsertValue >> result = new ArrayList <>();
163+
164+ String valuesSQL = getValuesSQL ();
165+ if (valuesSQL == null ) {
166+ return result ;
167+ }
168+
169+ // 去掉 VALUES 关键字
170+ String normalized = valuesSQL .trim ();
171+ // 如果没有以 ( 开头,补一个
172+ if (!normalized .startsWith ("(" )) {
173+ normalized = "(" + normalized ;
174+ }
175+
176+ // 如果没有以 ) 结尾,补一个
177+ if (!normalized .endsWith (")" )) {
178+ normalized = normalized + ")" ;
179+ }
180+
181+ List <String > groups = splitValueGroups (normalized );
182+
183+ int jdbcIndex = 0 ;
184+
185+ for (String group : groups ) {
186+ // 去掉外层括号
187+ String inner = group .trim ();
188+ if (inner .startsWith ("(" ) && inner .endsWith (")" )) {
189+ inner = inner .substring (1 , inner .length () - 1 );
190+ }
191+
192+ List <String > values = SQLUtils .parseInsertSQLValues (inner );
193+ List <InsertValue > row = new ArrayList <>();
194+
195+ for (String value : values ) {
196+ InsertValue insertValue = new InsertValue ();
197+
198+ String v = value .trim ();
199+ if ("?" .equals (v )) {
200+ insertValue .setType (ValueType .JDBC );
201+ jdbcIndex ++;
202+ insertValue .setValue ("?" + jdbcIndex );
203+ } else if (SQLUtils .isSQLKeyword (v ) || v .startsWith ("(" )) {
204+ insertValue .setType (ValueType .SELECT );
205+ insertValue .setValue (v );
206+ } else {
207+ insertValue .setType (ValueType .STATIC );
208+ insertValue .setValue (v );
209+ }
210+
211+ row .add (insertValue );
212+ }
213+
214+ result .add (row );
215+ }
216+
217+ return result ;
218+ }
106219
107220 public List <InsertValue > getValues () {
108221 List <InsertValue > insertValues = new ArrayList <>();
0 commit comments