mutable struct Statement <: DBInterface.Statement
    conn::Connection
    stmt::API.MYSQL_STMT
    sql::String
    nparams::Int
    nfields::Int
    bindhelpers::Vector{API.BindHelper}
    binds::Vector{API.MYSQL_BIND}
    names::Vector{Symbol}
    types::Vector{Type}
    lookup::Dict{Symbol, Int}
    valuehelpers::Vector{API.BindHelper}
    values::Vector{API.MYSQL_BIND}

    function Statement(conn::Connection, stmt::API.MYSQL_STMT, sql::AbstractString, nparams::Integer, nfields::Integer, bindhelpers, binds, names, types, valuehelpers, values)
        lookup = Dict(x => i for (i, x) in enumerate(names))
        s = new(conn, stmt, sql, nparams, nfields, bindhelpers, binds, names, types, lookup, valuehelpers, values)
        return s
    end
end

@noinline checkstmt(stmt::Statement) = checkstmt(stmt.stmt)
@noinline checkstmt(stmt::API.MYSQL_STMT) = stmt.ptr == C_NULL && error("prepared mysql statement has been closed")

DBInterface.getconnection(stmt::Statement) = stmt.conn

"""
    DBInterface.close!(stmt)

Close a prepared statement and free any underlying resources. The statement should not be used in any way afterwards.
"""
DBInterface.close!(stmt::Statement) = finalize(stmt.stmt)

"""
    DBInterface.prepare(conn::MySQL.Connection, sql) => MySQL.Statement

Send a `sql` SQL string to the database to be prepared, returning a `MySQL.Statement` object
that can be passed to `DBInterface.execute(stmt, args...)` to be repeatedly executed,
optionally passing `args` for parameters to be bound on each execution.

Note that `DBInterface.close!(stmt)` should be called once statement executions are finished. Apart from
freeing resources, it has been noted that too many unclosed statements and resultsets, used in conjunction
with streaming queries (i.e. `mysql_store_result=false`) has led to occasional resultset corruption.
"""
function DBInterface.prepare(conn::Connection, sql::AbstractString; mysql_date_and_time::Bool=false)
    clear!(conn)
    stmt = API.stmtinit(conn.mysql)
    API.prepare(stmt, sql)
    nparams = API.paramcount(stmt)
    bindhelpers = [API.BindHelper() for i = 1:nparams]
    binds = [API.MYSQL_BIND(bindhelpers[i].length, bindhelpers[i].is_null) for i = 1:nparams]
    nfields = API.fieldcount(stmt)
    result = API.resultmetadata(stmt)
    if result.ptr != C_NULL
        fields = API.fetchfields(result, nfields)
        names = [ccall(:jl_symbol_n, Ref{Symbol}, (Cstring, Csize_t), x.name, x.name_length) for x in fields]
        types = [juliatype(x.field_type, API.notnullable(x), API.isunsigned(x), API.isbinary(x), mysql_date_and_time) for x in fields]
        valuehelpers = [API.BindHelper() for i = 1:nfields]
        values = [API.MYSQL_BIND(valuehelpers[i].length, valuehelpers[i].is_null) for i = 1:nfields]
        foreach(1:nfields) do i
            returnbind!(valuehelpers[i], values, i, fields[i].field_type, types[i])
        end
        API.bindresult(stmt, values)
    else
        fields = API.MYSQL_FIELD[]
        names = Symbol[]
        types = Type[]
        valuehelpers = API.BindHelper[]
        values = API.MYSQL_BIND[]
    end
    return Statement(conn, stmt, sql, nparams, nfields, bindhelpers, binds, names, types, valuehelpers, values)
end

mutable struct Cursor{buffered} <: DBInterface.Cursor
    stmt::API.MYSQL_STMT
    nfields::Int
    names::Vector{Symbol}
    types::Vector{Type}
    lookup::Dict{Symbol, Int}
    valuehelpers::Vector{API.BindHelper}
    values::Vector{API.MYSQL_BIND}
    rows_affected::Int64
    rows::Int
    current_rownumber::Int
end

struct Row <: Tables.AbstractRow
    cursor::Cursor
    rownumber::Int
end

getcursor(r::Row) = getfield(r, :cursor)
getrownumber(r::Row) = getfield(r, :rownumber)

Tables.columnnames(r::Row) = getcursor(r).names

function Tables.getcolumn(r::Row, ::Type{T}, i::Int, nm::Symbol) where {T}
    cursor = getcursor(r)
    getrownumber(r) == cursor.current_rownumber || wrongrow(getrownumber(r))
    return getvalue(cursor.stmt, cursor.valuehelpers[i], cursor.values, i, T)
end

Tables.getcolumn(r::Row, i::Int) = Tables.getcolumn(r, getcursor(r).types[i], i, getcursor(r).names[i])
Tables.getcolumn(r::Row, nm::Symbol) = Tables.getcolumn(r, getcursor(r).lookup[nm])

Tables.isrowtable(::Type{<:Cursor}) = true
Tables.schema(c::Cursor) = Tables.Schema(c.names, c.types)

Base.eltype(c::Cursor) = Row
Base.IteratorSize(::Type{Cursor{true}}) = Base.HasLength()
Base.IteratorSize(::Type{Cursor{false}}) = Base.SizeUnknown()
Base.length(c::Cursor) = c.rows

function Base.iterate(cursor::Cursor, i=1)
    cursor.stmt.ptr == C_NULL && return nothing
    status = API.fetch(cursor.stmt)
    status == API.MYSQL_NO_DATA && return nothing
    status == 1 && throw(API.StmtError(cursor.stmt))
    cursor.current_rownumber = i
    return Row(cursor, i), i + 1
end

"""
    DBInterface.lastrowid(c::MySQL.Cursor)

Return the last inserted row id.
"""
function DBInterface.lastrowid(c::Cursor)
    checkstmt(c.stmt)
    return API.insertid(c.stmt)
end

"""
    DBInterface.close!(cursor)

Close a cursor. No more results will be available.
"""
DBInterface.close!(c::Cursor) = clear!(c.conn)

@noinline paramcheck(stmt, args) = length(args) == stmt.nparams || throw(MySQLInterfaceError("stmt requires $(stmt.nparams) params, only $(length(args)) provided"))

"""
    DBInterface.execute(stmt, params; mysql_store_result=true) => MySQL.Cursor

Execute a prepared statement, optionally passing `params` to be bound as parameters (like `?` in the sql).
Returns a `Cursor` object, which iterates resultset rows and satisfies the `Tables.jl` interface, meaning
results can be sent to any valid sink function (`DataFrame(cursor)`, `CSV.write("results.csv", cursor)`, etc.).
Specifying `mysql_store_result=false` will avoid buffering the full resultset to the client after executing
the query, which has memory use advantages, though ties up the database server since resultset rows must be
fetched one at a time.
"""
function DBInterface.execute(stmt::Statement, params=(); mysql_store_result::Bool=true, mysql_date_and_time::Bool=false)
    checkstmt(stmt)
    paramcheck(stmt, params)
    clear!(stmt.conn)
    if length(params) > 0
        foreach(1:stmt.nparams) do i
            bind!(stmt.bindhelpers[i], stmt.binds, i, params[i])
        end
        API.bindparam(stmt.stmt, stmt.binds)
    end
    API.execute(stmt.stmt)
    stmt.conn.lastexecute = stmt.stmt
    rows_affected = Core.bitcast(Int64, API.affectedrows(stmt.stmt))
    buffered = false
    rows = -1
    if mysql_store_result
        API.storeresult(stmt.stmt)
        buffered = true
        rows = API.numrows(stmt.stmt)
    end
    nfields = stmt.nfields
    names = stmt.names
    types = stmt.types
    valuehelpers = stmt.valuehelpers
    values = stmt.values
    lookup = stmt.lookup
    if stmt.nfields == 0
        nfields = API.fieldcount(stmt.stmt)
        result = API.resultmetadata(stmt.stmt)
        if result.ptr != C_NULL
            fields = API.fetchfields(result, nfields)
            names = [ccall(:jl_symbol_n, Ref{Symbol}, (Cstring, Csize_t), x.name, x.name_length) for x in fields]
            types = [juliatype(x.field_type, API.notnullable(x), API.isunsigned(x), API.isbinary(x), mysql_date_and_time) for x in fields]
            valuehelpers = [API.BindHelper() for i = 1:nfields]
            values = [API.MYSQL_BIND(valuehelpers[i].length, valuehelpers[i].is_null) for i = 1:nfields]
            foreach(1:nfields) do i
                returnbind!(valuehelpers[i], values, i, fields[i].field_type, types[i])
            end
            API.bindresult(stmt.stmt, values)
            lookup = Dict(x => i for (i, x) in enumerate(names))
        end
    end
    return Cursor{buffered}(stmt.stmt, nfields, names, types, lookup, valuehelpers, values, rows_affected, rows, 0)
end

inithelper!(helper, x::Missing) = nothing
ptrhelper(helper, x::Missing) = C_NULL

function getvalue(stmt, helper, values, i, ::Type{Union{T, Missing}}) where {T}
    helper.is_null[1] == 1 && return missing
    return getvalue(stmt, helper, values, i, T)
end

inithelper!(helper, x::API.Bit) = nothing
ptrhelper(helper, x::API.Bit) = C_NULL
sethelper!(helper, x::API.Bit) = nothing

function getvalue(stmt, helper, values, i, ::Type{API.Bit})
    len = helper.length[1]
    val = UInt64[0]
    ptr = pointer(values, i)
    API.setbuffer!(ptr, pointer(val))
    API.setbufferlength!(ptr, len)
    API.mysql_stmt_fetch_column(stmt.ptr, convert(Ptr{Cvoid}, ptr), i - 1, 0)
    x = val[1]
    return API.Bit(x >> (8 * (len - 1)))
end

inithelper!(helper, x::Union{Bool, UInt8, Int8}) = helper.uint8 = UInt8[Core.bitcast(UInt8, x)]
ptrhelper(helper, x::Union{Bool, UInt8, Int8}) = pointer(helper.uint8)
sethelper!(helper, x::Union{Bool, UInt8, Int8}) = helper.uint8[1] = Core.bitcast(UInt8, x)
getvalue(stmt, helper, values, i, ::Type{T}) where {T <: Union{Bool, UInt8, Int8}} = Core.bitcast(T, helper.uint8[1])

inithelper!(helper, x::Union{UInt16, Int16}) = helper.uint16 = UInt16[Core.bitcast(UInt16, x)]
ptrhelper(helper, x::Union{UInt16, Int16}) = pointer(helper.uint16)
sethelper!(helper, x::Union{UInt16, Int16}) = helper.uint16[1] = Core.bitcast(UInt16, x)
getvalue(stmt, helper, values, i, ::Type{T}) where {T <: Union{UInt16, Int16}} = Core.bitcast(T, helper.uint16[1])

inithelper!(helper, x::Union{UInt32, Int32}) = helper.uint32 = UInt32[Core.bitcast(UInt32, x)]
ptrhelper(helper, x::Union{UInt32, Int32}) = pointer(helper.uint32)
sethelper!(helper, x::Union{UInt32, Int32}) = helper.uint32[1] = Core.bitcast(UInt32, x)
getvalue(stmt, helper, values, i, ::Type{T}) where {T <: Union{UInt32, Int32}} = Core.bitcast(T, helper.uint32[1])

inithelper!(helper, x::Union{UInt64, Int64}) = helper.uint64 = UInt64[Core.bitcast(UInt64, x)]
ptrhelper(helper, x::Union{UInt64, Int64}) = pointer(helper.uint64)
sethelper!(helper, x::Union{UInt64, Int64}) = helper.uint64[1] = Core.bitcast(UInt64, x)
getvalue(stmt, helper, values, i, ::Type{T}) where {T <: Union{UInt64, Int64}} = Core.bitcast(T, helper.uint64[1])

inithelper!(helper, x::Float32) = helper.float = Float32[x]
ptrhelper(helper, x::Float32) = pointer(helper.float)
sethelper!(helper, x::Float32) = helper.float[1] = x
getvalue(stmt, helper, values, i, ::Type{Float32}) = helper.float[1]

inithelper!(helper, x::Float64) = helper.double = Float64[x]
ptrhelper(helper, x::Float64) = pointer(helper.double)
sethelper!(helper, x::Float64) = helper.double[1] = x
getvalue(stmt, helper, values, i, ::Type{Float64}) = helper.double[1]

inithelper!(helper, x::API.MYSQL_TIME) = helper.time = API.MYSQL_TIME[x]
ptrhelper(helper, x::API.MYSQL_TIME) = pointer(helper.time)
getvalue(stmt, helper, values, i, ::Type{Time}) = convert(Time, helper.time[1])
getvalue(stmt, helper, values, i, ::Type{Date}) = convert(Date, helper.time[1])
getvalue(stmt, helper, values, i, ::Type{DateTime}) = convert(DateTime, helper.time[1])
getvalue(stmt, helper, values, i, ::Type{DateAndTime}) = convert(DateAndTime, helper.time[1])

inithelper!(helper, x::String) = nothing
ptrhelper(helper, x::String) = C_NULL
sethelper!(helper, x::String) = helper.string = x

function getvalue(stmt, helper, values, i, ::Type{String})
    len = helper.length[1]
    str = Base._string_n(len)
    ptr = pointer(values, i)
    API.setbuffer!(ptr, pointer(str))
    API.setbufferlength!(ptr, len)
    API.mysql_stmt_fetch_column(stmt.ptr, convert(Ptr{Cvoid}, ptr), i - 1, 0)
    return str
end

inithelper!(helper, x::Vector{UInt8}) = nothing
ptrhelper(helper, x::Vector{UInt8}) = C_NULL
sethelper!(helper, x::Vector{UInt8}) = helper.blob = x

function getvalue(stmt, helper, values, i, ::Type{Vector{UInt8}})
    len = helper.length[1]
    blob = Vector{UInt8}(undef, len)
    ptr = pointer(values, i)
    API.setbuffer!(ptr, pointer(blob))
    API.setbufferlength!(ptr, len)
    API.mysql_stmt_fetch_column(stmt.ptr, convert(Ptr{Cvoid}, ptr), i - 1, 0)
    return blob
end

inithelper!(helper, x::Dec64) = nothing
ptrhelper(helper, x::Dec64) = C_NULL

function getvalue(stmt, helper, values, i, ::Type{Dec64})
    len = helper.length[1]
    str = Base._string_n(len)
    ptr = pointer(values, i)
    API.setbuffer!(ptr, pointer(str))
    API.setbufferlength!(ptr, len)
    API.mysql_stmt_fetch_column(stmt.ptr, convert(Ptr{Cvoid}, ptr), i - 1, 0)
    return parse(Dec64, str)
end

defaultvalue(T) = zero(T)
defaultvalue(::Type{Union{Missing, T}}) where {T} = defaultvalue(T)
defaultvalue(::Type{API.Bit}) = API.Bit(0)
defaultvalue(::Type{T}) where {T <: Dates.TimeType} = convert(API.MYSQL_TIME, Date(2000))
defaultvalue(::Type{String}) = ""
defaultvalue(::Type{Vector{UInt8}}) = UInt8[]

function returnbind!(helper, binds, i, type, ::Type{T}) where {T}
    x = defaultvalue(T)
    inithelper!(helper, x)
    ptr = pointer(binds, i)
    API.setbuffer!(ptr, ptrhelper(helper, x))
    API.setbuffertype!(ptr, type)
    helper.typeset = true
    return
end

function bind!(helper, binds, i, x::Missing)
    helper.is_null[1] = true
    return
end

function bind!(helper, binds, i, x::Real)
    if !helper.typeset
        inithelper!(helper, x)
        # set buffer address
        ptr = pointer(binds, i)
        API.setbuffer!(ptr, ptrhelper(helper, x))
        # set buffer_type
        API.setbuffertype!(ptr, API.mysqltype(x))
        typeof(x) <: Unsigned && API.setisunsigned!(ptr, true)
        helper.typeset = true
    end
    sethelper!(helper, x)
    helper.is_null[1] = false
    return
end

function bind!(helper, binds, i, x::Dates.TimeType)
    t = convert(API.MYSQL_TIME, x)
    if !helper.typeset
        helper.time = API.MYSQL_TIME[t]
        # set buffer address
        ptr = pointer(binds, i)
        API.setbuffer!(ptr, pointer(helper.time))
        # set buffer_type
        API.setbuffertype!(ptr, API.mysqltype(x))
        helper.typeset = true
    end
    helper.time[1] = t
    helper.is_null[1] = false
    return
end

val(x) = x
val(x::AbstractString) = String(x)
val(x::API.Bit) = API.bitvalue(x)
val(x::DecFP.DecimalFloatingPoint) = string(x)

len(x::String) = sizeof(x)
len(x::Vector{UInt8}) = length(x)

function bind!(helper, binds, i, x::Union{Vector{UInt8}, AbstractString, API.Bit, DecFP.DecimalFloatingPoint})
    ptr = pointer(binds, i)
    y = val(x)
    if !helper.typeset
        # set buffer_type
        API.setbuffertype!(ptr, API.mysqltype(y))
        helper.typeset = true
    end
    sethelper!(helper, y)
    API.setbuffer!(ptr, pointer(y))
    API.setbufferlength!(ptr, len(y))
    helper.is_null[1] = false
    helper.length[1] = len(y)
    return
end