import wx
import wx.aui
import os
import threading
import time
import glob
import re
import torch
from modules.gui.PyCSI import FrameMainGUI
from modules.gui.dialog_about import AboutDialog
from modules.gui.panel_top_bar import PanelTopBar
from modules.gui.dialog_export import DialogExport
from modules.scene.scene import Scene
from modules.gui.dialog_scale import DialogScale
from modules.scene.train_ert import ERTTrainer
from modules.gui.dialog_confirm_overwrite import DialogConfirmOverwrite
from multiprocessing import Process, Queue

class ReconThread():
  def __init__(self, dir, name, result_queue) -> None:
    super().__init__()
    self.dir = dir
    self.name = name
    self.result_queue = result_queue
    self.process = None

  def run_training(self, use_depth) -> None:
    self.process = Process(target=run_recon, args=(self.dir, self.name, use_depth, self.result_queue))
    self.process.start()

  def stop(self) -> None:
    if self.process is not None:
      self.process.terminate()
      self.process.join()
      self.process = None
  
  def is_running(self) -> bool:
    return self.process is not None

def run_recon(dir, name, use_depth, result_queue : Queue) -> None:
  ert = ERTTrainer(dir, name)
  ert.start_training(use_depth)
  result_queue.put(None)

def result_listener(result_queue : Queue, callback):
  result = result_queue.get()
  if result is None:
    callback()

class ProgressListener(threading.Thread):
  def __init__(self, ckpt_dir, update_bar_callback):
    super().__init__()
    self.stop_event = threading.Event()
    self.ckpt_dir = ckpt_dir
    self.update_bar_callback = update_bar_callback
    
  def run(self):
    ckpt = None
    while not self.stop_event.is_set():
      files = glob.glob(f"{self.ckpt_dir}/*.ckpt")
      if len(files) == 1:
        if ckpt is None:
          ckpt = files[0]
        if ckpt != files[0]:
          match = re.findall(r'\d+', ckpt)
          step = int(match[-1])
          self.update_bar_callback(step)
          ckpt = files[0]
      time.sleep(5)
  
  def stop(self):
    self.stop_event.set()


class FrameMainExtended(FrameMainGUI):
  def __init__(self, parent, dark, path):
    super().__init__(parent)

    self.progress_listener = None
    self.dark_mode = dark
    self.stop_flag = False
    self.panel_viewport = wx.Panel(self)
    self.panel_top_bar = PanelTopBar(self, self.FromDIP(wx.Size(4000,80)))
    self.m_mgr.AddPane( self.panel_viewport, wx.aui.AuiPaneInfo().Center().Name("Viewport").Caption( u"Viewport" ).MinSize(0,0).CloseButton( False ).Dock().Resizable().FloatingSize( wx.DefaultSize ).Floatable( False ).PaneBorder(True))
    self.m_mgr.AddPane(self.panel_top_bar, wx.aui.AuiPaneInfo().Top().CloseButton(False).CaptionVisible(False).PaneBorder(False).Fixed())
  
    self.m_mgr.Update()
    self.set_initial_color_mode()
    self.scene = Scene(path)
    self.recon_thread = ReconThread(self.scene.get_dir(), self.scene.get_name(), Queue())
    
    self.Layout()
    self.Update()
    self.Refresh()
    self.PostSizeEventToParent()

  
  def scale(self):
    dia = DialogScale(self)
    dia.Center()

    if dia.ShowModal() == wx.ID_OK:
      self.scale_model(dia.get_scale())
    dia.Destroy()
  
  def toggle_measure(self, toggle):
    self.toggle_measure_vieport(toggle)

  def check_recon(self):
    self.recon_path = self.scene.get_recon_path(False)
    if self.recon_path:
      self.insert_mesh(self.recon_path)
    else:
      self.panel_top_bar.disable_buttons()

  def reconstruct_scene(self):
    self.stop_flag = False
    
    recon_dir = f"{self.scene.get_dir()}/ert_runs/{self.scene.get_name()}"
    os.makedirs(recon_dir, exist_ok=True)
    
    if os.path.exists(recon_dir+"/nerfstudio_models"):
      if not self.data_exists_handler(recon_dir):
        return

    dia = DialogConfirmOverwrite(self, msg= "Use Depth?")
    dia.Center()
    code = dia.ShowModal()
    use_depth = False
    if code == wx.ID_OK:
      use_depth = True
    if code == wx.ID_CANCEL:
      return
    dia.Destroy()

    self.panel_top_bar.button_recon_scene.Disable()
    self.panel_top_bar.button_recon_scene.set_label("Reconstructing...")
    self.panel_top_bar.Refresh()
    self.recon_start = time.time()
    self.scene.unload_ert()
    self.unload_ert_viewport()
    self.set_viewport_to_recon()
    torch.cuda.empty_cache()

    self.recon_thread.run_training(use_depth)
    self.result_listener = threading.Thread(target=result_listener, args=(self.recon_thread.result_queue, self.recon_complete))
    self.result_listener.start()
   
    self.progress_listener = ProgressListener(recon_dir+"/nerfstudio_models", self.update_loading_bar)
    self.progress_listener.start()
    

  def recon_complete(self):
    print(f"Reconstruction complete in {time.time() - self.recon_start} seconds")
    #self.update_loading_bar(11)
    if self.progress_listener:
      self.progress_listener.stop()
    self.load_recon()
    self.panel_top_bar.button_recon_scene.Enable()
    self.panel_top_bar.button_recon_scene.set_label("Reconstruct Scene")
    self.panel_top_bar.Refresh()
    self.stop_flag = True
    

  def data_exists_handler(self, recon_dir):
    checkpoints = glob.glob(recon_dir+f"/nerfstudio_models/*.ckpt")
    
    if len(checkpoints) == 0:
      return True

    matches = re.findall(r'\d+', checkpoints[0])
    num = int(matches[-1])
    if num >= (self.scene.get_image_count()*30)-100: #iter count on checkpoint is never exactly n_images * 30
      dia = DialogConfirmOverwrite(self, msg= "Training data detected. \nStart over?")
      dia.Center()
      
      code = dia.ShowModal()
      dia.Destroy()
      if code == wx.ID_OK:
        return True
      else:
        return False
    return True
    
  def export_recon(self):
    self.unload_ert_viewport()
    dia = DialogExport(self)
    if dia.ShowModal() == wx.ID_OK:
      try:
        data = self.get_export_data()
        self.scene.export_ply(dia.get_export_dir(), data[0], data[1])
      except Exception as e:
        wx.MessageBox(f"Export failed (Try again?): {e}", "Error", wx.OK | wx.ICON_ERROR)    
    self.load_recon()
    dia.Destroy()

  # Sets all items in frame to default color mode (dark)
  def set_initial_color_mode(self):
    # Set Initial Setting to Dark Mode
    dock_art = self.m_mgr.GetArtProvider()
    
    dock_art.SetColour(wx.aui.AUI_DOCKART_INACTIVE_CAPTION_COLOUR, wx.Colour(144,144,144)) # caption color
    dock_art.SetColour(wx.aui.AUI_DOCKART_INACTIVE_CAPTION_GRADIENT_COLOUR, wx.Colour(51,51,51)) # caption gradient color
    # dock_art.SetColour(wx.aui.AUI_DOCKART_INACTIVE_CAPTION_TEXT_COLOUR, wx.Colour(255,255,255)) # caption text color
    dock_art.SetColour(wx.aui.AUI_DOCKART_BORDER_COLOUR, wx.Colour(252,252,252)) # border color
    
    # sash color
    dock_art.SetColour(wx.aui.AUI_DOCKART_GRIPPER_COLOUR, wx.Colour(255,255,255)) 
    dock_art.SetColour(wx.aui.AUI_DOCKART_BACKGROUND_COLOUR, wx.Colour(255,255,255)) 
    dock_art.SetMetric(wx.aui.AUI_DOCKART_PANE_BORDER_SIZE, 0) # Border width
    dock_art.SetMetric(wx.aui.AUI_DOCKART_SASH_SIZE, 1) # Border width

    dock_art.SetMetric(wx.aui.AUI_DOCKART_CAPTION_SIZE, 30)
    dock_art.SetFont(wx.aui.AUI_DOCKART_CAPTION_FONT, wx.Font( 11, wx.FONTFAMILY_SWISS, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_LIGHT, False, "Calibri Light" )) # Set Font

    self.m_mgr.Update()

    
    if self.dark_mode:
      self.toggle_dark_mode(wx.EVT_MENU)
    else:
      self.toggle_light_mode(wx.EVT_MENU)

    
  
  def toggle_view(self, view_mode : int):
    self.toggle_point(view_mode)

  def open_about_dialog(self, event):
    about = AboutDialog(self)
    about.ShowModal()

  # Flips a pane on/off
  def flipPane(self, paneToFlip, menuItem):
    pane = self.m_mgr.GetPane(paneToFlip)
    
    if pane.IsShown():
      pane.BestSize(pane.window.GetSize())
      pane.Hide()
    else:
      pane.Show()

    # Turn on/off eye icon
    if menuItem.IsChecked():
      menuItem.SetNormalBitmap(self.eye_view)
    else:
      menuItem.SetNormalBitmap(self.eye_hide)

    self.m_mgr.Update()
  
  def load_recon(self):
    ert = self.scene.get_ert()
    if ert:
      self.load_recon_ert(ert)
      

  def reset_viewport_camera(self):
    self.reset_cam()

  # Resets each pane inside of the aui manager to it's saved state on startup
  def menuselect_reset_layout(self, event):
    
    #First, hide all panes so as to reset their sizes + positions
    for pane in self.m_mgr.GetAllPanes():
      if pane.IsShown():
        pane.Hide()

    self.m_mgr.Update()

    # Show panes and reset to their initial pane info
    for pane in self.m_mgr.GetAllPanes():
      pane.Show()
      self.m_mgr.LoadPaneInfo(self.paneInfoDict[pane], pane)

    # Re-Check all window menu items + turn off eye-view icons
    self.menuitem_viewport_check = True
    self.menuitem_viewport_options_check = True
    self.menuitem_database_check = True
    self.menuitem_log_check = True
    self.m_mgr.Update()
    
  # Saves current layout in app for use in reset layout
  def menuselect_save_layout(self, event):
    
    self.paneInfoDict = {}

    for pane in self.m_mgr.GetAllPanes():
      
      pane.BestSize(pane.window.GetSize())
      
      hideFlag = False
      
      # Show pane if currently hidden so that hidden state is not saved
      if pane.IsShown() == False:
        pane.Show(True)
        hideFlag = True
      
      self.paneInfoDict[pane] = self.m_mgr.SavePaneInfo(pane)

      # Set pane back to hidden if originally hidden
      if hideFlag:
        pane.Hide()

  # Flips log pane on/off        
  def menuselect_log(self, event):  
    self.menuitem_log_check = not(self.menuitem_log_check)
    self.menuitem_log.Check(self.menuitem_log_check)

    self.flipPane(self.panel_log, self.menuitem_log)

  # Flips database pane on/off  
  def menuselect_database(self, event):
    self.menuitem_database_check = not(self.menuitem_database_check)
    self.menuitem_database.Check(self.menuitem_database_check)

    self.flipPane(self.treectrl_database, self.menuitem_database)

  # Flips viewport pane on/off
  def menuselect_viewport(self, event):
    self.menuitem_viewport_check = not(self.menuitem_viewport_check)
    self.menuitem_viewport.Check(self.menuitem_viewport_check)

    self.flipPane(self.panel_viewport, self.menuitem_viewport)

  # Event for toggling dark mode
  def toggle_dark_mode(self, event):
    self.dark_mode = True
    background = wx.Colour(33,33,33)  # neutral 20
    text70 = wx.Colour(51,51,51)      # neutral 70
    text99 = wx.Colour(252,252,252)   # neutral 99

    self.change_color_scheme(background, text70, text99)

  # Event for toggling light mode
  def toggle_light_mode(self, event):
    self.dark_mode = False
    background = wx.Colour(252,252,252)   # nat 20
    text70 = wx.Colour(172,172,172)       # nat 70
    text99 = wx.Colour(33,33,33)          # nat 99

    self.change_color_scheme(background, text70, text99)
    

  # Sets everything to color scheme provided
  def change_color_scheme(self, background, text1, text2):

     # AUI Dock Art Colors
    dock_art = self.m_mgr.GetArtProvider()
    
    dock_art.SetColour(9, text1) # caption color
    dock_art.SetColour(10, background) # caption gradient color
    dock_art.SetColour(12, text2) # caption text color
    #dock_art.SetColour(13, text1) # border color
    dock_art.SetColour(6, wx.Colour(252, 252, 252)) # sash color
    #dock_art.SetMetric(3, 0) # Border width

    self.m_mgr.Update()

  def OnClose(self, event):

    if self.recon_thread.is_running():
      self.recon_thread.result_queue.put("Kill")
      self.recon_thread.stop()
    if self.progress_listener:
      self.progress_listener.stop()
      
    self.exit_panda()
    event.Skip()