from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import numpy as np
import os
import tarfile
import tensorflow as tf

from six.moves import urllib
from tensorflow.python.platform import gfile

flags = tf.app.flags
flags.DEFINE_string('output_json', None, 'specify output json file path')
flags.DEFINE_string('jpeg_pattern', None, 'specify target jpeg file pattern')
flags.DEFINE_string('model_dir', '/tmp/inception_v3', 'specify model directory')
flags.DEFINE_string('architecture', 'inception_v3', 'specify model architecture')
FLAGS = tf.app.flags.FLAGS

def maybe_download_and_extract(data_url):
  """Download and extract model tar file.
  If the pretrained model we're using doesn't already exist, this function
  downloads it from the TensorFlow.org website and unpacks it into a directory.
  Args:
    data_url: Web location of the tar file containing the pretrained model.
  """
  dest_directory = FLAGS.model_dir
  if not os.path.exists(dest_directory):
    os.makedirs(dest_directory)
  filename = data_url.split('/')[-1]
  filepath = os.path.join(dest_directory, filename)
  if not os.path.exists(filepath):
    filepath, _ = urllib.request.urlretrieve(data_url, filepath)
    tarfile.open(filepath, 'r:gz').extractall(dest_directory)
  else:
    tf.logging.info('Not extracting or downloading files, model already present in disk')

def create_model_info(architecture='inception_v3'):
  architecture = architecture.lower()
  is_quantized = False
  if architecture == 'inception_v3':
    # pylint: disable=line-too-long
    data_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
    # pylint: enable=line-too-long
    logits_tensor_name = 'softmax/logits:0'
    probabilities_tensor_name = 'softmax:0'
    bottleneck_tensor_size = 1008
    input_width = 299
    input_height = 299
    input_depth = 3
    resized_input_tensor_name = 'Mul:0'
    model_file_name = 'classify_image_graph_def.pb'
    input_mean = 128
    input_std = 128
  elif architecture.startswith('mobilenet_'):
    parts = architecture.split('_')
    if len(parts) != 3 and len(parts) != 4:
      tf.logging.error("Couldn't understand architecture name '%s'",
                       architecture)
      return None
    version_string = parts[1]
    if (version_string != '1.0' and version_string != '0.75' and
        version_string != '0.50' and version_string != '0.25'):
      tf.logging.error(
          """"The Mobilenet version should be '1.0', '0.75', '0.50', or '0.25',
  but found '%s' for architecture '%s'""",
          version_string, architecture)
      return None
    size_string = parts[2]
    if (size_string != '224' and size_string != '192' and
        size_string != '160' and size_string != '128'):
      tf.logging.error(
          """The Mobilenet input size should be '224', '192', '160', or '128',
 but found '%s' for architecture '%s'""",
          size_string, architecture)
      return None
    if len(parts) == 3:
      is_quantized = False
    else:
      if parts[3] != 'quantized':
        tf.logging.error(
            "Couldn't understand architecture suffix '%s' for '%s'", parts[3],
            architecture)
        return None
      is_quantized = True

    if is_quantized:
      data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
      data_url += version_string + '_' + size_string + '_quantized_frozen.tgz'
      logits_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
      probabilities_tensor_name = 'MobilenetV1/Predictions/Reshape_1:0'
      resized_input_tensor_name = 'Placeholder:0'
      model_dir_name = ('mobilenet_v1_' + version_string + '_' + size_string +
                        '_quantized_frozen')
      model_base_name = 'quantized_frozen_graph.pb'

    else:
      data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
      data_url += version_string + '_' + size_string + '_frozen.tgz'
      logits_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
      probabilities_tensor_name = 'MobilenetV1/Predictions/Reshape_1:0'
      resized_input_tensor_name = 'input:0'
      model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string
      model_base_name = 'frozen_graph.pb'

    bottleneck_tensor_size = 1001
    input_width = int(size_string)
    input_height = int(size_string)
    input_depth = 3
    model_file_name = os.path.join(model_dir_name, model_base_name)
    input_mean = 127.5
    input_std = 127.5
  else:
    tf.logging.error("Couldn't understand architecture name '%s'", architecture)
    raise ValueError('Unknown architecture', architecture)


  return {
      'data_url': data_url,
      'logits_tensor_name': logits_tensor_name,
      'probabilities_tensor_name': probabilities_tensor_name,
      'bottleneck_tensor_size': bottleneck_tensor_size,
      'input_width': input_width,
      'input_height': input_height,
      'input_depth': input_depth,
      'resized_input_tensor_name': resized_input_tensor_name,
      'model_file_name': model_file_name,
      'input_mean': input_mean,
      'input_std': input_std,
      'quantize_layer': is_quantized,
  }

def create_model_graph(graph, model_info):
  """"Creates a graph from saved GraphDef file and returns a Graph object.
  Args:
    model_info: Dictionary containing information about the model architecture.
  Returns:
    Graph holding the trained Inception network, and various tensors we'll be
    manipulating.
  """
  with graph.as_default():
    model_path = os.path.join(FLAGS.model_dir, model_info['model_file_name'])
    tf.logging.info('Model path: {}'.format(model_path))
    with gfile.FastGFile(model_path, 'rb') as f:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(f.read())
      logits_tensor, probabilities_tensor, resized_input_tensor = (
          tf.import_graph_def(
              graph_def,
              name='',
              return_elements=[
                  model_info['logits_tensor_name'],
                  model_info['probabilities_tensor_name'],
                  model_info['resized_input_tensor_name'],
              ]))
  return logits_tensor, probabilities_tensor, resized_input_tensor

def add_jpeg_decoding(input_width, input_height, input_depth, input_mean,
                      input_std):
  """Adds operations that perform JPEG decoding and resizing to the graph..
  Args:
    input_width: Desired width of the image fed into the recognizer graph.
    input_height: Desired width of the image fed into the recognizer graph.
    input_depth: Desired channels of the image fed into the recognizer graph.
    input_mean: Pixel value that should be zero in the image for the graph.
    input_std: How much to divide the pixel values by before recognition.
  Returns:
    Tensors for the node to feed JPEG data into, and the output of the
      preprocessing steps.
  """
  jpeg_data = tf.placeholder(tf.string, name='DecodeJPGInput')
  decoded_image = tf.image.decode_jpeg(jpeg_data, channels=input_depth)
  decoded_image_as_float = tf.cast(decoded_image, dtype=tf.float32)
  decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
  resize_shape = tf.stack([input_height, input_width])
  resize_shape_as_int = tf.cast(resize_shape, dtype=tf.int32)
  resized_image = tf.image.resize_bilinear(decoded_image_4d,
                                           resize_shape_as_int)
  offset_image = tf.subtract(resized_image, input_mean)
  mul_image = tf.multiply(offset_image, 1.0 / input_std)
  return jpeg_data, mul_image

def main():
  output_json = FLAGS.output_json
  target_jpeg = FLAGS.jpeg_pattern
  architecture = FLAGS.architecture
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.logging.info('TensorFlow version: {}'.format(tf.__version__))
  tf.logging.info('output_json: {}'.format(output_json))
  tf.logging.info('target_jpeg: {}'.format(target_jpeg))
  tf.logging.info('architecture: {}'.format(architecture))

  model_info = create_model_info(architecture)
  maybe_download_and_extract(model_info['data_url'])

  graph = tf.Graph()
  logits_tensor, probabilities_tensor, resized_input_tensor = (
      create_model_graph(graph, model_info))

  with graph.as_default():
    jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding(
        model_info['input_width'], model_info['input_height'],
        model_info['input_depth'], model_info['input_mean'],
        model_info['input_std'])

  values = []

  with open(output_json, 'w') as json_file:
    with tf.Session(graph=graph) as sess:
      image_paths = gfile.Glob(target_jpeg)
      group_by = 100
      groups = [image_paths[i:i + group_by] for i in range(0, len(image_paths), group_by)]
      for sub_image_paths in groups:
        for image_path in sub_image_paths:
          image_data = gfile.FastGFile(image_path, 'rb').read()
          resized_input_values = sess.run(decoded_image_tensor,
                                          {jpeg_data_tensor: image_data})
          logits_values, probabilities_values = sess.run(
              [logits_tensor, probabilities_tensor],
              {resized_input_tensor: resized_input_values})

          logits_values = np.squeeze(logits_values)
          probabilities_values = np.squeeze(probabilities_values)

          jobj = {
              'key': os.path.splitext(os.path.basename(image_path))[0],
              'vector': logits_values.tolist(),
              'prob': probabilities_values.tolist()
          }
          json_file.write(json.dumps(jobj) + '\n')
          json_file.flush()

if __name__ == '__main__':
  main()