# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for constructing vocabulary, converting the examples to integer format and building the required masks for batch computation Author: aneelakantan (Arvind Neelakantan)
"""

from __future__ import print_function

import copy
import numbers
import numpy as np
import wiki_data


def return_index(a):
  for i in range(len(a)):
    if (a[i] == 1.0):
      return i


def construct_vocab(data, utility, add_word=False):
  ans = []
  for example in data:
    sent = ""
    for word in example.question:
      if (not (isinstance(word, numbers.Number))):
        sent += word + " "
    example.original_nc = copy.deepcopy(example.number_columns)
    example.original_wc = copy.deepcopy(example.word_columns)
    example.original_nc_names = copy.deepcopy(example.number_column_names)
    example.original_wc_names = copy.deepcopy(example.word_column_names)
    if (add_word):
      continue
    number_found = 0
    if (not (example.is_bad_example)):
      for word in example.question:
        if (isinstance(word, numbers.Number)):
          number_found += 1
        else:
          if (not (utility.word_ids.has_key(word))):
            utility.words.append(word)
            utility.word_count[word] = 1
            utility.word_ids[word] = len(utility.word_ids)
            utility.reverse_word_ids[utility.word_ids[word]] = word
          else:
            utility.word_count[word] += 1
      for col_name in example.word_column_names:
        for word in col_name:
          if (isinstance(word, numbers.Number)):
            number_found += 1
          else:
            if (not (utility.word_ids.has_key(word))):
              utility.words.append(word)
              utility.word_count[word] = 1
              utility.word_ids[word] = len(utility.word_ids)
              utility.reverse_word_ids[utility.word_ids[word]] = word
            else:
              utility.word_count[word] += 1
      for col_name in example.number_column_names:
        for word in col_name:
          if (isinstance(word, numbers.Number)):
            number_found += 1
          else:
            if (not (utility.word_ids.has_key(word))):
              utility.words.append(word)
              utility.word_count[word] = 1
              utility.word_ids[word] = len(utility.word_ids)
              utility.reverse_word_ids[utility.word_ids[word]] = word
            else:
              utility.word_count[word] += 1


def word_lookup(word, utility):
  if (utility.word_ids.has_key(word)):
    return word
  else:
    return utility.unk_token


def convert_to_int_2d_and_pad(a, utility):
  ans = []
  #print a
  for b in a:
    temp = []
    if (len(b) > utility.FLAGS.max_entry_length):
      b = b[0:utility.FLAGS.max_entry_length]
    for remaining in range(len(b), utility.FLAGS.max_entry_length):
      b.append(utility.dummy_token)
    assert len(b) == utility.FLAGS.max_entry_length
    for word in b:
      temp.append(utility.word_ids[word_lookup(word, utility)])
    ans.append(temp)
  #print ans
  return ans


def convert_to_bool_and_pad(a, utility):
  a = a.tolist()
  for i in range(len(a)):
    for j in range(len(a[i])):
      if (a[i][j] < 1):
        a[i][j] = False
      else:
        a[i][j] = True
    a[i] = a[i] + [False] * (utility.FLAGS.max_elements - len(a[i]))
  return a


seen_tables = {}


def partial_match(question, table, number):
  answer = []
  match = {}
  for i in range(len(table)):
    temp = []
    for j in range(len(table[i])):
      temp.append(0)
    answer.append(temp)
  for i in range(len(table)):
    for j in range(len(table[i])):
      for word in question:
        if (number):
          if (word == table[i][j]):
            answer[i][j] = 1.0
            match[i] = 1.0
        else:
          if (word in table[i][j]):
            answer[i][j] = 1.0
            match[i] = 1.0
  return answer, match


def exact_match(question, table, number):
  #performs exact match operation
  answer = []
  match = {}
  matched_indices = []
  for i in range(len(table)):
    temp = []
    for j in range(len(table[i])):
      temp.append(0)
    answer.append(temp)
  for i in range(len(table)):
    for j in range(len(table[i])):
      if (number):
        for word in question:
          if (word == table[i][j]):
            match[i] = 1.0
            answer[i][j] = 1.0
      else:
        table_entry = table[i][j]
        for k in range(len(question)):
          if (k + len(table_entry) <= len(question)):
            if (table_entry == question[k:(k + len(table_entry))]):
              #if(len(table_entry) == 1):
              #print "match: ", table_entry, question
              match[i] = 1.0
              answer[i][j] = 1.0
              matched_indices.append((k, len(table_entry)))
  return answer, match, matched_indices


def partial_column_match(question, table, number):
  answer = []
  for i in range(len(table)):
    answer.append(0)
  for i in range(len(table)):
    for word in question:
      if (word in table[i]):
        answer[i] = 1.0
  return answer


def exact_column_match(question, table, number):
  #performs exact match on column names
  answer = []
  matched_indices = []
  for i in range(len(table)):
    answer.append(0)
  for i in range(len(table)):
    table_entry = table[i]
    for k in range(len(question)):
      if (k + len(table_entry) <= len(question)):
        if (table_entry == question[k:(k + len(table_entry))]):
          answer[i] = 1.0
          matched_indices.append((k, len(table_entry)))
  return answer, matched_indices


def get_max_entry(a):
  e = {}
  for w in a:
    if (w != "UNK, "):
      if (e.has_key(w)):
        e[w] += 1
      else:
        e[w] = 1
  if (len(e) > 0):
    (key, val) = sorted(e.items(), key=lambda x: -1 * x[1])[0]
    if (val > 1):
      return key
    else:
      return -1.0
  else:
    return -1.0


def list_join(a):
  ans = ""
  for w in a:
    ans += str(w) + ", "
  return ans


def group_by_max(table, number):
  #computes the most frequently occurring entry in a column
  answer = []
  for i in range(len(table)):
    temp = []
    for j in range(len(table[i])):
      temp.append(0)
    answer.append(temp)
  for i in range(len(table)):
    if (number):
      curr = table[i]
    else:
      curr = [list_join(w) for w in table[i]]
    max_entry = get_max_entry(curr)
    #print i, max_entry
    for j in range(len(curr)):
      if (max_entry == curr[j]):
        answer[i][j] = 1.0
      else:
        answer[i][j] = 0.0
  return answer


def pick_one(a):
  for i in range(len(a)):
    if (1.0 in a[i]):
      return True
  return False


def check_processed_cols(col, utility):
  return True in [
      True for y in col
      if (y != utility.FLAGS.pad_int and y !=
          utility.FLAGS.bad_number_pre_process)
  ]


def complete_wiki_processing(data, utility, train=True):
  #convert to integers and padding
  processed_data = []
  num_bad_examples = 0
  for example in data:
    number_found = 0
    if (example.is_bad_example):
      num_bad_examples += 1
    if (not (example.is_bad_example)):
      example.string_question = example.question[:]
      #entry match
      example.processed_number_columns = example.processed_number_columns[:]
      example.processed_word_columns = example.processed_word_columns[:]
      example.word_exact_match, word_match, matched_indices = exact_match(
          example.string_question, example.original_wc, number=False)
      example.number_exact_match, number_match, _ = exact_match(
          example.string_question, example.original_nc, number=True)
      if (not (pick_one(example.word_exact_match)) and not (
          pick_one(example.number_exact_match))):
        assert len(word_match) == 0
        assert len(number_match) == 0
        example.word_exact_match, word_match = partial_match(
            example.string_question, example.original_wc, number=False)
      #group by max
      example.word_group_by_max = group_by_max(example.original_wc, False)
      example.number_group_by_max = group_by_max(example.original_nc, True)
      #column name match
      example.word_column_exact_match, wcol_matched_indices = exact_column_match(
          example.string_question, example.original_wc_names, number=False)
      example.number_column_exact_match, ncol_matched_indices = exact_column_match(
          example.string_question, example.original_nc_names, number=False)
      if (not (1.0 in example.word_column_exact_match) and not (
          1.0 in example.number_column_exact_match)):
        example.word_column_exact_match = partial_column_match(
            example.string_question, example.original_wc_names, number=False)
        example.number_column_exact_match = partial_column_match(
            example.string_question, example.original_nc_names, number=False)
      if (len(word_match) > 0 or len(number_match) > 0):
        example.question.append(utility.entry_match_token)
      if (1.0 in example.word_column_exact_match or
          1.0 in example.number_column_exact_match):
        example.question.append(utility.column_match_token)
      example.string_question = example.question[:]
      example.number_lookup_matrix = np.transpose(
          example.number_lookup_matrix)[:]
      example.word_lookup_matrix = np.transpose(example.word_lookup_matrix)[:]
      example.columns = example.number_columns[:]
      example.word_columns = example.word_columns[:]
      example.len_total_cols = len(example.word_column_names) + len(
          example.number_column_names)
      example.column_names = example.number_column_names[:]
      example.word_column_names = example.word_column_names[:]
      example.string_column_names = example.number_column_names[:]
      example.string_word_column_names = example.word_column_names[:]
      example.sorted_number_index = []
      example.sorted_word_index = []
      example.column_mask = []
      example.word_column_mask = []
      example.processed_column_mask = []
      example.processed_word_column_mask = []
      example.word_column_entry_mask = []
      example.question_attention_mask = []
      example.question_number = example.question_number_1 = -1
      example.question_attention_mask = []
      example.ordinal_question = []
      example.ordinal_question_one = []
      new_question = []
      if (len(example.number_columns) > 0):
        example.len_col = len(example.number_columns[0])
      else:
        example.len_col = len(example.word_columns[0])
      for (start, length) in matched_indices:
        for j in range(length):
          example.question[start + j] = utility.unk_token
      #print example.question
      for word in example.question:
        if (isinstance(word, numbers.Number) or wiki_data.is_date(word)):
          if (not (isinstance(word, numbers.Number)) and
              wiki_data.is_date(word)):
            word = word.replace("X", "").replace("-", "")
          number_found += 1
          if (number_found == 1):
            example.question_number = word
            if (len(example.ordinal_question) > 0):
              example.ordinal_question[len(example.ordinal_question) - 1] = 1.0
            else:
              example.ordinal_question.append(1.0)
          elif (number_found == 2):
            example.question_number_1 = word
            if (len(example.ordinal_question_one) > 0):
              example.ordinal_question_one[len(example.ordinal_question_one) -
                                           1] = 1.0
            else:
              example.ordinal_question_one.append(1.0)
        else:
          new_question.append(word)
          example.ordinal_question.append(0.0)
          example.ordinal_question_one.append(0.0)
      example.question = [
          utility.word_ids[word_lookup(w, utility)] for w in new_question
      ]
      example.question_attention_mask = [0.0] * len(example.question)
      #when the first question number occurs before a word
      example.ordinal_question = example.ordinal_question[0:len(
          example.question)]
      example.ordinal_question_one = example.ordinal_question_one[0:len(
          example.question)]
      #question-padding
      example.question = [utility.word_ids[utility.dummy_token]] * (
          utility.FLAGS.question_length - len(example.question)
      ) + example.question
      example.question_attention_mask = [-10000.0] * (
          utility.FLAGS.question_length - len(example.question_attention_mask)
      ) + example.question_attention_mask
      example.ordinal_question = [0.0] * (utility.FLAGS.question_length -
                                          len(example.ordinal_question)
                                         ) + example.ordinal_question
      example.ordinal_question_one = [0.0] * (utility.FLAGS.question_length -
                                              len(example.ordinal_question_one)
                                             ) + example.ordinal_question_one
      if (True):
        #number columns and related-padding
        num_cols = len(example.columns)
        start = 0
        for column in example.number_columns:
          if (check_processed_cols(example.processed_number_columns[start],
                                   utility)):
            example.processed_column_mask.append(0.0)
          sorted_index = sorted(
              range(len(example.processed_number_columns[start])),
              key=lambda k: example.processed_number_columns[start][k],
              reverse=True)
          sorted_index = sorted_index + [utility.FLAGS.pad_int] * (
              utility.FLAGS.max_elements - len(sorted_index))
          example.sorted_number_index.append(sorted_index)
          example.columns[start] = column + [utility.FLAGS.pad_int] * (
              utility.FLAGS.max_elements - len(column))
          example.processed_number_columns[start] += [utility.FLAGS.pad_int] * (
              utility.FLAGS.max_elements -
              len(example.processed_number_columns[start]))
          start += 1
          example.column_mask.append(0.0)
        for remaining in range(num_cols, utility.FLAGS.max_number_cols):
          example.sorted_number_index.append([utility.FLAGS.pad_int] *
                                             (utility.FLAGS.max_elements))
          example.columns.append([utility.FLAGS.pad_int] *
                                 (utility.FLAGS.max_elements))
          example.processed_number_columns.append([utility.FLAGS.pad_int] *
                                                  (utility.FLAGS.max_elements))
          example.number_exact_match.append([0.0] *
                                            (utility.FLAGS.max_elements))
          example.number_group_by_max.append([0.0] *
                                             (utility.FLAGS.max_elements))
          example.column_mask.append(-100000000.0)
          example.processed_column_mask.append(-100000000.0)
          example.number_column_exact_match.append(0.0)
          example.column_names.append([utility.dummy_token])
        #word column  and related-padding
        start = 0
        word_num_cols = len(example.word_columns)
        for column in example.word_columns:
          if (check_processed_cols(example.processed_word_columns[start],
                                   utility)):
            example.processed_word_column_mask.append(0.0)
          sorted_index = sorted(
              range(len(example.processed_word_columns[start])),
              key=lambda k: example.processed_word_columns[start][k],
              reverse=True)
          sorted_index = sorted_index + [utility.FLAGS.pad_int] * (
              utility.FLAGS.max_elements - len(sorted_index))
          example.sorted_word_index.append(sorted_index)
          column = convert_to_int_2d_and_pad(column, utility)
          example.word_columns[start] = column + [[
              utility.word_ids[utility.dummy_token]
          ] * utility.FLAGS.max_entry_length] * (utility.FLAGS.max_elements -
                                                 len(column))
          example.processed_word_columns[start] += [utility.FLAGS.pad_int] * (
              utility.FLAGS.max_elements -
              len(example.processed_word_columns[start]))
          example.word_column_entry_mask.append([0] * len(column) + [
              utility.word_ids[utility.dummy_token]
          ] * (utility.FLAGS.max_elements - len(column)))
          start += 1
          example.word_column_mask.append(0.0)
        for remaining in range(word_num_cols, utility.FLAGS.max_word_cols):
          example.sorted_word_index.append([utility.FLAGS.pad_int] *
                                           (utility.FLAGS.max_elements))
          example.word_columns.append([[utility.word_ids[utility.dummy_token]] *
                                       utility.FLAGS.max_entry_length] *
                                      (utility.FLAGS.max_elements))
          example.word_column_entry_mask.append(
              [utility.word_ids[utility.dummy_token]] *
              (utility.FLAGS.max_elements))
          example.word_exact_match.append([0.0] * (utility.FLAGS.max_elements))
          example.word_group_by_max.append([0.0] * (utility.FLAGS.max_elements))
          example.processed_word_columns.append([utility.FLAGS.pad_int] *
                                                (utility.FLAGS.max_elements))
          example.word_column_mask.append(-100000000.0)
          example.processed_word_column_mask.append(-100000000.0)
          example.word_column_exact_match.append(0.0)
          example.word_column_names.append([utility.dummy_token] *
                                           utility.FLAGS.max_entry_length)
        seen_tables[example.table_key] = 1
      #convert column and word column names to integers
      example.column_ids = convert_to_int_2d_and_pad(example.column_names,
                                                     utility)
      example.word_column_ids = convert_to_int_2d_and_pad(
          example.word_column_names, utility)
      for i_em in range(len(example.number_exact_match)):
        example.number_exact_match[i_em] = example.number_exact_match[
            i_em] + [0.0] * (utility.FLAGS.max_elements -
                             len(example.number_exact_match[i_em]))
        example.number_group_by_max[i_em] = example.number_group_by_max[
            i_em] + [0.0] * (utility.FLAGS.max_elements -
                             len(example.number_group_by_max[i_em]))
      for i_em in range(len(example.word_exact_match)):
        example.word_exact_match[i_em] = example.word_exact_match[
            i_em] + [0.0] * (utility.FLAGS.max_elements -
                             len(example.word_exact_match[i_em]))
        example.word_group_by_max[i_em] = example.word_group_by_max[
            i_em] + [0.0] * (utility.FLAGS.max_elements -
                             len(example.word_group_by_max[i_em]))
      example.exact_match = example.number_exact_match + example.word_exact_match
      example.group_by_max = example.number_group_by_max + example.word_group_by_max
      example.exact_column_match = example.number_column_exact_match + example.word_column_exact_match
      #answer and related mask, padding
      if (example.is_lookup):
        example.answer = example.calc_answer
        example.number_print_answer = example.number_lookup_matrix.tolist()
        example.word_print_answer = example.word_lookup_matrix.tolist()
        for i_answer in range(len(example.number_print_answer)):
          example.number_print_answer[i_answer] = example.number_print_answer[
              i_answer] + [0.0] * (utility.FLAGS.max_elements -
                                   len(example.number_print_answer[i_answer]))
        for i_answer in range(len(example.word_print_answer)):
          example.word_print_answer[i_answer] = example.word_print_answer[
              i_answer] + [0.0] * (utility.FLAGS.max_elements -
                                   len(example.word_print_answer[i_answer]))
        example.number_lookup_matrix = convert_to_bool_and_pad(
            example.number_lookup_matrix, utility)
        example.word_lookup_matrix = convert_to_bool_and_pad(
            example.word_lookup_matrix, utility)
        for remaining in range(num_cols, utility.FLAGS.max_number_cols):
          example.number_lookup_matrix.append([False] *
                                              utility.FLAGS.max_elements)
          example.number_print_answer.append([0.0] * utility.FLAGS.max_elements)
        for remaining in range(word_num_cols, utility.FLAGS.max_word_cols):
          example.word_lookup_matrix.append([False] *
                                            utility.FLAGS.max_elements)
          example.word_print_answer.append([0.0] * utility.FLAGS.max_elements)
        example.print_answer = example.number_print_answer + example.word_print_answer
      else:
        example.answer = example.calc_answer
        example.print_answer = [[0.0] * (utility.FLAGS.max_elements)] * (
            utility.FLAGS.max_number_cols + utility.FLAGS.max_word_cols)
      #question_number masks
      if (example.question_number == -1):
        example.question_number_mask = np.zeros([utility.FLAGS.max_elements])
      else:
        example.question_number_mask = np.ones([utility.FLAGS.max_elements])
      if (example.question_number_1 == -1):
        example.question_number_one_mask = -10000.0
      else:
        example.question_number_one_mask = np.float64(0.0)
      if (example.len_col > utility.FLAGS.max_elements):
        continue
      processed_data.append(example)
  return processed_data


def add_special_words(utility):
  utility.words.append(utility.entry_match_token)
  utility.word_ids[utility.entry_match_token] = len(utility.word_ids)
  utility.reverse_word_ids[utility.word_ids[
      utility.entry_match_token]] = utility.entry_match_token
  utility.entry_match_token_id = utility.word_ids[utility.entry_match_token]
  print("entry match token: ", utility.word_ids[
      utility.entry_match_token], utility.entry_match_token_id)
  utility.words.append(utility.column_match_token)
  utility.word_ids[utility.column_match_token] = len(utility.word_ids)
  utility.reverse_word_ids[utility.word_ids[
      utility.column_match_token]] = utility.column_match_token
  utility.column_match_token_id = utility.word_ids[utility.column_match_token]
  print("entry match token: ", utility.word_ids[
      utility.column_match_token], utility.column_match_token_id)
  utility.words.append(utility.dummy_token)
  utility.word_ids[utility.dummy_token] = len(utility.word_ids)
  utility.reverse_word_ids[utility.word_ids[
      utility.dummy_token]] = utility.dummy_token
  utility.dummy_token_id = utility.word_ids[utility.dummy_token]
  utility.words.append(utility.unk_token)
  utility.word_ids[utility.unk_token] = len(utility.word_ids)
  utility.reverse_word_ids[utility.word_ids[
      utility.unk_token]] = utility.unk_token


def perform_word_cutoff(utility):
  if (utility.FLAGS.word_cutoff > 0):
    for word in utility.word_ids.keys():
      if (utility.word_count.has_key(word) and utility.word_count[word] <
          utility.FLAGS.word_cutoff and word != utility.unk_token and
          word != utility.dummy_token and word != utility.entry_match_token and
          word != utility.column_match_token):
        utility.word_ids.pop(word)
        utility.words.remove(word)


def word_dropout(question, utility):
  if (utility.FLAGS.word_dropout_prob > 0.0):
    new_question = []
    for i in range(len(question)):
      if (question[i] != utility.dummy_token_id and
          utility.random.random() > utility.FLAGS.word_dropout_prob):
        new_question.append(utility.word_ids[utility.unk_token])
      else:
        new_question.append(question[i])
    return new_question
  else:
    return question


def generate_feed_dict(data, curr, batch_size, gr, train=False, utility=None):
  #prepare feed dict dictionary
  feed_dict = {}
  feed_examples = []
  for j in range(batch_size):
    feed_examples.append(data[curr + j])
  if (train):
    feed_dict[gr.batch_question] = [
        word_dropout(feed_examples[j].question, utility)
        for j in range(batch_size)
    ]
  else:
    feed_dict[gr.batch_question] = [
        feed_examples[j].question for j in range(batch_size)
    ]
  feed_dict[gr.batch_question_attention_mask] = [
      feed_examples[j].question_attention_mask for j in range(batch_size)
  ]
  feed_dict[
      gr.batch_answer] = [feed_examples[j].answer for j in range(batch_size)]
  feed_dict[gr.batch_number_column] = [
      feed_examples[j].columns for j in range(batch_size)
  ]
  feed_dict[gr.batch_processed_number_column] = [
      feed_examples[j].processed_number_columns for j in range(batch_size)
  ]
  feed_dict[gr.batch_processed_sorted_index_number_column] = [
      feed_examples[j].sorted_number_index for j in range(batch_size)
  ]
  feed_dict[gr.batch_processed_sorted_index_word_column] = [
      feed_examples[j].sorted_word_index for j in range(batch_size)
  ]
  feed_dict[gr.batch_question_number] = np.array(
      [feed_examples[j].question_number for j in range(batch_size)]).reshape(
          (batch_size, 1))
  feed_dict[gr.batch_question_number_one] = np.array(
      [feed_examples[j].question_number_1 for j in range(batch_size)]).reshape(
          (batch_size, 1))
  feed_dict[gr.batch_question_number_mask] = [
      feed_examples[j].question_number_mask for j in range(batch_size)
  ]
  feed_dict[gr.batch_question_number_one_mask] = np.array(
      [feed_examples[j].question_number_one_mask for j in range(batch_size)
      ]).reshape((batch_size, 1))
  feed_dict[gr.batch_print_answer] = [
      feed_examples[j].print_answer for j in range(batch_size)
  ]
  feed_dict[gr.batch_exact_match] = [
      feed_examples[j].exact_match for j in range(batch_size)
  ]
  feed_dict[gr.batch_group_by_max] = [
      feed_examples[j].group_by_max for j in range(batch_size)
  ]
  feed_dict[gr.batch_column_exact_match] = [
      feed_examples[j].exact_column_match for j in range(batch_size)
  ]
  feed_dict[gr.batch_ordinal_question] = [
      feed_examples[j].ordinal_question for j in range(batch_size)
  ]
  feed_dict[gr.batch_ordinal_question_one] = [
      feed_examples[j].ordinal_question_one for j in range(batch_size)
  ]
  feed_dict[gr.batch_number_column_mask] = [
      feed_examples[j].column_mask for j in range(batch_size)
  ]
  feed_dict[gr.batch_number_column_names] = [
      feed_examples[j].column_ids for j in range(batch_size)
  ]
  feed_dict[gr.batch_processed_word_column] = [
      feed_examples[j].processed_word_columns for j in range(batch_size)
  ]
  feed_dict[gr.batch_word_column_mask] = [
      feed_examples[j].word_column_mask for j in range(batch_size)
  ]
  feed_dict[gr.batch_word_column_names] = [
      feed_examples[j].word_column_ids for j in range(batch_size)
  ]
  feed_dict[gr.batch_word_column_entry_mask] = [
      feed_examples[j].word_column_entry_mask for j in range(batch_size)
  ]
  return feed_dict