"""Various GS mesh exporters"""
import sys, os
sys.path.append(os.path.abspath('../../../'))
sys.path.append(os.path.abspath('../'))
sys.path.append(os.path.abspath('./'))

import random
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Optional, Tuple, Union

import matplotlib
import numpy as np
import cv2
import torch
import torch.nn.functional as F
import tyro
from tqdm import tqdm
from typing_extensions import Annotated

from nerfstudio.benchmark_utils.utils import (
    get_colored_points_from_depth,
    get_means3d_backproj,
    project_pix,
    quat_to_rotmat
)

from nerfstudio.data.utils.colmap_parsing_utils import (
    read_points3D_binary,
)

from matplotlib import pyplot as plt
import matplotlib.lines as mlines
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

from nerfstudio.cameras.cameras import Cameras
from nerfstudio.models.raytracingfacto import RaytracingfactoModel
from nerfstudio.utils.eval_utils import eval_setup
from nerfstudio.utils.rich_utils import CONSOLE
from nerfstudio.utils.dn_utils import normal_from_depth_image
from nerfstudio.model_components.losses import (
    estimate_scale_and_bias_midas, normalize_depth_midas    
)

# Plot the error vs uncertainty 3-sigma bounds
def plot_uncertainty(err, var, save_path, err_name="rgb",
        ticksize=12, fontsize=14): #, tickfont="Crimson Text", fontname="Crimson Text"):

    fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(3, 3), dpi=200)
    canvas = fig.canvas
    
    err = err.reshape(-1).detach().cpu().numpy()
    var = torch.maximum(var, 1e-12 * torch.ones_like(var))
    std = var.reshape(-1).sqrt().detach().cpu().numpy()

    axs.scatter(err, std)
    l = mlines.Line2D([-3, 3], [-1, 1], color="r", linestyle="--", linewidth=0.7)
    axs.add_line(l)
    l = mlines.Line2D([3, -3], [-1, 1], color="r", linestyle="--", linewidth=0.7)
    axs.add_line(l)
        
    print(err_name + "_err min, mean, max", err.min(), err.mean(), err.max())
    print(err_name + "_var min, mean, max", var.min(), var.mean(), var.max())

    #ylim = np.abs(std).max()
    #plt.ylim((0, ylim))
    plt.ylim((0, 1.5))
    xlim = np.abs(err).max()
    plt.xlim((-xlim, xlim))
    plt.setp(axs.get_xticklabels()) #, fontsize=ticksize, fontname=tickfont)
    plt.setp(axs.get_yticklabels()) #, fontsize=ticksize, fontname=tickfont)
    plt.grid(True)

    axs.set_ylabel(err_name.split('_')[0] + " $\sigma$", fontsize=fontsize) #, fontname=fontname)
    axs.set_xlabel(err_name.split('_')[1] + " error", fontsize=fontsize) #, fontname=fontname)
    plt.tight_layout(pad=0.2)
            
    fig_path = os.path.join(save_path, err_name + "_uncertainty.png")
    print("Saving figure to", fig_path)
    plt.savefig(fig_path)

    #canvas.draw()  # Draw the canvas, cache the renderer
    #image_flat = np.frombuffer(canvas.tostring_argb(), dtype='uint8')  # (H * W * 4,)
    ## NOTE: reversed converts (W, H) from get_width_height to (H, W)
    #image = image_flat.reshape(*reversed(canvas.get_width_height()), 4)[:, :, 1:]  # (H, W, 3)

    #return torch.from_numpy(image)

def normalize(x):
    #return (x - x.min()) / (x.max() - x.min())
    return (x - x.min()) / (torch.quantile(x, 0.99) - x.min())

@dataclass
class VariancePlotter():
    """
    Plot err vs. stdev of model
    """

    load_config: Path
    """Path to the trained config YAML file."""

    output_dir: Path = Path("./uncertainty_eval")
    """Path to the output directory."""

    "name of name"
    name: str = "var3sigma"
    
    """Number of images to skip while running"""
    skip: int = 10

    def main(self):
        if not self.output_dir.exists():
            self.output_dir.mkdir(parents=True)
        
        out_imgs_dir = self.output_dir / "images"
        os.makedirs(out_imgs_dir, exist_ok=True)

        _, pipeline, _, _ = eval_setup(self.load_config)
        
        model = pipeline.model
        cmap = matplotlib.colormaps.get_cmap('inferno')
        # Once for learned uncertainty, once for Hessian uncertainty
        for i in range(2):
            prefix = "learned_"
            if i == 1:
                model.initialize_uncertainty_with_hessian()
                prefix = "hessian_"
            os.makedirs(out_imgs_dir / (prefix + "rgb"), exist_ok=True)
            os.makedirs(out_imgs_dir / (prefix + "depth"), exist_ok=True)
            os.makedirs(out_imgs_dir / (prefix + "normal"), exist_ok=True)
            with torch.no_grad():
                cameras: Cameras = pipeline.datamanager.train_dataset.cameras  # type: ignore
                # TODO: do eval dataset as well
                
                rgb_errs = []
                rgb_vars = []
                depth_errs = []
                depth_vars = []
                normal_errs = []
                normal_vars = []
                
                have_depth = False
                have_normals = False

                for image_idx, data in tqdm(enumerate(
                    pipeline.datamanager.train_dataset
                ), "Running model on all views", total=len(pipeline.datamanager.train_dataset)):  # type: ignore
                    if image_idx % self.skip != 0:
                        continue

                    mask = torch.ones_like(data["image"][:,:,0])
                    if "mask" in data:
                        mask = data["mask"]
                    camera = cameras[image_idx : image_idx + 1]
                    outputs = model.get_outputs_for_camera(camera=camera)
                    device = outputs["rgb"].device
                    mask = mask.to(torch.bool).to(device)

                    # RGB
                    rgb_err = data["image"].to(device) - outputs["rgb"]
                    rgb_errs.append(rgb_err[mask].reshape(-1).cpu())
                    rgb_vars.append(outputs["rgb_uncertainty"][mask].reshape(-1).cpu())
                    
                    rgb_err_viz = torch.cat([normalize(outputs["rgb_uncertainty"]), normalize(rgb_err ** 2)], 1).mean(-1)
                    rgb_err_viz = cmap(rgb_err_viz.cpu().numpy())[:, :, :3]
                    cv2.imwrite(out_imgs_dir / (prefix + "rgb") / f"{image_idx:09d}.png", (255 * rgb_err_viz).astype(np.uint8))

                    if "depth_image" in data:
                        have_depth = True

                        # Depth
                        gt_depth = data["depth_image"].to(device)
                        pred_depth = outputs["depth"]
                        # Scale gt depth into pred depth scale before sending to MLE loss
                        min_depth = 0.01
                        #valid_mask = mask & (gt_depth.nan_to_num_() > min_depth) & (pred_depth.nan_to_num_() > min_depth)
                        valid_mask = (gt_depth.nan_to_num_() > min_depth) & (pred_depth.nan_to_num_() > min_depth)
                        gt_depth_norm = normalize_depth_midas(gt_depth)
                        scale, bias = estimate_scale_and_bias_midas(pred_depth)
                        gt_depth_pred_scale = scale * gt_depth_norm + bias
                        depth_err = (gt_depth_pred_scale - pred_depth) * valid_mask.float()
                        depth_errs.append(depth_err[valid_mask].reshape(-1).cpu())
                        depth_vars.append(outputs["depth_uncertainty"][valid_mask].reshape(-1).cpu())
                        
                        depth_err_viz = torch.cat([normalize(outputs["depth_uncertainty"]), normalize(depth_err ** 2)], 1).mean(-1)
                        depth_err_viz = cmap(depth_err_viz.cpu().numpy())[:, :, :3]
                        cv2.imwrite(out_imgs_dir / (prefix + "depth") / f"{image_idx:09d}.png", (255 * depth_err_viz).astype(np.uint8))
                        
                        # Normals
                        gt_normal = normal_from_depth_image(
                            depths=gt_depth,
                            fx=camera.fx.item(),
                            fy=camera.fy.item(),
                            cx=camera.cx.item(),
                            cy=camera.cy.item(),
                            img_size=(camera.width.item(), camera.height.item()),
                            c2w=torch.eye(4, dtype=torch.float, device=device),
                            device=device,
                            smooth=False,
                        )
                        
                        gt_normal = gt_normal @ torch.diag(
                            torch.tensor(
                                [1, -1, -1], device=device, dtype=gt_normal.dtype
                            )
                        )
                        # Transform normal from camera space to world space
                        gt_normal = gt_normal @ outputs["refined_camera"].camera_to_worlds.squeeze(0)[:3, :3].T

                        normal_err = gt_normal - (2 * outputs["normal"] - 1)
                        normal_err = normal_err * mask[:, :, None].float()
                        normal_errs.append(normal_err[mask].reshape(-1).cpu())
                        normal_vars.append(outputs["normal_uncertainty"][mask].reshape(-1).cpu())

                        normal_err_viz = torch.cat([normalize(outputs["normal_uncertainty"]), normalize(normal_err ** 2)], 1).mean(-1)
                        normal_err_viz = cmap(normal_err_viz.cpu().numpy())[:, :, :3]
                        cv2.imwrite(out_imgs_dir / (prefix + "normal") / f"{image_idx:09d}.png", (255 * normal_err_viz).astype(np.uint8))
                
                rgb_errs = torch.cat(rgb_errs)
                rgb_vars = torch.cat(rgb_vars)
                plot_uncertainty(rgb_errs, rgb_vars, self.output_dir, err_name=prefix + "rgb")

                if have_depth:
                    depth_errs = torch.cat(depth_errs)
                    depth_vars = torch.cat(depth_vars)
                    plot_uncertainty(depth_errs, depth_vars, self.output_dir, err_name=prefix + "depth")

                    normal_errs = torch.cat(normal_errs)
                    normal_vars = torch.cat(normal_vars)
                    normal_plot = plot_uncertainty(normal_errs, normal_vars, self.output_dir, err_name=prefix + "normal")


Commands = tyro.conf.FlagConversionOff[
    Union[
        Annotated[VariancePlotter, tyro.conf.subcommand(name="var3sigma")],
    ]
]


def entrypoint():
    """Entrypoint for use with pyproject scripts."""
    tyro.extras.set_accent_color("bright_yellow")
    tyro.cli(Commands).main()


if __name__ == "__main__":
    entrypoint()
#     # tyro.cli(GaussiansToPoisson).main()
#     # tyro.cli(DepthAndNormalMapsPoisson).main()
#     # tyro.cli(LevelSetExtractor).main()
#     # tyro.cli(MarchingCubesMesh).main()
#     # tyro.cli(TSDFFusion).main()
#     tyro.cli(Open3DTSDFFusion).main()
