Skip to content

Commit 9b50c47

Browse files
committed
BUG: enable multivalues insert
1 parent 569bc7a commit 9b50c47

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

pandas/io/sql.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -572,8 +572,11 @@ def create(self):
572572
else:
573573
self._execute_create()
574574

575-
def insert_statement(self):
576-
return self.table.insert()
575+
def insert_statement(self, data, conn):
576+
dialect = getattr(conn, 'dialect', None)
577+
if dialect and getattr(dialect, 'supports_multivalues_insert', False):
578+
return (self.table.insert(data),)
579+
return (self.table.insert(), data)
577580

578581
def insert_data(self):
579582
if self.index is not None:
@@ -613,7 +616,7 @@ def insert_data(self):
613616

614617
def _execute_insert(self, conn, keys, data_iter):
615618
data = [{k: v for k, v in zip(keys, row)} for row in data_iter]
616-
conn.execute(self.insert_statement(), data)
619+
conn.execute(*self.insert_statement(data, conn))
617620

618621
def insert(self, chunksize=None):
619622
keys, data_list = self.insert_data()

pandas/tests/io/test_sql.py

+22
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,25 @@ def _transaction_test(self):
479479
res2 = self.pandasSQL.read_query('SELECT * FROM test_trans')
480480
assert len(res2) == 1
481481

482+
def _test_insert_multivalues(self):
483+
db = sql.SQLDatabase(self.conn)
484+
df = DataFrame({'A': [1, 0, 0], 'B': [1.1, 0.2, 4.3]})
485+
table = sql.SQLTable("test_table", db, frame=df)
486+
data = [
487+
{'A': 1, 'B': 0.46},
488+
{'A': 0, 'B': -2.06}
489+
]
490+
statement = table.insert_statement(data, conn=self.conn)[0]
491+
dialect = getattr(self.conn, 'dialect', None)
492+
if dialect and getattr(dialect, 'supports_multivalues_insert', False):
493+
assert statement.parameters == data, (
494+
'insert statement should be multivalues'
495+
)
496+
else:
497+
assert statement.parameters is None, (
498+
'insert statement should not be multivalues'
499+
)
500+
482501

483502
# -----------------------------------------------------------------------------
484503
# -- Testing the public API
@@ -1665,6 +1684,9 @@ class Temporary(Base):
16651684

16661685
tm.assert_frame_equal(df, expected)
16671686

1687+
def test_insert_multivalues(self):
1688+
self._test_insert_multivalues()
1689+
16681690

16691691
class _TestSQLAlchemyConn(_EngineToConnMixin, _TestSQLAlchemy):
16701692

0 commit comments

Comments
 (0)