# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the Neural Programmer model described in https://openreview.net/pdf?id=ry2YOrcge

This file calls functions to load & pre-process data, construct the TF graph
and performs training or evaluation as specified by the flag evaluator_job
Author: aneelakantan (Arvind Neelakantan)
"""
from __future__ import print_function

import time
from random import Random
import numpy as np
import tensorflow as tf
import model
import wiki_data
import parameters
import data_utils

tf.flags.DEFINE_integer("train_steps", 100001, "Number of steps to train")
tf.flags.DEFINE_integer("eval_cycle", 500,
                        "Evaluate model at every eval_cycle steps")
tf.flags.DEFINE_integer("max_elements", 100,
                        "maximum rows that are  considered for processing")
tf.flags.DEFINE_integer(
    "max_number_cols", 15,
    "maximum number columns that are considered for processing")
tf.flags.DEFINE_integer(
    "max_word_cols", 25,
    "maximum number columns that are considered for processing")
tf.flags.DEFINE_integer("question_length", 62, "maximum question length")
tf.flags.DEFINE_integer("max_entry_length", 1, "")
tf.flags.DEFINE_integer("max_passes", 4, "number of operation passes")
tf.flags.DEFINE_integer("embedding_dims", 256, "")
tf.flags.DEFINE_integer("batch_size", 20, "")
tf.flags.DEFINE_float("clip_gradients", 1.0, "")
tf.flags.DEFINE_float("eps", 1e-6, "")
tf.flags.DEFINE_float("param_init", 0.1, "")
tf.flags.DEFINE_float("learning_rate", 0.001, "")
tf.flags.DEFINE_float("l2_regularizer", 0.0001, "")
tf.flags.DEFINE_float("print_cost", 50.0,
                      "weighting factor in the objective function")
tf.flags.DEFINE_string("job_id", "temp", """job id""")
tf.flags.DEFINE_string("output_dir", "../model/",
                       """output_dir""")
tf.flags.DEFINE_string("data_dir", "../data/",
                       """data_dir""")
tf.flags.DEFINE_integer("write_every", 500, "wrtie every N")
tf.flags.DEFINE_integer("param_seed", 150, "")
tf.flags.DEFINE_integer("python_seed", 200, "")
tf.flags.DEFINE_float("dropout", 0.8, "dropout keep probability")
tf.flags.DEFINE_float("rnn_dropout", 0.9,
                      "dropout keep probability for rnn connections")
tf.flags.DEFINE_float("pad_int", -20000.0,
                      "number columns are padded with pad_int")
tf.flags.DEFINE_string("data_type", "double", "float or double")
tf.flags.DEFINE_float("word_dropout_prob", 0.9, "word dropout keep prob")
tf.flags.DEFINE_integer("word_cutoff", 10, "")
tf.flags.DEFINE_integer("vocab_size", 10800, "")
tf.flags.DEFINE_boolean("evaluator_job", False,
                        "wehther to run as trainer/evaluator")
tf.flags.DEFINE_float(
    "bad_number_pre_process", -200000.0,
    "number that is added to a corrupted table entry in a number column")
tf.flags.DEFINE_float("max_math_error", 3.0,
                      "max square loss error that is considered")
tf.flags.DEFINE_float("soft_min_value", 5.0, "")
FLAGS = tf.flags.FLAGS


class Utility:
  #holds FLAGS and other variables that are used in different files
  def __init__(self):
    global FLAGS
    self.FLAGS = FLAGS
    self.unk_token = "UNK"
    self.entry_match_token = "entry_match"
    self.column_match_token = "column_match"
    self.dummy_token = "dummy_token"
    self.tf_data_type = {}
    self.tf_data_type["double"] = tf.float64
    self.tf_data_type["float"] = tf.float32
    self.np_data_type = {}
    self.np_data_type["double"] = np.float64
    self.np_data_type["float"] = np.float32
    self.operations_set = ["count"] + [
        "prev", "next", "first_rs", "last_rs", "group_by_max", "greater",
        "lesser", "geq", "leq", "max", "min", "word-match"
    ] + ["reset_select"] + ["print"]
    self.word_ids = {}
    self.reverse_word_ids = {}
    self.word_count = {}
    self.random = Random(FLAGS.python_seed)


def evaluate(sess, data, batch_size, graph, i):
  #computes accuracy
  num_examples = 0.0
  gc = 0.0
  for j in range(0, len(data) - batch_size + 1, batch_size):
    [ct] = sess.run([graph.final_correct],
                    feed_dict=data_utils.generate_feed_dict(data, j, batch_size,
                                                            graph))
    gc += ct * batch_size
    num_examples += batch_size
  print("dev set accuracy   after ", i, " : ", gc / num_examples)
  print(num_examples, len(data))
  print("--------")


def Train(graph, utility, batch_size, train_data, sess, model_dir,
          saver):
  #performs training
  curr = 0
  train_set_loss = 0.0
  utility.random.shuffle(train_data)
  start = time.time()
  for i in range(utility.FLAGS.train_steps):
    curr_step = i
    if (i > 0 and i % FLAGS.write_every == 0):
      model_file = model_dir + "/model_" + str(i)
      saver.save(sess, model_file)
    if curr + batch_size >= len(train_data):
      curr = 0
      utility.random.shuffle(train_data)
    step, cost_value = sess.run(
        [graph.step, graph.total_cost],
        feed_dict=data_utils.generate_feed_dict(
            train_data, curr, batch_size, graph, train=True, utility=utility))
    curr = curr + batch_size
    train_set_loss += cost_value
    if (i > 0 and i % FLAGS.eval_cycle == 0):
      end = time.time()
      time_taken = end - start
      print("step ", i, " ", time_taken, " seconds ")
      start = end
      print(" printing train set loss: ", train_set_loss / utility.FLAGS.eval_cycle)
      train_set_loss = 0.0


def master(train_data, dev_data, utility):
  #creates TF graph and calls trainer or evaluator
  batch_size = utility.FLAGS.batch_size
  model_dir = utility.FLAGS.output_dir + "/model" + utility.FLAGS.job_id + "/"
  #create all paramters of the model
  param_class = parameters.Parameters(utility)
  params, global_step, init = param_class.parameters(utility)
  key = "test" if (FLAGS.evaluator_job) else "train"
  graph = model.Graph(utility, batch_size, utility.FLAGS.max_passes, mode=key)
  graph.create_graph(params, global_step)
  prev_dev_error = 0.0
  final_loss = 0.0
  final_accuracy = 0.0
  #start session
  with tf.Session() as sess:
    sess.run(init.name)
    sess.run(graph.init_op.name)
    to_save = params.copy()
    saver = tf.train.Saver(to_save, max_to_keep=500)
    if (FLAGS.evaluator_job):
      while True:
        selected_models = {}
        file_list = tf.gfile.ListDirectory(model_dir)
        for model_file in file_list:
          if ("checkpoint" in model_file or "index" in model_file or
              "meta" in model_file):
            continue
          if ("data" in model_file):
            model_file = model_file.split(".")[0]
          model_step = int(
              model_file.split("_")[len(model_file.split("_")) - 1])
          selected_models[model_step] = model_file
        file_list = sorted(selected_models.items(), key=lambda x: x[0])
        if (len(file_list) > 0):
          file_list = file_list[0:len(file_list) - 1]
        print("list of models: ", file_list)
        for model_file in file_list:
          model_file = model_file[1]
          print("restoring: ", model_file)
          saver.restore(sess, model_dir + "/" + model_file)
          model_step = int(
              model_file.split("_")[len(model_file.split("_")) - 1])
          print("evaluating on dev ", model_file, model_step)
          evaluate(sess, dev_data, batch_size, graph, model_step)
    else:
      ckpt = tf.train.get_checkpoint_state(model_dir)
      print("model dir: ", model_dir)
      if (not (tf.gfile.IsDirectory(utility.FLAGS.output_dir))):
        print("create dir: ", utility.FLAGS.output_dir)
        tf.gfile.MkDir(utility.FLAGS.output_dir)
      if (not (tf.gfile.IsDirectory(model_dir))):
        print("create dir: ", model_dir)
        tf.gfile.MkDir(model_dir)
      Train(graph, utility, batch_size, train_data, sess, model_dir,
            saver)

def main(args):
  utility = Utility()
  train_name = "random-split-1-train.examples"
  dev_name = "random-split-1-dev.examples"
  test_name = "pristine-unseen-tables.examples"
  #load data
  dat = wiki_data.WikiQuestionGenerator(train_name, dev_name, test_name, FLAGS.data_dir)
  train_data, dev_data, test_data = dat.load()
  utility.words = []
  utility.word_ids = {}
  utility.reverse_word_ids = {}
  #construct vocabulary
  data_utils.construct_vocab(train_data, utility)
  data_utils.construct_vocab(dev_data, utility, True)
  data_utils.construct_vocab(test_data, utility, True)
  data_utils.add_special_words(utility)
  data_utils.perform_word_cutoff(utility)
  #convert data to int format and pad the inputs
  train_data = data_utils.complete_wiki_processing(train_data, utility, True)
  dev_data = data_utils.complete_wiki_processing(dev_data, utility, False)
  test_data = data_utils.complete_wiki_processing(test_data, utility, False)
  print("# train examples ", len(train_data))
  print("# dev examples ", len(dev_data))
  print("# test examples ", len(test_data))
  print("running open source")
  #construct TF graph and train or evaluate
  master(train_data, dev_data, utility)


if __name__ == "__main__":
  tf.app.run()