import CLibsql
import Foundation

public enum Value {
    case integer(Int64)
    case text(String)
    case blob(Data)
    case real(Double)
    case null
}

public protocol ValueRepresentable {
    func toValue() -> Value
}

extension Value: ValueRepresentable {
    public func toValue() -> Value { self }
}

extension Int: ValueRepresentable {
    public func toValue() -> Value { .integer(Int64(self)) }
}

extension Int64: ValueRepresentable {
    public func toValue() -> Value { .integer(self) }
}

extension String: ValueRepresentable {
    public func toValue() -> Value { .text(self) }
}

extension Data: ValueRepresentable {
    public func toValue() -> Value { .blob(self) }
}

extension Double: ValueRepresentable {
    public func toValue() -> Value { .real(self) }
}

public protocol Prepareable {
    func prepare(_ sql: String) throws -> Statement
}

extension Prepareable {
    func execute(_ sql: String) throws -> Int {
        return try self.prepare(sql).execute()
    }
    
    func execute(_ sql: String, _ params: [String: ValueRepresentable]) throws -> Int {
        return try self.prepare(sql).bind(params).execute()
    }
    
    func execute(_ sql: String, _ params: [ValueRepresentable]) throws -> Int {
        return try self.prepare(sql).bind(params).execute()
    }
    
    func query(_ sql: String) throws -> Rows {
        return try self.prepare(sql).query()
    }
    
    func query(_ sql: String, _ params: [String: ValueRepresentable]) throws -> Rows {
        return try self.prepare(sql).bind(params).query()
    }
    
    func query(_ sql: String, _ params: [ValueRepresentable]) throws -> Rows {
        return try self.prepare(sql).bind(params).query()
    }
}

extension String? {
    func withCString<Result>(_ body: (UnsafePointer<Int8>?) throws -> Result) rethrows -> Result {
        if self == nil {
            return try body(nil)
        } else {
            return try self!.withCString(body)
        }
    }
}

func errIf(_ err: OpaquePointer!) throws {
    if (err != nil) {
        defer { libsql_error_deinit(err) }
        throw LibsqlError.runtimeError(String(cString: libsql_error_message(err)!))
    }
}

enum LibsqlError: Error {
    case runtimeError(String)
    case typeMismatch
}

public class Row {
    var inner: libsql_row_t

    fileprivate init?(from inner: libsql_row_t?) {
        guard let inner = inner else {
            return nil
        }

        self.inner = inner
    }
    
    public func get(_ index: Int32) throws -> Value {
        let result = libsql_row_value(self.inner, index)
        try errIf(result.err)
       
        switch result.ok.type {
        case LIBSQL_TYPE_BLOB:
            let slice = result.ok.value.blob
            defer { libsql_slice_deinit(slice) }
            return .blob(Data(bytes: slice.ptr, count: Int(slice.len)))
        case LIBSQL_TYPE_TEXT:
            let slice = result.ok.value.text
            defer { libsql_slice_deinit(slice) }
            return .text(String(cString: slice.ptr.assumingMemoryBound(to: UInt8.self)))
        case LIBSQL_TYPE_INTEGER:
            return .integer(result.ok.value.integer)
        case LIBSQL_TYPE_REAL:
            return .real(result.ok.value.real)
        case LIBSQL_TYPE_NULL:
            return .null
        default:
            preconditionFailure()
        }
    }

    public func getData(_ index: Int32) throws -> Data {
        guard case let .blob(data) = try self.get(index) else {
            throw LibsqlError.typeMismatch
        }
        return data
    }

    public func getDouble(_ index: Int32) throws -> Double {
        guard case let .real(double) = try self.get(index) else {
            throw LibsqlError.typeMismatch
        }
        return double
    }

    public func getString(_ index: Int32) throws -> String {
        guard case let .text(string) = try self.get(index) else {
            throw LibsqlError.typeMismatch
        }
        return string
    }

    public func getInt(_ index: Int32) throws -> Int {
        guard case let .integer(int) = try self.get(index) else {
            throw LibsqlError.typeMismatch
        }
        return Int(int)
    }
}

public class Rows: Sequence, IteratorProtocol {
    var inner: libsql_rows_t

    fileprivate init(from inner: libsql_rows_t) {
        self.inner = inner
    }

    deinit {
        libsql_rows_deinit(self.inner)
    }

    public func next() -> Row? {
        let row = libsql_rows_next(self.inner)
        try! errIf(row.err)
        
        if libsql_row_empty(row) {
            return nil
        }
        
        return Row(from: row)
    }
}

public class Statement {
    var inner: libsql_statement_t

    deinit {
        libsql_statement_deinit(self.inner)
    }

    fileprivate init(from inner: libsql_statement_t) {
        self.inner = inner
    }

    public func execute() throws -> Int {
        let exec = libsql_statement_execute(self.inner)
        try errIf(exec.err)
        
        return Int(exec.rows_changed)
    }

    public func query() throws -> Rows {
        let rows = libsql_statement_query(self.inner)
        try errIf(rows.err)

        return Rows(from: rows)
    }
    
    public func bind(_ params: [String: ValueRepresentable]) throws -> Self {
        for (name, value) in params {
            switch value.toValue() {
            case .integer(let integer):
                let bind = libsql_statement_bind_named(
                    self.inner,
                    name,
                    libsql_integer(integer)
                )
                try errIf(bind.err)
            case .text(let text):
                let len = text.count + 1
                try text.withCString { text in
                    let bind = libsql_statement_bind_named(
                        self.inner,
                        name,
                        libsql_text(text, len)
                    )
                    try errIf(bind.err)
                }
            case .blob(let slice):
                try slice.withUnsafeBytes { slice in
                    let bind = libsql_statement_bind_named(
                        self.inner,
                        name,
                        libsql_blob(slice.baseAddress, slice.count)
                    )
                    try errIf(bind.err)
                }
            case .real(let real):
                let bind = libsql_statement_bind_named(
                    self.inner,
                    name,
                    libsql_real(real)
                )
                try errIf(bind.err)
            case .null:
                let bind = libsql_statement_bind_named(
                    self.inner,
                    name,
                    libsql_value_t(value: .init(), type: LIBSQL_TYPE_NULL)
                )
                try errIf(bind.err)
            }
        }
        
        return self;
    }

    public func bind(_ params: [ValueRepresentable]) throws -> Self {
        for value in params {
            switch value.toValue() {
            case .integer(let integer):
                let bind = libsql_statement_bind_value(
                    self.inner,
                    libsql_integer(integer)
                )
                try errIf(bind.err)
            case .text(let text):
                let len = text.count + 1
                try text.withCString { text in
                    let bind = libsql_statement_bind_value(
                        self.inner,
                        libsql_text(text, len)
                    )
                    try errIf(bind.err)
                }
            case .blob(let slice):
                try slice.withUnsafeBytes { slice in
                    let bind = libsql_statement_bind_value(
                        self.inner,
                        libsql_blob(slice.baseAddress, slice.count)
                    )
                    try errIf(bind.err)
                }
            case .real(let real):
                let bind = libsql_statement_bind_value(self.inner, libsql_real(real))
                try errIf(bind.err)
            case .null:
                let bind = libsql_statement_bind_value(
                    self.inner,
                    libsql_value_t(value: .init(), type: LIBSQL_TYPE_NULL)
                )
                try errIf(bind.err)
            }
        }
        
        return self;
    }
}

public class Transaction: Prepareable {
    var inner: libsql_transaction_t
    
    public consuming func commit() {
        libsql_transaction_commit(self.inner)
    }
    
    public consuming func rollback() {
        libsql_transaction_rollback(self.inner)
    }

    fileprivate init(from inner: libsql_transaction_t) {
        self.inner = inner
    }
    
    public func execute_batch(_ sql: String) {
        libsql_transaction_batch(self.inner, sql)
    }

    public func prepare(_ sql: String) throws -> Statement {
        let stmt = libsql_transaction_prepare(self.inner, sql);
        try errIf(stmt.err)
        
        return Statement(from: stmt)
    }
    
}

public class Connection: Prepareable {
    var inner: libsql_connection_t

    deinit {
        libsql_connection_deinit(self.inner)
    }

    fileprivate init(from inner: libsql_connection_t) {
        self.inner = inner
    }
    
    public func transaction() throws -> Transaction {
        let tx = libsql_connection_transaction(self.inner)
        try errIf(tx.err);
        
        return Transaction(from: tx)
    }
    
    public func execute_batch(_ sql: String) {
        libsql_connection_batch(self.inner, sql)
    }

    public func prepare(_ sql: String) throws -> Statement {
        let stmt = libsql_connection_prepare(self.inner, sql);
        try errIf(stmt.err)
        
        return Statement(from: stmt)
    }
}

public class Database {
    var inner: libsql_database_t

    deinit {
        libsql_database_deinit(self.inner)
    }

    public func sync() throws {
        let sync = libsql_database_sync(self.inner)
        try errIf(sync.err)
    }

    public func connect() throws -> Connection {
        let conn = libsql_database_connect(self.inner)
        try errIf(conn.err)
        
        return Connection(from: conn)
    }

    public init(_ path: String) throws {
        self.inner = try path.withCString { path in
            var desc = libsql_database_desc_t()
            desc.path = path
            
            let db = libsql_database_init(desc)
            try errIf(db.err)
            
            return db
        }
    }

    public init(url: String, authToken: String, withWebpki: Bool = false) throws {
        self.inner = try url.withCString { url in
            try authToken.withCString { authToken in
                var desc = libsql_database_desc_t()
                desc.url = url
                desc.auth_token = authToken
                desc.webpki = withWebpki
                
                let db = libsql_database_init(desc)
                try errIf(db.err)
                
                return db
            }
        }

    }

    public init(
        path: String,
        url: String,
        authToken: String,
        readYourWrites: Bool = true,
        encryptionKey: String? = nil,
        syncInterval: UInt64 = 0,
        withWebpki: Bool = false
    ) throws {
        self.inner = try path.withCString { path in
            try url.withCString { url in
                try authToken.withCString { authToken in
                    try encryptionKey.withCString { encryptionKey in
                        var desc = libsql_database_desc_t()
                        desc.path = path
                        desc.url = url
                        desc.auth_token = authToken
                        desc.encryption_key = encryptionKey
                        desc.not_read_your_writes = !readYourWrites
                        desc.sync_interval = syncInterval
                        desc.webpki = withWebpki
                        
                        let db = libsql_database_init(desc)
                        try errIf(db.err)
                        
                        return db
                    }
                }
            }
        }
    }
}