import os
import shutil
import subprocess
import argparse
from pathlib import Path
from depth_generator import generate_depth
from nerfstudio.process_data.colmap_utils import colmap_to_json
from nerfstudio.data.utils.colmap_parsing_utils import (
    read_points3D_binary,
)

NUM_ITERS = 12000

STEPS_PER_SAVE = 1000

# By default we will overwrite the checkpoint. Only disable if it is seemed necessary.
SAVE_ONLY_LATEST = True


def check_conda_env(env_names):
    result = subprocess.run(
                ["conda", "env", "list"],
                capture_output=True,
                text=True,
                check=True
            )
    env_list = result.stdout       
    for env in env_names:
        if env not in env_list:
            print(f"No conda environment: {env} found. Make sure to create it for corresponding run.")
            return False
    return True


def generate_transforms(dataset_path):
    transforms = dataset_path.joinpath("transforms.json")
    if transforms.exists():
        print("Removing exsisting transforms.json")
        os.remove(transforms)

    recon_dir = dataset_path.joinpath("sparse/0")
    depth_path = generate_depth(dataset_path)
    depth_dict = {}
    for depth_file in depth_path.iterdir():
        if depth_file.is_file() and depth_file.suffix == ".npy" or ".exr":
            depth_dict[depth_file.stem] = depth_file

    colmap_to_json(recon_dir, dataset_path, image_id_to_depth_path=depth_dict)
    assert transforms.exists(), f"transforms.json not found in {dataset_path}"


def get_output_dir(scene_name):
    data = Path().home().joinpath(f"benchmark/{scene_name}")
    if not data.exists():
        data.mkdir(parents=True)
    return data


def get_version_name(output_dir, exp_name, model_name, new_version):
    current_outputs = output_dir.joinpath(exp_name, model_name)
    if not current_outputs.exists():
        return "v0"

    names = sorted([int(x.name[1:]) for x in current_outputs.iterdir() if x.is_dir()])
    idx = names[-1]
    if new_version:
        idx += 1
    return f"v{idx}"


def generate_ert_command(dataset_path, exp_name, scene_name, new_version):
    # First generate depth and transforms.json
    generate_transforms(dataset_path)

    output_dir = get_output_dir(scene_name)
    version_name = get_version_name(output_dir, exp_name, "raytracingfacto", new_version)

    enable_depth_training = True
    enable_normal_training = True
    train_command = [
          "ns-train", "raytracingfacto-depth" 
        , "--output-dir", str(output_dir)
        , "--experiment-name", str(exp_name) 
        , "--timestamp", str(version_name)
        , "--steps-per-save", str(STEPS_PER_SAVE)
        , "--max-num-iterations", str(NUM_ITERS)
        , "--save-only-latest-checkpoint", str(SAVE_ONLY_LATEST)
        , "--pipeline.datamanager.cache-images", "cpu"
        # , "--vis", "viewer+tensorboard"
        , "--vis", "tensorboard"
        , "--viewer.quit-on-train-completion", "True"

        # Mode Settings
        , "--pipeline.model.adaptive_control_enabled", f"True"
        , "--pipeline.model.enable_depth_training", str(enable_depth_training)
        , "--pipeline.model.enable_normal_training", str(enable_normal_training)
        , "--pipeline.model.min_particle_opacity", "0.01" # when to start adaptive control
        , "--pipeline.model.warmup_length", "500" # when to start adaptive control
        , "--pipeline.model.stop_split_at", "15000" # when to stop adaptive control
        , "--pipeline.model.refine_every", "100" # adaptive control after 100 iteratins
        , "--pipeline.model.reset_alpha_every", "30" # reset after every 30*refine_every
        , "--pipeline.model.cull_alpha_thresh", "0.03" # prune opacity below 0.01
        , "--pipeline.model.cull_scale_thresh", "0.1" # NOT MENTIONED IN PAPER: prune with scale greater than 1%
        , "--pipeline.model.densify_grad_thresh", "0.0003" # NOT MENTIONED IN PAPER: 3D gradients threshold
        , "--pipeline.model.densify_size_thresh", "0.01" # 1% of scene scale
        , "--pipeline.model.bound_extent", "3.0" # Use full extent

        # Downscale
        , "--pipeline.datamanager.camera-res-scale-factor", "0.125"
    ]
    train_command.extend(["nerfstudio-data", "--downscale-factor", "1", "--depth_unit_scale_factor", "1.0", "--data", str(dataset_path)])
    return train_command


def generate_dns_command(dataset_path, exp_name, scene_name, new_version):
    # First generate depth and transforms.json
    generate_transforms(dataset_path)

    output_dir = get_output_dir(scene_name)
    version_name = get_version_name(output_dir, exp_name, "dn-splatter", new_version)

    # Using comma in the front since forgetting to put comma
    # after a config is a frequent thing.
    train_cmd = [
          "ns-train", "dn-splatter"
        , "--output-dir", str(output_dir)
        , "--experiment_name", str(exp_name)
        , "--timestamp", str(version_name)
        , "--steps-per-save", str(STEPS_PER_SAVE)
        , "--max-num-iterations", str(NUM_ITERS)
        , "--save-only-latest-checkpoint", str(SAVE_ONLY_LATEST)
        # , "--vis", "viewer+tensorboard"
        , "--vis", "tensorboard"
        , "--viewer.quit-on-train-completion", "True"
        , "--pipeline.datamanager.cache-images", "cpu"

        # Model configs
        , "--pipeline.model.use-depth-loss", "True"
        , "--pipeline.model.depth-loss-type", "PearsonDepth"
        , "--pipeline.model.depth-lambda", "0.2"
        , "--pipeline.model.use-depth-smooth-loss", "True"
        , "--pipeline.model.use-normal-loss", "True"
        , "--pipeline.model.use-normal-tv-loss", "True"
        , "--pipeline.model.normal-supervision", "depth"
        , "--pipeline.model.cull_alpha_thresh", "0.005"
        , "--pipeline.model.continue_cull_post_densification", "False"
        # Downscale
        , "--pipeline.datamanager.camera-res-scale-factor", "0.125"

    ]
    train_cmd.extend(f"coolermap --data {dataset_path} --downscale-factor 1 --train-split-fraction 1.0".split(" "))
    return train_cmd


# Define the Conda environment and command to run
def run_command(benchmark, print_only=False):
    if print_only:
        print(benchmark)
        return
    try:
        for conda_env, command in benchmark.items():
            print(f"Running in conda env: {conda_env}")
            print(f"command:\n", " ".join(command))
            print(["conda", "run", "-n", conda_env] + command)
            process = subprocess.run(
                " ".join(["conda", "run", "--no-capture-output", "-n", conda_env] + command),
                shell=True,
                capture_output=False,
                text=True
            )

    except subprocess.CalledProcessError as e:
        print(f"Error: {e.stderr}")


def autocompute_tsdf_params(sparse_dir, depth_trunc):
    pts = read_points3D_binary(os.path.join(sparse_dir, "points3D.bin"))
    assert len(pts) > 0
    min_x, max_x = None, None
    min_y, max_y = None, None
    min_z, max_z = None, None
    for pt in pts.values():
        min_x = min(min_x, pt.xyz[0]) if min_x is not None else pt.xyz[0]
        max_x = max(min_x, pt.xyz[0]) if max_x is not None else pt.xyz[0]
        min_y = min(min_y, pt.xyz[1]) if min_y is not None else pt.xyz[1]
        max_y = max(min_y, pt.xyz[1]) if max_y is not None else pt.xyz[1]
        min_z = min(min_z, pt.xyz[2]) if min_z is not None else pt.xyz[2]
        max_z = max(min_z, pt.xyz[2]) if max_z is not None else pt.xyz[2]
    
    # TODO add these as args
    max_diam = max(max_x - min_x, max_y - min_y, max_z - min_z)
    depth_trunc = min(max_diam, depth_trunc)
    voxel_size = depth_trunc / 1024
    sdf_truc = 5.0 * voxel_size
    print(f"Voxel size: {voxel_size}, SDF Truc: {sdf_truc}")
    return voxel_size, sdf_truc



if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # Dataset related arguments. 
    parser.add_argument("--dataset", type=Path, required=True, help="Path to the dataset folder. eg: /Datasets/mipnerf360 or /Datasets/internal")
    parser.add_argument("--scene", type=str, required=True, help="Select one of the scenes in the dataset")
    parser.add_argument("--new", action="store_true", help="Save checkpoints at a new folder.")
    parser.add_argument("--optimize", action="store_true", help="Enable optimization.")
    parser.add_argument("--render", action="store_true", help="Enable rendering and comparision generation.")
    parser.add_argument("--mesh", action="store_true", help="Mesh the geometry using TSDF")
    parser.add_argument("--render_mesh", action="store_true", help="Render the geometry produced in the field using blender")

    args = parser.parse_args()

    assert args.dataset.name in ["mipnerf360", "internal"], "Only [mipnerf360/internal] dataset is supported at the moment."

    dataset_path = args.dataset.joinpath(args.scene, "processed")
    assert dataset_path.exists(), f"Make sure {dataset_path} is present."

    # Get the active Conda environment name
    conda_env = os.getenv('CONDA_DEFAULT_ENV')
    print(f"Active Conda environment: {conda_env}")
    print("Working on benchmark system.")

    # 1. Optimize commands
    if args.optimize:
        test_benchmark = {
            "qr": generate_ert_command(
                dataset_path, 
                exp_name=args.dataset.name,
                scene_name=args.scene,
                new_version=args.new
            ),
            "nerfstudio": generate_dns_command(
                dataset_path, 
                exp_name=args.dataset.name,
                scene_name=args.scene,
                new_version=args.new
            ),
        }
        if not check_conda_env(test_benchmark.keys()):
            raise NameError("One of the specified conda environment is missing at the moment.")
        run_command(test_benchmark, False)

    # 2. Render commands
    if args.render:
        output_dir = get_output_dir(args.scene)
        ert_version_name = get_version_name(output_dir, args.dataset.name, "raytracingfacto", args.new)
        dns_version_name = get_version_name(output_dir, args.dataset.name, "dn-splatter", args.new)

        generate_cmd = [
              "python", "benchmark_utils/generate_render.py"
            , "--output_scene", str(args.scene)
            , "--dataset", args.dataset.name
        ]
        test_benchmark = {
            "qr": generate_cmd + ["--timestamp", ert_version_name, "--model_name", "raytracingfacto"],
            "nerfstudio": generate_cmd + ["--timestamp", dns_version_name, "--model_name", "dn-splatter"],
        }
        run_command(test_benchmark)

    # 3. Mesh commands
    if args.mesh:
        output_dir = get_output_dir(args.scene)
        ert_version_name = get_version_name(output_dir, args.dataset.name, "raytracingfacto", args.new)
        dns_version_name = get_version_name(output_dir, args.dataset.name, "dn-splatter", args.new)

        sparse_dir = Path(args.dataset).joinpath(args.scene, "processed/sparse/0")

        ert_mesh_name = f"{args.dataset.name}_{args.scene}_ERT"

        mesh_cmd = [
            "python", "mesh_extraction.py", "o3dtsdf",
            "--sparse_dir", sparse_dir.as_posix(),
            "--depth-trunc", "3.0",
            "--name", ert_mesh_name
        ]

        depth_trunc = 8.0
        voxel_size, sdf_truc = autocompute_tsdf_params(sparse_dir, depth_trunc)

        dns_cmd = [
            "gs-mesh", "o3dtsdf",
            "--voxel-size", str(voxel_size),
            "--sdf-truc", str(sdf_truc),
            "--depth-trunc", str(depth_trunc),
        ]

        dns_output = Path("./mesh_exports/Open3dTSDFfusion_mesh.ply")
        if dns_output.exists():
            os.remove(dns_output)


        ert_output_dir = output_dir.joinpath(args.dataset.name, "raytracingfacto", ert_version_name, "config.yml")
        dns_output_dir = output_dir.joinpath(args.dataset.name, "dn-splatter", dns_version_name, "config.yml")

        mesh_generation = {
            "qr": mesh_cmd + ["--load-config", ert_output_dir.as_posix()],
            "nerfstudio": dns_cmd + ["--load-config", dns_output_dir.as_posix()],
        }


        dns_mesh_name = f"{args.dataset.name}_{args.scene}_DNS"
        run_command(mesh_generation)

        assert dns_output.exists(), "Mesh generation failed for DNS."

        shutil.copy(dns_output, f"./mesh_exports/{dns_mesh_name}.ply")

    
    if args.render_mesh:
        # Render the exported mesh and create a comparison of the outputs.
        result_path = Path().home().joinpath(f"benchmark/{args.scene}")
        ert_output_path = result_path.joinpath("outputs", "raytracingfacto", "camera.json")
        dns_output_path = result_path.joinpath("outputs", "dn-splatter", "camera.json")


        ert_mesh_name = f"{args.dataset.name}_{args.scene}_ERT.ply"
        dns_mesh_name = f"{args.dataset.name}_{args.scene}_DNS.ply"

        ert_command = [
            "python", "benchmark_utils/render_mesh.py",
            "--mesh_path", Path(f"./mesh_exports/{ert_mesh_name}").absolute().as_posix(),
            "--camera_path", ert_output_path.as_posix()
        ]
        dns_command = [
            "python", "benchmark_utils/render_mesh.py",
            "--mesh_path", Path(f"./mesh_exports/{dns_mesh_name}").absolute().as_posix(),
            "--camera_path", dns_output_path.as_posix()
        ]

        render_cmd = {
            "qr": ert_command,
        }
        run_command(render_cmd)

        render_cmd = {
            "qr": dns_command,
        }
        run_command(render_cmd)
        pass
