@@ -1806,3 +1806,167 @@ func TestCopyInStmtAffectedRows(t *testing.T) {
18061806 res .RowsAffected ()
18071807 res .LastInsertId ()
18081808}
1809+
1810+ func TestConnPrepareContext (t * testing.T ) {
1811+ db := openTestConn (t )
1812+ defer db .Close ()
1813+
1814+ tests := []struct {
1815+ name string
1816+ ctx func () (context.Context , context.CancelFunc )
1817+ sql string
1818+ err error
1819+ }{
1820+ {
1821+ name : "context.Background" ,
1822+ ctx : func () (context.Context , context.CancelFunc ) {
1823+ return context .Background (), nil
1824+ },
1825+ sql : "SELECT 1" ,
1826+ err : nil ,
1827+ },
1828+ {
1829+ name : "context.WithTimeout exceeded" ,
1830+ ctx : func () (context.Context , context.CancelFunc ) {
1831+ return context .WithTimeout (context .Background (), time .Microsecond )
1832+ },
1833+ sql : "SELECT 1" ,
1834+ err : context .DeadlineExceeded ,
1835+ },
1836+ {
1837+ name : "context.WithTimeout" ,
1838+ ctx : func () (context.Context , context.CancelFunc ) {
1839+ return context .WithTimeout (context .Background (), time .Minute )
1840+ },
1841+ sql : "SELECT 1" ,
1842+ err : nil ,
1843+ },
1844+ }
1845+ for _ , tt := range tests {
1846+ t .Run (tt .name , func (t * testing.T ) {
1847+ ctx , cancel := tt .ctx ()
1848+ if cancel != nil {
1849+ defer cancel ()
1850+ }
1851+ _ , err := db .PrepareContext (ctx , tt .sql )
1852+ switch {
1853+ case (err != nil ) != (tt .err != nil ):
1854+ t .Fatalf ("conn.PrepareContext() unexpected nil err got = %v, expected = %v" , err , tt .err )
1855+ case (err != nil && tt .err != nil ) && (err .Error () != tt .err .Error ()):
1856+ t .Errorf ("conn.PrepareContext() got = %v, expected = %v" , err .Error (), tt .err .Error ())
1857+ }
1858+ })
1859+ }
1860+ }
1861+
1862+ func TestStmtQueryContext (t * testing.T ) {
1863+ db := openTestConn (t )
1864+ defer db .Close ()
1865+
1866+ tests := []struct {
1867+ name string
1868+ ctx func () (context.Context , context.CancelFunc )
1869+ sql string
1870+ err error
1871+ }{
1872+ {
1873+ name : "context.Background" ,
1874+ ctx : func () (context.Context , context.CancelFunc ) {
1875+ return context .Background (), nil
1876+ },
1877+ sql : "SELECT pg_sleep(1);" ,
1878+ err : nil ,
1879+ },
1880+ {
1881+ name : "context.WithTimeout exceeded" ,
1882+ ctx : func () (context.Context , context.CancelFunc ) {
1883+ return context .WithTimeout (context .Background (), 1 * time .Second )
1884+ },
1885+ sql : "SELECT pg_sleep(10);" ,
1886+ err : & Error {Message : "canceling statement due to user request" },
1887+ },
1888+ {
1889+ name : "context.WithTimeout" ,
1890+ ctx : func () (context.Context , context.CancelFunc ) {
1891+ return context .WithTimeout (context .Background (), time .Minute )
1892+ },
1893+ sql : "SELECT pg_sleep(1);" ,
1894+ err : nil ,
1895+ },
1896+ }
1897+ for _ , tt := range tests {
1898+ t .Run (tt .name , func (t * testing.T ) {
1899+ ctx , cancel := tt .ctx ()
1900+ if cancel != nil {
1901+ defer cancel ()
1902+ }
1903+ stmt , err := db .PrepareContext (ctx , tt .sql )
1904+ if err != nil {
1905+ t .Fatal (err )
1906+ }
1907+ _ , err = stmt .QueryContext (ctx )
1908+ switch {
1909+ case (err != nil ) != (tt .err != nil ):
1910+ t .Fatalf ("stmt.QueryContext() unexpected nil err got = %v, expected = %v" , err , tt .err )
1911+ case (err != nil && tt .err != nil ) && (err .Error () != tt .err .Error ()):
1912+ t .Errorf ("stmt.QueryContext() got = %v, expected = %v" , err .Error (), tt .err .Error ())
1913+ }
1914+ })
1915+ }
1916+ }
1917+
1918+ func TestStmtExecContext (t * testing.T ) {
1919+ db := openTestConn (t )
1920+ defer db .Close ()
1921+
1922+ tests := []struct {
1923+ name string
1924+ ctx func () (context.Context , context.CancelFunc )
1925+ sql string
1926+ err error
1927+ }{
1928+ {
1929+ name : "context.Background" ,
1930+ ctx : func () (context.Context , context.CancelFunc ) {
1931+ return context .Background (), nil
1932+ },
1933+ sql : "SELECT pg_sleep(1);" ,
1934+ err : nil ,
1935+ },
1936+ {
1937+ name : "context.WithTimeout exceeded" ,
1938+ ctx : func () (context.Context , context.CancelFunc ) {
1939+ return context .WithTimeout (context .Background (), 1 * time .Second )
1940+ },
1941+ sql : "SELECT pg_sleep(10);" ,
1942+ err : & Error {Message : "canceling statement due to user request" },
1943+ },
1944+ {
1945+ name : "context.WithTimeout" ,
1946+ ctx : func () (context.Context , context.CancelFunc ) {
1947+ return context .WithTimeout (context .Background (), time .Minute )
1948+ },
1949+ sql : "SELECT pg_sleep(1);" ,
1950+ err : nil ,
1951+ },
1952+ }
1953+ for _ , tt := range tests {
1954+ t .Run (tt .name , func (t * testing.T ) {
1955+ ctx , cancel := tt .ctx ()
1956+ if cancel != nil {
1957+ defer cancel ()
1958+ }
1959+ stmt , err := db .PrepareContext (ctx , tt .sql )
1960+ if err != nil {
1961+ t .Fatal (err )
1962+ }
1963+ _ , err = stmt .ExecContext (ctx )
1964+ switch {
1965+ case (err != nil ) != (tt .err != nil ):
1966+ t .Fatalf ("stmt.ExecContext() unexpected nil err got = %v, expected = %v" , err , tt .err )
1967+ case (err != nil && tt .err != nil ) && (err .Error () != tt .err .Error ()):
1968+ t .Errorf ("stmt.ExecContext() got = %v, expected = %v" , err .Error (), tt .err .Error ())
1969+ }
1970+ })
1971+ }
1972+ }
0 commit comments