//===----------------------------------------------------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2024 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

#include "NullEditorConsumer.h"
#include "SourceKit/Core/Context.h"
#include "SourceKit/Core/LangSupport.h"
#include "SourceKit/Core/NotificationCenter.h"
#include "SourceKit/Support/Concurrency.h"
#include "SourceKit/SwiftLang/Factory.h"
#include "swift/Basic/LLVMInitialize.h"
#include "gtest/gtest.h"

#include <chrono>
#include <condition_variable>
#include <mutex>
#include <thread>

using namespace SourceKit;
using namespace llvm;

static StringRef getRuntimeLibPath() {
  return sys::path::parent_path(SWIFTLIB_DIR);
}

static SmallString<128> getSwiftExecutablePath() {
  SmallString<128> path = sys::path::parent_path(getRuntimeLibPath());
  sys::path::append(path, "bin", "swift-frontend");
  return path;
}

namespace {
class CompileTrackingConsumer final : public trace::TraceConsumer {
  std::mutex Mtx;
  std::condition_variable CV;
  bool HasStarted = false;

public:
  CompileTrackingConsumer() {}
  CompileTrackingConsumer(const CompileTrackingConsumer &) = delete;

  void operationStarted(uint64_t OpId, trace::OperationKind OpKind,
                        const trace::SwiftInvocation &Inv,
                        const trace::StringPairs &OpArgs) override {
    std::unique_lock<std::mutex> lk(Mtx);
    HasStarted = true;
    CV.notify_all();
  }

  void waitForBuildToStart() {
    std::unique_lock<std::mutex> lk(Mtx);
    auto secondsToWait = std::chrono::seconds(20);
    auto when = std::chrono::system_clock::now() + secondsToWait;
    CV.wait_until(lk, when, [&]() { return HasStarted; });
    HasStarted = false;
  }

  void operationFinished(uint64_t OpId, trace::OperationKind OpKind,
                         ArrayRef<DiagnosticEntryInfo> Diagnostics) override {}

  swift::OptionSet<trace::OperationKind> desiredOperations() override {
    return trace::OperationKind::PerformSema;
  }
};

class CloseTest : public ::testing::Test {
  std::shared_ptr<SourceKit::Context> Ctx;
  std::shared_ptr<CompileTrackingConsumer> CompileTracker;

  NullEditorConsumer Consumer;

public:
  CloseTest() {
    INITIALIZE_LLVM();
    Ctx = std::make_shared<SourceKit::Context>(
        getSwiftExecutablePath(), getRuntimeLibPath(),
        /*diagnosticDocumentationPath*/ "", SourceKit::createSwiftLangSupport,
        [](SourceKit::Context &Ctx){ return nullptr; },
        /*dispatchOnMain=*/false);
  }

  CompileTrackingConsumer &getCompileTracker() const { return *CompileTracker; }
  LangSupport &getLang() { return Ctx->getSwiftLangSupport(); }

  void SetUp() override {
    CompileTracker = std::make_shared<CompileTrackingConsumer>();
    trace::registerConsumer(CompileTracker.get());
  }

  void TearDown() override {
    trace::unregisterConsumer(CompileTracker.get());
    CompileTracker = nullptr;
  }

  void open(const char *DocName, StringRef Text, ArrayRef<const char *> CArgs) {
    auto Args = makeArgs(DocName, CArgs);
    auto Buf = MemoryBuffer::getMemBufferCopy(Text, DocName);
    getLang().editorOpen(DocName, Buf.get(), Consumer, Args, std::nullopt);
  }

  void close(const char *DocName, bool CancelBuilds, bool RemoveCache) {
    getLang().editorClose(DocName, CancelBuilds, RemoveCache);
  }

  void getDiagnosticsAsync(
      const char *DocName, ArrayRef<const char *> CArgs,
      std::function<void(const RequestResult<DiagnosticsResult> &)> callback) {
    auto Args = makeArgs(DocName, CArgs);
    getLang().getDiagnostics(DocName, Args, /*VFSOpts*/ std::nullopt,
                             /*CancelToken*/ {}, callback);
  }

private:
  std::vector<const char *> makeArgs(const char *DocName,
                                     ArrayRef<const char *> CArgs) {
    std::vector<const char *> Args = CArgs;
    Args.push_back(DocName);
    return Args;
  }
};

} // end anonymous namespace

static const char *getComplexSourceText() {
  // best of luck, type-checker
  return
    "struct A: ExpressibleByIntegerLiteral { init(integerLiteral value: Int) {} }\n"
    "struct B: ExpressibleByIntegerLiteral { init(integerLiteral value: Int) {} }\n"
    "struct C: ExpressibleByIntegerLiteral { init(integerLiteral value: Int) {} }\n"

    "func + (lhs: A, rhs: B) -> A { fatalError() }\n"
    "func + (lhs: B, rhs: C) -> A { fatalError() }\n"
    "func + (lhs: C, rhs: A) -> A { fatalError() }\n"

    "func + (lhs: B, rhs: A) -> B { fatalError() }\n"
    "func + (lhs: C, rhs: B) -> B { fatalError() }\n"
    "func + (lhs: A, rhs: C) -> B { fatalError() }\n"

    "func + (lhs: C, rhs: B) -> C { fatalError() }\n"
    "func + (lhs: B, rhs: C) -> C { fatalError() }\n"
    "func + (lhs: A, rhs: A) -> C { fatalError() }\n"

    "let x: C = 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8\n";
}

TEST_F(CloseTest, Cancel) {
  const char *DocName = "test.swift";
  auto *Contents = getComplexSourceText();
  const char *Args[] = {"-parse-as-library"};

  // Test twice with RemoveCache = false to test both the prior state of
  // the ASTProducer being cached and not cached.
  for (auto RemoveCache : {true, false, false}) {
    open(DocName, Contents, Args);

    Semaphore BuildResultSema(0);

    getDiagnosticsAsync(DocName, Args,
                        [&](const RequestResult<DiagnosticsResult> &Result) {
      EXPECT_TRUE(Result.isCancelled());
      BuildResultSema.signal();
    });

    getCompileTracker().waitForBuildToStart();

    close(DocName, /*CancelBuilds*/ true, RemoveCache);

    bool Expired = BuildResultSema.wait(30 * 1000);
    if (Expired)
      llvm::report_fatal_error("Did not receive a response for the request");
  }
}