Skip to content

Commit 8cbeffa

Browse files
lucaloozjulienschmidt
authored andcommitted
Enable Multi Results support and discard additional results
- packets.go: flag clientMultiResults, update status when receiving an EOF packet, discard additional results on readRow when EOF is reached - statement.go: currently a nil rows.mc is used as an eof, don’t set it if there are no columns to avoid that Next() waits indefinitely - rows.go: discard additional results on close and avoid panic on Columns()
1 parent acb04ff commit 8cbeffa

File tree

3 files changed

+51
-4
lines changed

3 files changed

+51
-4
lines changed

packets.go

+43-2
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
224224
clientTransactions |
225225
clientLocalFiles |
226226
clientPluginAuth |
227+
clientMultiResults |
227228
mc.flags&clientLongFlag
228229

229230
if mc.cfg.ClientFoundRows {
@@ -519,6 +520,10 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
519520
}
520521
}
521522

523+
func readStatus(b []byte) statusFlag {
524+
return statusFlag(b[0]) | statusFlag(b[1])<<8
525+
}
526+
522527
// Ok Packet
523528
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
524529
func (mc *mysqlConn) handleOkPacket(data []byte) error {
@@ -533,7 +538,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
533538
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
534539

535540
// server_status [2 bytes]
536-
mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8
541+
mc.status = readStatus(data[1+n+m : 1+n+m+2])
537542

538543
// warning count [2 bytes]
539544
if !mc.strict {
@@ -652,6 +657,11 @@ func (rows *textRows) readRow(dest []driver.Value) error {
652657

653658
// EOF Packet
654659
if data[0] == iEOF && len(data) == 5 {
660+
// server_status [2 bytes]
661+
rows.mc.status = readStatus(data[3:])
662+
if err := rows.mc.discardMoreResultsIfExists(); err != nil {
663+
return err
664+
}
655665
rows.mc = nil
656666
return io.EOF
657667
}
@@ -709,6 +719,10 @@ func (mc *mysqlConn) readUntilEOF() error {
709719
if err == nil && data[0] != iEOF {
710720
continue
711721
}
722+
if err == nil && data[0] == iEOF && len(data) == 5 {
723+
mc.status = readStatus(data[3:])
724+
}
725+
712726
return err // Err or EOF
713727
}
714728
}
@@ -1013,6 +1027,28 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
10131027
return mc.writePacket(data)
10141028
}
10151029

1030+
func (mc *mysqlConn) discardMoreResultsIfExists() error {
1031+
for mc.status&statusMoreResultsExists != 0 {
1032+
resLen, err := mc.readResultSetHeaderPacket()
1033+
if err != nil {
1034+
return err
1035+
}
1036+
if resLen > 0 {
1037+
// columns
1038+
if err := mc.readUntilEOF(); err != nil {
1039+
return err
1040+
}
1041+
// rows
1042+
if err := mc.readUntilEOF(); err != nil {
1043+
return err
1044+
}
1045+
} else {
1046+
mc.status &^= statusMoreResultsExists
1047+
}
1048+
}
1049+
return nil
1050+
}
1051+
10161052
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
10171053
func (rows *binaryRows) readRow(dest []driver.Value) error {
10181054
data, err := rows.mc.readPacket()
@@ -1022,11 +1058,16 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
10221058

10231059
// packet indicator [1 byte]
10241060
if data[0] != iOK {
1025-
rows.mc = nil
10261061
// EOF Packet
10271062
if data[0] == iEOF && len(data) == 5 {
1063+
rows.mc.status = readStatus(data[3:])
1064+
if err := rows.mc.discardMoreResultsIfExists(); err != nil {
1065+
return err
1066+
}
1067+
rows.mc = nil
10281068
return io.EOF
10291069
}
1070+
rows.mc = nil
10301071

10311072
// Error otherwise
10321073
return rows.mc.handleErrorPacket(data)

rows.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ type emptyRows struct{}
3838

3939
func (rows *mysqlRows) Columns() []string {
4040
columns := make([]string, len(rows.columns))
41-
if rows.mc.cfg.ColumnsWithAlias {
41+
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
4242
for i := range columns {
4343
if tableName := rows.columns[i].tableName; len(tableName) > 0 {
4444
columns[i] = tableName + "." + rows.columns[i].name
@@ -65,6 +65,12 @@ func (rows *mysqlRows) Close() error {
6565

6666
// Remove unread packets from stream
6767
err := mc.readUntilEOF()
68+
if err == nil {
69+
if err = mc.discardMoreResultsIfExists(); err != nil {
70+
return err
71+
}
72+
}
73+
6874
rows.mc = nil
6975
return err
7076
}

statement.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
101101
}
102102

103103
rows := new(binaryRows)
104-
rows.mc = mc
105104

106105
if resLen > 0 {
106+
rows.mc = mc
107107
// Columns
108108
// If not cached, read them and cache them
109109
if stmt.columns == nil {

0 commit comments

Comments
 (0)