import ctypes
import math
import sys
import sdl2
import sdl2.ext
import time
import numpy
import tensorflow as tf
from . import nativehelpers as helpers

planet_radius_km = 6378.1
planet_radius_miles = planet_radius_km / 1.60934


def run(m):
  window = sdl2.ext.Window(
      "Training neural network: memorizing fictional world map",
      size=(1024, 512))
  window.show()
  running = True
  last_clicked = None
  windowsurface = ctypes.pointer(window.get_surface())
  inputs = tf.reshape(inputs_equirectangular(1024, 512), (1024 * 512, 3))
  window.refresh()
  last_rendered = -15
  while running:
    if time.monotonic() - last_rendered > 15:
      outputs = m.heightmap()(inputs)
      image = tensor_to_surface(
          tf.math.multiply(
              255,
              colourize_heightmap(
                  tf.reshape(outputs, (512, 1024, outputs.shape[1])))))
      sdl2.SDL_BlitSurface(image, None, windowsurface, None)
      last_rendered = time.monotonic()
      sdl2.SDL_FreeSurface(image)
    events = sdl2.ext.get_events()
    for event in events:
      if event.type == sdl2.SDL_QUIT:
        running = False
        break
      elif event.type == sdl2.SDL_MOUSEBUTTONUP:
        point = list(inputs[event.button.y * 1024 + event.button.x].numpy())
        solid_angle = landmass_steradians(tf.reshape(outputs, (512, 1024, 1)),
                                          event.button.x, event.button.y)
        print('Clicked land mass: %s square Km | %s square miles' %
              (solid_angle * planet_radius_km**2,
               solid_angle * planet_radius_miles**2))
        if last_clicked != None:
          dug = math.sqrt(
              sum(map((lambda p: (p[0] - p[1])**2), zip(last_clicked, point))))
          arc = math.acos(1 - dug**2 / 2)  # cosine rule, a == b == 1
          print('%s Km | %s miles' %
                (arc * planet_radius_km, arc * planet_radius_miles))
        print(point)
        last_clicked = point
    window.refresh()
  return 0


def inputs_equirectangular(width, height):
  n = numpy.zeros((height, width, 3), numpy.float64)
  helpers.equirectangular(
      ctypes.cast(ctypes.c_voidp(n.ctypes.data),
                  ctypes.POINTER(ctypes.c_double)), ctypes.c_int(width),
      ctypes.c_int(height))
  return tf.constant(n)


def tensor_to_surface(source):
  shape = source.shape
  source = tf.cast(source, tf.uint8).numpy()
  return helpers.render_tensor(ctypes.c_void_p(source.ctypes.data),
                               ctypes.c_int(shape[1]), ctypes.c_int(shape[0]),
                               ctypes.c_int(shape[2]))


def colourize_heightmap(source):
  shape = source.shape
  n = numpy.zeros((shape[0], shape[1], 3), numpy.float64)
  source = tf.cast(source, tf.float64).numpy()
  helpers.colourize_heightmap(
      ctypes.cast(ctypes.c_voidp(n.ctypes.data),
                  ctypes.POINTER(ctypes.c_double)),
      ctypes.cast(ctypes.c_voidp(source.ctypes.data),
                  ctypes.POINTER(ctypes.c_double)), ctypes.c_int(shape[1]),
      ctypes.c_int(shape[0]))
  return tf.constant(n)


def landmass_steradians(source, x, y):
  shape = source.shape
  source = tf.cast(source, tf.float64).numpy()
  return helpers.landmass_steradians(
      ctypes.cast(ctypes.c_voidp(source.ctypes.data),
                  ctypes.POINTER(ctypes.c_double)), ctypes.c_int(shape[1]),
      ctypes.c_int(shape[0]), ctypes.c_int(x), ctypes.c_int(y))