import os
from threading import Thread

import cv2
import numpy as np
from nerfstudio.models.raytracingfacto import RaytracingfactoModelConfig
from nerfstudio.pipelines.base_pipeline import VanillaPipelineConfig
from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig
from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig, FullImageDatamanager
from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig
from nerfstudio.data.datasets.ert_dataset import ERTDataset
from nerfstudio.engine.trainer import TrainerConfig
from nerfstudio.engine.schedulers import ExponentialDecaySchedulerConfig
from nerfstudio.engine.optimizers import AdamOptimizerConfig
from nerfstudio.configs.base_config import ViewerConfig
import PIL.Image as Image

from pathlib import Path
import datetime
import glob
import re

class ERTTrainer:
  def __init__(self, scene_dir : str, scene_name : str) -> None:
    self.processed_dir = Path(str(scene_dir) + "/processed")
    self.config = None #glob.glob(f"{str(scene_dir)}/ert_runs/{scene_name}/*.yml")
    self.ckpt = None #glob.glob(f"{str(scene_dir)}/ert_runs/{scene_name}/nerfstudio_models/*.ckpt")
    print(scene_dir)
    print(scene_name)
    self.scene_dir = scene_dir
    self.scene_name = scene_name

  def get_ply_paths(self):
    return glob.glob(str(self.processed_dir) + "/*.ply")

  def get_recon_path(self):  
    self.config = glob.glob(f"{str(self.scene_dir)}/ert_runs/{self.scene_name}/*.yml")
    self.ckpt = glob.glob(f"{str(self.scene_dir)}/ert_runs/{self.scene_name}/nerfstudio_models/*.ckpt")
    if len(self.config) == 0 or len(self.ckpt) == 0:
      return None
    matches = re.findall(r'\d+', self.ckpt[0])
    num = int(matches[-1])
    max_iter = len(os.listdir(self.processed_dir + "/images")) * 30
    if num < max_iter-100:
      return None
    return self.config[0]
 


  def start_training(self, use_depth=False):
    images_path = Path(str(self.scene_dir) + "/processed/images")
    downscale_path = Path(str(self.scene_dir) + "/processed/images_8")
    mono_depth_path = Path(str(self.scene_dir) + "/processed/mono_depth")
    depth_downscale_path = Path(str(self.scene_dir) + "/processed/depths_8")
    if not os.path.exists(downscale_path):
      os.makedirs(downscale_path)
      for image in images_path.iterdir():
        im = Image.open(image)
        im = im.resize((int(im.width/8), int(im.height/8)))
        im.save(downscale_path / image.name)
    if mono_depth_path.exists() and not depth_downscale_path.exists():
      os.makedirs(depth_downscale_path)
      for depth_file in mono_depth_path.iterdir():
        if depth_file.suffix != '.npy':
          continue

        depth = np.load(depth_file)
        size = depth.shape
        downscale_size = (int(size[0]/8), int(size[1]/8))
        depth = cv2.resize(depth, downscale_size, interpolation=cv2.INTER_NEAREST)
        depth_uint8 = (255 * (depth - np.min(depth)) / (np.max(depth) - np.min(depth))).astype(np.uint8)
        edges = cv2.Canny(image=depth_uint8, threshold1=10, threshold2=20) # Canny Edge Detection
        # Ignore depth points on sharp edges to avoid floaters
        depth[edges > 0] = 0
        depth = depth.reshape(-1)
        valid_mask = (depth > 0.1) & np.isfinite(depth)
        np.save(depth_downscale_path / depth_file.name, depth)
    

    data_config = NerfstudioDataParserConfig(
      data=Path(str(self.scene_dir) + "/processed"),
      downscale_factor=8,
      depth_unit_scale_factor=1.0,
      load_3D_points=True
    )

    datamanager_config = FullImageDatamanagerConfig(
      _target=FullImageDatamanager[ERTDataset],
      dataparser=data_config,
      cache_images_type="uint8",
      cache_images='cpu',
    )

    camera_optimizer_config = CameraOptimizerConfig(
      mode="SO3xR3",
      refine_intrinsics=True,
      refine_focal_length=True,
      refine_camera_center=True,
      refine_distortion_params=False
    )

    model_config = RaytracingfactoModelConfig(
      mode="3dgrt",
      cull_alpha_thresh=0.03,
      cull_scale_thresh=0.15,
      densify_grad_thresh=0.0003,
      densify_size_thresh=0.05,
      stop_split_at=8000,
      camera_optimizer=camera_optimizer_config,
      lambda_normal=1.0,
      lambda_normal_smooth=1.0,
      min_alpha=0.01,
      adaptive_control_enabled=True,
      enable_depth_training=use_depth,
      enable_normal_training=use_depth,
      densify_norm_thresh=2e-9,
      visibility_pruning=True,
      normal_init=True
    )


    pipeline_config = VanillaPipelineConfig(datamanager=datamanager_config, model=model_config)

    optimizers = {}
    optimizers['means'] = {'optimizer': AdamOptimizerConfig(lr=0.00016, eps=1e-15),
                           'scheduler': ExponentialDecaySchedulerConfig(lr_pre_warmup=1e-08, lr_final=1.6e-06, warmup_steps=0, max_steps=30000)}
    optimizers['features_dc'] = {'optimizer': AdamOptimizerConfig(lr=0.0025, eps=1e-15),'scheduler': None}
    optimizers['features_rest'] = {'optimizer': AdamOptimizerConfig(lr=0.000125, eps=1e-15),'scheduler': None}
    optimizers['opacities'] = {'optimizer': AdamOptimizerConfig(lr=0.05, eps=1e-15),'scheduler': None}
    optimizers['scales'] = {'optimizer': AdamOptimizerConfig(lr=0.005, eps=1e-15),'scheduler': None}
    optimizers['quats'] = {'optimizer': AdamOptimizerConfig(lr=0.001, eps=1e-15),'scheduler': None}
    optimizers['normals'] = {'optimizer': AdamOptimizerConfig(lr=0.001, eps=1e-15),'scheduler': None}
    optimizers['camera_opt'] = {'optimizer': AdamOptimizerConfig(lr=0.0001, eps=1e-15),
                               'scheduler': ExponentialDecaySchedulerConfig(lr_pre_warmup=0, lr_final=5e-07, warmup_steps=1000, max_steps=30000)}
    optimizers['scales'] = {'optimizer': AdamOptimizerConfig(lr=0.001, eps=1e-15),'scheduler': None}
    optimizers['color_uncertainties'] = {'optimizer': AdamOptimizerConfig(lr=0.001, eps=1e-15),'scheduler': None}
    optimizers['opacity_uncertainties'] = {'optimizer': AdamOptimizerConfig(lr=0.001, eps=1e-15),'scheduler': None}
    optimizers['position_uncertainties'] = {'optimizer': AdamOptimizerConfig(lr=0.001, eps=1e-15),'scheduler': None}
    optimizers['quat_uncertainties'] = {'optimizer': AdamOptimizerConfig(lr=0.001, eps=1e-15),'scheduler': None}
    optimizers['normal_uncertainties'] = {'optimizer': AdamOptimizerConfig(lr=0.001, eps=1e-15),'scheduler': None}
    optimizers['radiometry_opt'] = {'optimizer': AdamOptimizerConfig(lr=0.0001, eps=1e-15),
                               'scheduler': ExponentialDecaySchedulerConfig(lr_pre_warmup=0.0001, lr_final=5e-07, warmup_steps=1000, max_steps=30000)}
    optimizers['scale_uncertainties'] = {'optimizer': AdamOptimizerConfig(lr=0.001, eps=1e-15),'scheduler': None}
    optimizers['scales'] = {'optimizer': AdamOptimizerConfig(lr=0.005, eps=1e-15),'scheduler': None}

    now = datetime.datetime.now()
    timestamp = now.strftime("%Y-%m-%d_%H%M%S")

    viewer_config = ViewerConfig(quit_on_train_completion=True)

  
    self.trainer_config = TrainerConfig(
      output_dir=Path(str(self.scene_dir) + "/ert_runs"),
      project_name="ert-project",
      timestamp=timestamp,
      method_name="raytracingfacto",
      ert_layout=True,
      experiment_name=self.scene_name,
      pipeline=pipeline_config,
      vis='viewer',
      viewer=viewer_config,
      optimizers=optimizers,
      steps_per_save=100,
      steps_per_eval_batch=0,
      steps_per_eval_image=100,
      steps_per_eval_all_images=1000,
      max_num_iterations=30 * len(os.listdir(str(self.processed_dir) + "/images")),
      load_config = Path(self.config[0]) if self.config else None,
      load_checkpoint = Path(self.ckpt[0]) if self.ckpt else None
    )

    self.trainer_config.save_config()
    self.trainer_config.print_to_terminal()
    self.model = self.trainer_config.setup()
    self.model.setup()
    self.model.setup_optimizers()
    self.model.train()
    return True