// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import CTensorFlow
import XCTest

@testable import TensorFlow

struct Empty: TensorGroup {}

struct Simple: TensorGroup, Equatable {
  var w, b: Tensor<Float>
}

struct Mixed: TensorGroup, Equatable {
  // Mutable.
  var float: Tensor<Float>
  // Immutable.
  var int: Tensor<Int32>
}

struct Nested: TensorGroup, Equatable {
  // Immutable.
  var simple: Simple
  // Mutable.
  var mixed: Mixed
}

struct Generic<T: TensorGroup & Equatable, U: TensorGroup & Equatable>: TensorGroup, Equatable {
  var t: T
  var u: U
}

struct UltraNested<T: TensorGroup & Equatable, V: TensorGroup & Equatable>: TensorGroup, Equatable {
  var a: Generic<T, V>
  var b: Generic<V, T>
}

extension TensorHandle {
  func makeCopy() -> TFETensorHandle {
    let status = TF_NewStatus()
    let result = TFETensorHandle(
      _owning: TFE_TensorHandleCopySharingTensor(handle._cTensorHandle, status)!)
    XCTAssertEqual(TF_GetCode(status), TF_OK)
    TF_DeleteStatus(status)
    return result
  }
}

extension TensorArrayProtocol {
  var tfeTensorHandles: [TFETensorHandle] {
    self._tensorHandles.map { $0 as! TFETensorHandle }
  }
}

final class TensorGroupTests: XCTestCase {
  func testEmptyList() {
    XCTAssertEqual([], Empty._typeList)
    XCTAssertEqual(Empty()._tensorHandles.count, 0)
  }

  func testSimpleTypeList() {
    let float = Float.tensorFlowDataType
    XCTAssertEqual([float, float], Simple._typeList)
  }

  func testSimpleInit() {
    let w = Tensor<Float>(0.1)
    let b = Tensor<Float>(0.1)
    let simple = Simple(w: w, b: b)

    let wHandle = w.handle.makeCopy()
    let bHandle = b.handle.makeCopy()

    let expectedSimple = Simple(_handles: [wHandle, bHandle])
    XCTAssertEqual(expectedSimple, simple)

    let reconstructedSimple = Simple(_handles: simple.tfeTensorHandles)
    XCTAssertEqual(reconstructedSimple, simple)
  }

  func testMixedTypeList() {
    let float = Float.tensorFlowDataType
    let int = Int32.tensorFlowDataType
    XCTAssertEqual([float, int], Mixed._typeList)
  }

  func testMixedInit() {
    let float = Tensor<Float>(0.1)
    let int = Tensor<Int32>(1)
    let mixed = Mixed(float: float, int: int)

    let floatHandle = float.handle.makeCopy()
    let intHandle = int.handle.makeCopy()

    let expectedMixed = Mixed(_handles: [floatHandle, intHandle])
    XCTAssertEqual(expectedMixed, mixed)

    let reconstructedMixed = Mixed(_handles: mixed.tfeTensorHandles)
    XCTAssertEqual(reconstructedMixed, mixed)
  }

  func testNestedTypeList() {
    let float = Float.tensorFlowDataType
    let int = Int32.tensorFlowDataType
    XCTAssertEqual([float, float, float, int], Nested._typeList)
  }

  func testNestedInit() {
    let w = Tensor<Float>(0.1)
    let b = Tensor<Float>(0.1)
    let simple = Simple(w: w, b: b)
    let float = Tensor<Float>(0.1)
    let int = Tensor<Int32>(1)
    let mixed = Mixed(float: float, int: int)
    let nested = Nested(simple: simple, mixed: mixed)

    let wHandle = w.handle.makeCopy()
    let bHandle = b.handle.makeCopy()
    let floatHandle = float.handle.makeCopy()
    let intHandle = int.handle.makeCopy()

    let expectedNested = Nested(
      _handles: [wHandle, bHandle, floatHandle, intHandle])
    XCTAssertEqual(expectedNested, nested)

    let reconstructedNested = Nested(_handles: nested.tfeTensorHandles)
    XCTAssertEqual(reconstructedNested, nested)
  }

  func testGenericTypeList() {
    let float = Float.tensorFlowDataType
    let int = Int32.tensorFlowDataType
    XCTAssertEqual(
      [float, float, float, int], Generic<Simple, Mixed>._typeList)
  }

  func testGenericInit() {
    let w = Tensor<Float>(0.1)
    let b = Tensor<Float>(0.1)
    let simple = Simple(w: w, b: b)
    let float = Tensor<Float>(0.1)
    let int = Tensor<Int32>(1)
    let mixed = Mixed(float: float, int: int)
    let generic = Generic(t: simple, u: mixed)

    let wHandle = w.handle.makeCopy()
    let bHandle = b.handle.makeCopy()
    let floatHandle = float.handle.makeCopy()
    let intHandle = int.handle.makeCopy()

    let expectedGeneric = Generic<Simple, Mixed>(
      _handles: [wHandle, bHandle, floatHandle, intHandle])
    XCTAssertEqual(expectedGeneric, generic)

    let reconstructedGeneric = Generic<Simple, Mixed>(_handles: generic.tfeTensorHandles)
    XCTAssertEqual(reconstructedGeneric, generic)
  }

  func testNestedGenericTypeList() {
    struct NestedGeneric {
      func function() {
        let float = Float.tensorFlowDataType
        let int = Int32.tensorFlowDataType
        XCTAssertEqual(
          [float, float, float, int, float, int, float, float],
          UltraNested<Simple, Mixed>._typeList)
      }
    }

    NestedGeneric().function()
  }

  func testNestedGenericInit() {
    struct NestedGeneric {
      func function() {
        let w = Tensor<Float>(0.1)
        let b = Tensor<Float>(0.1)
        let simple = Simple(w: w, b: b)
        let float = Tensor<Float>(0.1)
        let int = Tensor<Int32>(1)
        let mixed = Mixed(float: float, int: int)
        let genericSM = Generic<Simple, Mixed>(t: simple, u: mixed)
        let genericMS = Generic<Mixed, Simple>(t: mixed, u: simple)
        let generic = UltraNested(a: genericSM, b: genericMS)

        let wHandle1 = w.handle.makeCopy()
        let wHandle2 = w.handle.makeCopy()
        let bHandle1 = b.handle.makeCopy()
        let bHandle2 = b.handle.makeCopy()
        let floatHandle1 = float.handle.makeCopy()
        let floatHandle2 = float.handle.makeCopy()
        let intHandle1 = int.handle.makeCopy()
        let intHandle2 = int.handle.makeCopy()

        let expectedGeneric = UltraNested<Simple, Mixed>(
          _handles: [
            wHandle1, bHandle1, floatHandle1, intHandle1,
            floatHandle2, intHandle2, wHandle2, bHandle2,
          ])
        XCTAssertEqual(expectedGeneric, generic)

        let reconstructedGeneric = UltraNested<Simple, Mixed>(
          _handles: generic.tfeTensorHandles)
        XCTAssertEqual(reconstructedGeneric, generic)
      }
    }

    NestedGeneric().function()
  }

  static var allTests = [
    ("testEmptyList", testEmptyList),
    ("testSimpleTypeList", testSimpleTypeList),
    ("testSimpleInit", testSimpleInit),
    ("testMixedTypelist", testMixedTypeList),
    ("testMixedInit", testMixedInit),
    ("testNestedTypeList", testNestedTypeList),
    ("testNestedInit", testNestedInit),
    ("testGenericTypeList", testGenericTypeList),
    ("testGenericInit", testGenericInit),
    ("testNestedGenericTypeList", testNestedGenericTypeList),
    ("testNestedGenericInit", testNestedGenericInit),
  ]

}