Skip to content

Commit 4aa920d

Browse files
badoetjulienschmidt
authored andcommitted
TestMultiQuery
discard additional OK response after Multi Statement Exec Calls
1 parent 71c5db6 commit 4aa920d

File tree

3 files changed

+83
-0
lines changed

3 files changed

+83
-0
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Nicola Peduzzi <thenikso at gmail.com>
3939
Runrioter Wung <runrioter at gmail.com>
4040
Soroush Pour <me at soroushjp.com>
4141
Stan Putrya <root.vagner at gmail.com>
42+
Stanley Gunawan <gunawan.stanley at gmail.com>
4243
Xiaobing Jiang <s7v7nislands at gmail.com>
4344
Xiuming Chen <cc at cxm.cc>
4445

driver_test.go

+81
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,28 @@ type DBTest struct {
7676
db *sql.DB
7777
}
7878

79+
func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
80+
if !available {
81+
t.Skipf("MySQL-Server not running on %s", netAddr)
82+
}
83+
84+
dsn3 := dsn + "&multiStatements=true"
85+
var db3 *sql.DB
86+
if _, err := parseDSN(dsn3); err != errInvalidDSNUnsafeCollation {
87+
db3, err = sql.Open("mysql", dsn3)
88+
if err != nil {
89+
t.Fatalf("Error connecting: %s", err.Error())
90+
}
91+
defer db3.Close()
92+
}
93+
94+
dbt3 := &DBTest{t, db3}
95+
for _, test := range tests {
96+
test(dbt3)
97+
dbt3.db.Exec("DROP TABLE IF EXISTS test")
98+
}
99+
}
100+
79101
func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
80102
if !available {
81103
t.Skipf("MySQL server not running on %s", netAddr)
@@ -99,15 +121,30 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
99121
defer db2.Close()
100122
}
101123

124+
dsn3 := dsn + "&multiStatements=true"
125+
var db3 *sql.DB
126+
if _, err := parseDSN(dsn3); err != errInvalidDSNUnsafeCollation {
127+
db3, err = sql.Open("mysql", dsn3)
128+
if err != nil {
129+
t.Fatalf("Error connecting: %s", err.Error())
130+
}
131+
defer db3.Close()
132+
}
133+
102134
dbt := &DBTest{t, db}
103135
dbt2 := &DBTest{t, db2}
136+
dbt3 := &DBTest{t, db3}
104137
for _, test := range tests {
105138
test(dbt)
106139
dbt.db.Exec("DROP TABLE IF EXISTS test")
107140
if db2 != nil {
108141
test(dbt2)
109142
dbt2.db.Exec("DROP TABLE IF EXISTS test")
110143
}
144+
if db3 != nil {
145+
test(dbt3)
146+
dbt3.db.Exec("DROP TABLE IF EXISTS test")
147+
}
111148
}
112149
}
113150

@@ -237,6 +274,50 @@ func TestCRUD(t *testing.T) {
237274
})
238275
}
239276

277+
func TestMultiQuery(t *testing.T) {
278+
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
279+
// Create Table
280+
dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ")
281+
282+
// Create Data
283+
res := dbt.mustExec("INSERT INTO test VALUES (1, 1)")
284+
count, err := res.RowsAffected()
285+
if err != nil {
286+
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
287+
}
288+
if count != 1 {
289+
dbt.Fatalf("Expected 1 affected row, got %d", count)
290+
}
291+
292+
// Update
293+
res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;")
294+
count, err = res.RowsAffected()
295+
if err != nil {
296+
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
297+
}
298+
if count != 1 {
299+
dbt.Fatalf("Expected 1 affected row, got %d", count)
300+
}
301+
302+
// Read
303+
var out int
304+
rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;")
305+
if rows.Next() {
306+
rows.Scan(&out)
307+
if 5 != out {
308+
dbt.Errorf("5 != %t", out)
309+
}
310+
311+
if rows.Next() {
312+
dbt.Error("unexpected data")
313+
}
314+
} else {
315+
dbt.Error("no data")
316+
}
317+
318+
})
319+
}
320+
240321
func TestInt(t *testing.T) {
241322
runTests(t, dsn, func(dbt *DBTest) {
242323
types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"}

packets.go

+1
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
543543

544544
// server_status [2 bytes]
545545
mc.status = readStatus(data[1+n+m : 1+n+m+2])
546+
mc.discardMoreResultsIfExists()
546547

547548
// warning count [2 bytes]
548549
if !mc.strict {

0 commit comments

Comments
 (0)