import wx
import zipfile
import json
import time
import glob
from path import Path
import os
import re
import time
import open3d as o3d
import numpy as np
from google.protobuf.json_format import MessageToJson, ParseDict
from modules.scene.data_processing import process_arkit_data
from modules.proto import scene_pb2
# from modules.scene.train_ert import ERTTrainer
# from nerfstudio.scripts.exporter import ExportGaussianSplat
from multiprocessing import Queue

# def export_ert_ply(fp, scale, recon_path):
#   dir = os.path.dirname(fp)
#   file = os.path.basename(fp)
#   pcd_exporter = ExportGaussianSplat(Path(recon_path), Path(dir), file)
#   pcd_exporter.main()

#   pcd = o3d.io.read_point_cloud(fp)
#   scaling_mat = np.array([[scale, 0, 0, 0], [0, scale, 0, 0], [0, 0, scale, 0], [0, 0, 0, 1]])
#   pcd.transform(scaling_mat)
#   o3d.io.write_point_cloud(fp, pcd)

class Scene():
  def __init__(self, scene_path=None, json_zipped = False, file_type = 0):
    self.loaded = False
    self.mesh_nodes = []
    self.ert = None
   
    if scene_path is not None:
      self.scene_path = Path(scene_path)
      if file_type == 0:
        self.deserialize_qdb(scene_path)
      elif file_type == 1:
        self.deserialize_json(scene_path, json_zipped)
      self.loaded = True
    else:
      self.scene_path = None
      self.image_count = -1
      return
    
    if os.path.exists(self.get_dir() + "/processed/images"):
      self.image_count = len(os.listdir(self.get_dir() + "/processed/images"))
    else:
      self.image_count = -1

  def get_image_count(self):
    return self.image_count
  
  def get_dir(self):
    check_path = Path(self.scene_path)
    if check_path.is_file():
      return str(self.scene_path.parent)
    return str(self.scene_path)
  
  def get_name(self):
    return self.name
  
  def export_ply(self, fp, scale, recon_path):
    # proc = Process(target=export_ert_ply, args=(fp, scale, recon_path))
    # proc.run()
    # proc.join()
    # proc.terminate()
    return

  def unload_ert(self):
    if self.ert is not None:
      self.ert = None

  def get_ert(self):
    if self.ert == None:
      if self.name == None or self.scene_path == None:
        return self.ert
      
      checkpoints = glob.glob(self.get_dir()+f"/ert_runs/{self.name}/nerfstudio_models/*.ckpt")
      configs = glob.glob(self.get_dir()+f"/ert_runs/{self.name}/*.yml")
      if len(checkpoints) > 0:
        matches = re.findall(r'\d+', checkpoints[0])
        num = int(matches[-1])
        if num < self.get_image_count()*30-100:
          return None
      else:
        return None
      
      self.ert = configs[0]
    return self.ert
  
  def instantiate_scene_with_video(self, scene_name : str, scene_dir : str, video_path : str):
    self.scene_path = Path(scene_dir + f"/{scene_name}.qdb")
    self.name = scene_name
    self.video_path = video_path
    self.serialize(self.scene_path)
    return self.scene_path
  
  def instantiate_scene_with_arkit(self, scene_name : str, scene_dir : str, result_queue : Queue):
    self.name = scene_name
    self.scene_path = Path(scene_dir + f"/{scene_name}.qdb")
    self.serialize(self.scene_path)
    result_queue.put((10, "QDB Created. Processing Data..."))
    process_arkit_data(scene_dir+"/raw_images")
    result_queue.put((65, "Data processed. Generating masks..."))

    result_queue.put((100, "Processing Complete"))
    return self.scene_path
  
  def deserialize(self):
    self.name = self.scene.name

  def deserialize_json(self, scene_path, json_zipped):
    with open(scene_path, 'r') as file:
      data = json.load(file)
      self.scene = scene_pb2.SceneData()
      ParseDict(data, self.scene)
      self.deserialize()

  def deserialize_qdb(self, scene_path):
    self.scene = scene_pb2.SceneData()
    with open(scene_path, "rb") as sceneFile:
      decompressed_data = sceneFile.read()
      self.scene.ParseFromString(decompressed_data)
      self.deserialize()
      
  def serialize(self, scene_path, save_dialog = None, compressed=False):
    caps = []
    suffix = Path(scene_path).suffix
    parent = Path(scene_path).parent
    
    start_time = time.time()
    index = 0
    meshes = []
    for mesh in self.mesh_nodes:
      if suffix == '.json':

        ply_data = mesh.get_ply()
        
        if ply_data == b'':
          ply_path = str(self.scene_path.parent) + "\\" +  mesh.get_path()
          with open(ply_path, "rb") as ply:
            ply_data = ply.read()
          
        local_path = Path(mesh.get_path())
        local_p = str(parent) + "\\" + str(local_path.parent)
    
        if not os.path.exists(local_p):
          os.mkdir(local_p)
        
        mesh_path = str(parent) + "\\" + mesh.get_path()
        with open(mesh_path, "wb") as m_file:
          m_file.write(ply_data)

        mesh.set_ply(b'')
        
      meshes.append(mesh.get_data())
    
   
    
    save_scene = scene_pb2.SceneData(name = self.name,
                                     meshes = meshes)
    
    suf = Path(scene_path).suffix

    if suf == '.json':
      json_data = MessageToJson(save_scene)
      if compressed:
        zip_path = Path(scene_path)
        name = str(zip_path.name)
        zip_path = str(zip_path.parent) + "\\" + str(zip_path.stem) + '.zip'
        with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_LZMA) as zip:
          zip.writestr(name, json_data)  
      else:
        with open(scene_path, "w") as sceneFile:
          sceneFile.write(json_data)
    
    elif suf == '.qdb':
      compressed_data = save_scene.SerializeToString()
        
      with open(scene_path, "wb") as sceneFile:
        sceneFile.write(compressed_data)
    if save_dialog != None:
      wx.CallAfter(save_dialog.Close)
  
  def is_loaded(self):
    return self.loaded
