# ruff: noqa: E741
# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Gaussian ray tracing implementation that combines many recent advancements.
"""

from __future__ import annotations

import math
from dataclasses import dataclass, field
from typing import Dict, List, Literal, Optional, Tuple, Type, Union

import open3d as o3d
from tqdm import tqdm
import numpy as np
import torch
import matplotlib
from gsplat.strategy import RayTracerDefaultStrategy, TrimStrategy

try:
    from gsplat.rendering import rasterization
except ImportError:
    print("Please install gsplat>=1.0.0")
from pytorch_msssim import SSIM
from torch.nn import Parameter

from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig
from nerfstudio.cameras.cameras import Cameras
from nerfstudio.data.scene_box import OrientedBox
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation
from nerfstudio.engine.optimizers import Optimizers
from nerfstudio.model_components.lib_bilagrid import BilateralGrid, color_correct, slice, total_variation_loss
from nerfstudio.models.base_model import Model, ModelConfig
from nerfstudio.model_components.losses import (
    MidasLoss, LocalMidasLoss, PearsonDepthLoss, LocalPearsonDepthLoss, NormalLoss,
    estimate_scale_and_bias_midas, normalize_depth_midas, mle_loss
)
from nerfstudio.utils import writer
from nerfstudio.utils.dn_utils import normal_from_depth_image
from nerfstudio.utils.colors import get_color
from nerfstudio.utils.misc import torch_compile
from nerfstudio.utils.rich_utils import CONSOLE
from nerfstudio.models.splatfacto import SplatfactoModelConfig
import qr.spatial_py as spu
import torch.nn.functional as F
from nerfstudio.models.RayTracer import RayTracer

EPS = 1e-6

def num_sh_bases(degree: int) -> int:
    """
    Returns the number of spherical harmonic bases for a given degree.
    """
    assert degree <= 4, "We don't support degree greater than 4."
    return (degree + 1) ** 2


def rotate_vector_to_vector(v1: torch.Tensor, v2: torch.Tensor):
    """
    Returns a rotation matrix that rotates v1 to align with v2.
    """
    assert v1.dim() == v2.dim()
    assert v1.shape[-1] == 3
    if v1.dim() == 1:
        v1 = v1[None, ...]
        v2 = v2[None, ...]
    N = v1.shape[0]

    u = v1 / torch.norm(v1, dim=-1, keepdim=True)
    Ru = v2 / torch.norm(v2, dim=-1, keepdim=True)
    I = torch.eye(3, 3, device=v1.device).unsqueeze(0).repeat(N, 1, 1)

    # the cos angle between the vectors
    c = torch.bmm(u.view(N, 1, 3), Ru.view(N, 3, 1)).squeeze(-1)

    eps = 1.0e-10
    # the cross product matrix of a vector to rotate around
    K = torch.bmm(Ru.unsqueeze(2), u.unsqueeze(1)) - torch.bmm(
        u.unsqueeze(2), Ru.unsqueeze(1)
    )
    # Rodrigues' formula
    ans = I + K + (K @ K) / (1 + c)[..., None]
    same_direction_mask = torch.abs(c - 1.0) < eps
    same_direction_mask = same_direction_mask.squeeze(-1)
    opposite_direction_mask = torch.abs(c + 1.0) < eps
    opposite_direction_mask = opposite_direction_mask.squeeze(-1)
    ans[same_direction_mask] = torch.eye(3, device=v1.device)
    ans[opposite_direction_mask] = -torch.eye(3, device=v1.device)
    return ans

def matrix_to_quaternion(rotation_matrix: torch.Tensor):
    """
    Convert a 3x3 rotation matrix to a unit quaternion.
    """
    if rotation_matrix.dim() == 2:
        rotation_matrix = rotation_matrix[None, ...]
    assert rotation_matrix.shape[1:] == (3, 3)

    traces = torch.vmap(torch.trace)(rotation_matrix)
    quaternion = torch.zeros(
        rotation_matrix.shape[0],
        4,
        dtype=rotation_matrix.dtype,
        device=rotation_matrix.device,
    )
    for i in range(rotation_matrix.shape[0]):
        matrix = rotation_matrix[i]
        trace = traces[i]
        if trace > 0:
            S = torch.sqrt(trace + 1.0) * 2
            w = 0.25 * S
            x = (matrix[2, 1] - matrix[1, 2]) / S
            y = (matrix[0, 2] - matrix[2, 0]) / S
            z = (matrix[1, 0] - matrix[0, 1]) / S
        elif (matrix[0, 0] > matrix[1, 1]) and (matrix[0, 0] > matrix[2, 2]):
            S = torch.sqrt(1.0 + matrix[0, 0] - matrix[1, 1] - matrix[2, 2]) * 2
            w = (matrix[2, 1] - matrix[1, 2]) / S
            x = 0.25 * S
            y = (matrix[0, 1] + matrix[1, 0]) / S
            z = (matrix[0, 2] + matrix[2, 0]) / S
        elif matrix[1, 1] > matrix[2, 2]:
            S = torch.sqrt(1.0 + matrix[1, 1] - matrix[0, 0] - matrix[2, 2]) * 2
            w = (matrix[0, 2] - matrix[2, 0]) / S
            x = (matrix[0, 1] + matrix[1, 0]) / S
            y = 0.25 * S
            z = (matrix[1, 2] + matrix[2, 1]) / S
        else:
            S = torch.sqrt(1.0 + matrix[2, 2] - matrix[0, 0] - matrix[1, 1]) * 2
            w = (matrix[1, 0] - matrix[0, 1]) / S
            x = (matrix[0, 2] + matrix[2, 0]) / S
            y = (matrix[1, 2] + matrix[2, 1]) / S
            z = 0.25 * S

        quaternion[i] = torch.tensor(
            [w, x, y, z], dtype=matrix.dtype, device=matrix.device
        )
    return quaternion


def quat_to_rotmat(quat):
    assert quat.shape[-1] == 4, quat.shape
    w, x, y, z = torch.unbind(quat, dim=-1)
    mat = torch.stack(
        [
            1 - 2 * (y**2 + z**2),
            2 * (x * y - w * z),
            2 * (x * z + w * y),
            2 * (x * y + w * z),
            1 - 2 * (x**2 + z**2),
            2 * (y * z - w * x),
            2 * (x * z - w * y),
            2 * (y * z + w * x),
            1 - 2 * (x**2 + y**2),
        ],
        dim=-1,
    )
    return mat.reshape(quat.shape[:-1] + (3, 3))


def random_quat_tensor(N):
    """
    Defines a random quaternion tensor of shape (N, 4)
    """
    u = torch.rand(N)
    v = torch.rand(N)
    w = torch.rand(N)
    return torch.stack(
        [
            torch.sqrt(1 - u) * torch.sin(2 * math.pi * v),
            torch.sqrt(1 - u) * torch.cos(2 * math.pi * v),
            torch.sqrt(u) * torch.sin(2 * math.pi * w),
            torch.sqrt(u) * torch.cos(2 * math.pi * w),
        ],
        dim=-1,
    )

def create_image_ray_bundle(W, H, pose=None, K=None):
    """
    Returns a ray bundle (origin, direction) for an image of size HxW.
    
    Parameters:
    - W: Width of the image.
    - H: Height of the image.
    - pose: (4, 4) torch.Tensor representing the camera-to-world transformation matrix.
    - f: Focal length. If None, defaults to max(H, W)/2 (assuming a pinhole camera model).
    
    Returns:
    - rb: RayBundle object containing origins and directions tensors of shape (H*W, 3).
    """
    if pose is None:
        pose = torch.eye(4, dtype=torch.float32)
    else:
        #pose = pose.clone().detach().float()
        pose = pose.clone().float()

    # W = W.item()
    # H = H.item()

    if K is None:
        f = 0.5 * max(H, W)  # Default focal length
        K = torch.tensor([[f, 0, W // 2], [0, f, H // 2], [0, 0, 1]], dtype=torch.float32)

    #K = K.clone().cpu().detach().float()
    K = K.clone().cpu().float()
    #pose = pose.clone().cpu().detach().float()
    pose = pose.clone().cpu().float()

    # Generate pixel coordinates in image plane
    # i, j = torch.meshgrid(torch.arange(W, dtype=torch.float32), torch.arange(H, dtype=torch.float32))
    # i = i.t().reshape(-1) + 0.5  # (H*W,)
    # j = j.t().reshape(-1) + 0.5  # (H*W,)

    # Camera coordinate system
    # dirs = torch.stack([(i - W * 0.5) / f, -(j - H * 0.5) / f, -torch.ones_like(i)], dim=-1)  # (H*W, 3)
    coords = torch.stack(torch.meshgrid(torch.arange(H), torch.arange(W)), -1)[:, :, (1, 0)]  # H,W,2

    coords = coords.reshape(-1, 2).float() + 0.5  # H*W,2
    coords = torch.cat([coords, torch.ones( H * W, 1, dtype=torch.float32)], 1)  # imn,h*w,3

    # TODO analytical K inverse
    dirs = coords @ torch.inverse(K).T  # H*W,3

    # Rotate ray directions from camera to world coordinates
    rays_d = torch.matmul(dirs, pose[:3, :3].t())  # (H*W, 3)
    rays_d = F.normalize(rays_d, dim=-1, eps=EPS)

    # Ray origins
    rays_o = pose[:3, 3].expand(rays_d.shape)  # (H*W, 3)

    # Create ray bundle
    # rays_o, rays_d = spu.ray_bundle(rays_o, rays_d)

    return rays_o, rays_d

def RGB2SH(rgb):
    """
    Converts from RGB values [0,1] to the 0th spherical harmonic coefficient
    """
    C0 = 0.28209479177387814
    return (rgb - 0.5) / C0


def SH2RGB(sh):
    """
    Converts from the 0th spherical harmonic coefficient to RGB values [0,1]
    """
    C0 = 0.28209479177387814
    return sh * C0 + 0.5


def resize_image(image: torch.Tensor, d: int):
    """
    Downscale images using the same 'area' method in opencv

    :param image shape [H, W, C]
    :param d downscale factor (must be 2, 4, 8, etc.)

    return downscaled image in shape [H//d, W//d, C]
    """
    import torch.nn.functional as tf

    image = image.to(torch.float32)
    weight = (1.0 / (d * d)) * torch.ones((1, 1, d, d), dtype=torch.float32, device=image.device)
    return tf.conv2d(image.permute(2, 0, 1)[:, None, ...], weight, stride=d).squeeze(1).permute(1, 2, 0)


@torch_compile()
def get_viewmat(optimized_camera_to_world):
    """
    function that converts c2w to gsplat world2camera matrix, using compile for some speed
    """
    R = optimized_camera_to_world[:, :3, :3]  # 3 x 3
    T = optimized_camera_to_world[:, :3, 3:4]  # 3 x 1
    # flip the z and y axes to align with gsplat conventions
    R = R * torch.tensor([[[1, -1, -1]]], device=R.device, dtype=R.dtype)
    viewmat_rt = torch.zeros_like(optimized_camera_to_world)
    viewmat_rt[:, :3, :3] = R
    viewmat_rt[:, :3, 3:4] = T
    # analytic matrix inverse to get world2camera matrix
    R_inv = R.transpose(1, 2)
    T_inv = -torch.bmm(R_inv, T)
    viewmat = torch.zeros(R.shape[0], 4, 4, device=R.device, dtype=R.dtype)
    viewmat[:, 3, 3] = 1.0  # homogenous
    viewmat[:, :3, :3] = R_inv
    viewmat[:, :3, 3:4] = T_inv
    return viewmat, viewmat_rt

@dataclass
class RaytracingfactoModelConfig(SplatfactoModelConfig):
    """Raytracingfacto Model Config, Inherits existing gaussian splatfacto"""

    _target: Type = field(default_factory=lambda: RaytracingfactoModel)

    mode: str = "3dgrt"
    """2dgrt or 3dgrt"""

    randomize_data : bool = False

    # _target: Type = field(default_factory=lambda: SplatfactoModel)
    warmup_length: int = 500
    """period of steps where refinement is turned off"""
    refine_every: int = 100
    """period of steps where gaussians are culled and densified"""
    resolution_schedule: int = 3000
    """training starts at 1/d resolution, every n steps this is doubled"""
    background_color: Literal["random", "black", "white"] = "random"
    """Whether to randomize the background color."""
    num_downscales: int = 0
    """at the beginning, resolution is 1/2^d, where d is this number"""
    cull_alpha_thresh: float = 0.1
    """threshold of alpha for culling gaussians. One can set it to a lower value (e.g. 0.005) for higher quality."""
    cull_scale_thresh: float = 0.5
    """threshold of scale for culling huge gaussians"""
    reset_alpha_every: int = 30
    """Every this many refinement steps, reset the alpha"""
    densify_grad_thresh: float = 0.0008
    """threshold of positional gradient norm for densifying gaussians"""
    densify_size_thresh: float = 0.01
    """below this size, gaussians are *duplicated*, otherwise split"""
    sh_degree_interval: int = 1000
    """every n intervals turn on another sh degree"""
    random_init: bool = False
    """whether to initialize the positions uniformly randomly (not SFM points)"""
    num_random: int = 50000
    """Number of gaussians to initialize if random init is used"""
    random_scale: float = 10.0
    "Size of the cube to initialize random gaussians within"
    lambda_ssim: float = 0.2
    """weight of ssim loss"""
    lambda_depth: float = 0.2
    """weight of the depth loss"""
    lambda_normal: float = 0.1
    """Regularization weight for Normal L1 loss"""
    lambda_normal_smooth: float = 0.5
    """Regularization weight for Normal Smooth loss"""
    lambda_uncertainty: float = 1.0
    """Regularization weight for uncertainty losses"""
    stop_split_at: int = 15000
    """stop splitting at this step"""
    sh_degree: int = 3
    """maximum degree of spherical harmonics to use. Set to 0 to only estimate RGB"""
    use_scale_regularization: bool = False
    """If enabled, a scale regularization introduced in PhysGauss (https://xpandora.github.io/PhysGaussian/) is used for reducing huge spikey gaussians."""
    max_gauss_ratio: float = 10.0
    """threshold of ratio of gaussian max to min scale before applying regularization
    loss from the PhysGaussian paper
    """
    min_transmittance: float = 0.001
    """minimum allowable transmittance per ray
    """
    min_alpha: float = 0.0
    """minimum alpha
    """
    adaptive_control_enabled: bool = False
    """
    whether to use adaptive control of the gaussians
    Mostly turned off in case of fine tuning existing splats
    """
    pause_refine_after_reset: int = 0
    output_depth_during_training: bool = False
    camera_optimizer: CameraOptimizerConfig = field(default_factory=lambda: CameraOptimizerConfig(mode="off"))
    """Config of the camera optimizer to use"""
    color_corrected_metrics: bool = False
    """If True, apply color correction to the rendered images before computing the metrics."""

    enable_depth_training: bool = False
    """If True, calculates depth gradients and compares rendered depth to GT depth. Otherwise, doesn't use anything."""

    enable_normal_training: bool = False
    """If True, uses normal loss gradients during training and compares rendered normal to GT normal. Otherwise, doesn't use anything."""
    
    enable_normal_consistency_loss: bool = False
    """If True, uses normal loss gradients during training and compares rendered normal to rendered depth gradient normal."""

    enable_dd_loss: bool = False
    """If True, enables dd loss"""
    
    depth_distortion_alpha : int = 1000
    
    enable_uncertainty: bool = False
    """If true, learn an uncertainty for each Gaussian parameter which can be rendered through uncertainty propagation"""
    
    uncertainty_grad_clip: float = 0.5
    """Absolute value of gradients when propagating uncertainty above this threshold will be clipped. Good for ensuring uncertainty is not unreasonably high"""

    uncertainty_start_step: bool = 0 #5000
    """Iteration step to start uncertainty training"""
    
    initialize_uncertainty_with_hessian: bool = False
    """Whether to initialize the learned per-Gaussian uncertainty with the Hessian diagonal or by heuristic means"""
    
    hessian_diag_pad: float = 1e-3
    """The diagonal padding for the Hessian before inverting to initialize the per-Gaussian covariance (uncertainty)"""

    invert_gt_depth: bool = False
    """Indicates if we that we are using depth with brightest near vs brightest far."""

    max_gaussian_count : int = 3_000_000
    """Limit the maximum number of gaussian in the current field."""

    prune_ratio : float = 0.1
    """Limit the maximum number of gaussian in the current field."""

    bound_extent: float = 1.0
    """Use the large bounds for AABB during BVH construction."""

    # re-3dgs uses 2e-9: src -> https://github.com/NJU-3DV/Relightable3DGaussian/blob/a102f14c7e34b7bd5e421990e6f1593d39ebab83/arguments/__init__.py#L105C49-L105C50
    densify_norm_thresh: float = 2e-7 
    """threshold of positional gradient norm for densifying gaussians"""

    trim_every: int = 1000
    """Number of steps in which we trigger the trimming."""

    only_trim: bool = False
    """Doesn't reset the opacity and only trims the gaussians."""

    trim_ratio: float = 0.1
    """Ratio of the gaussians to trim during pruning stage."""

    strategy: Literal["trim", "default"] = "default"
    """Uses the trim GS strategy"""

    normal_init : bool = False
    
    pure_2d_gaussians: bool = False
    """If set, leave the z-axis scales at the minimum and don't optimize them (only applicable for mode==3dgrt)"""

    visibility_pruning: bool = True
    
    enable_profiling: bool = True

    refine_radiometry: bool = False
    """If True, refines a color correction matrix per camera."""

    radiometry_refine_warmup: int = 2000
    """Optimize for this many steps before turning on radiometry refinement."""


class RaytracingfactoModel(Model):
    """Nerfstudio's implementation of Gaussian Splatting

    Args:
        config: Raytracingfacto configuration to instantiate model
    """

    config: RaytracingfactoModelConfig

    def __init__(
        self,
        *args,
        seed_points: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        gf_profiling_path : Optional[str] = None,
        **kwargs,
    ):
        self.seed_points = seed_points
        self.cameras = kwargs["cameras"].clone()
        self.cameras.metadata = {
            "cam_idx": torch.arange(self.cameras.length())
        }
        self.gf_profiling_path = gf_profiling_path
        super().__init__(*args, **kwargs)
        # self.pause_reinterval = self.num_train_data + self.config.refine_every

        self.depth_loss = PearsonDepthLoss() #MidasLoss(self.device)
        self.local_depth_loss = LocalPearsonDepthLoss() #LocalMidasLoss(self.device)

        self.normal_l1 = NormalLoss('L1')
        self.normal_smooth = NormalLoss('smooth')
        self.cmap = matplotlib.colormaps.get_cmap('Spectral_r')

    def populate_modules(self):
        if self.seed_points is not None and not self.config.random_init:
            means = torch.nn.Parameter(self.seed_points[0])  # (Location, Color)
        else:
            means = torch.nn.Parameter((torch.rand((self.config.num_random, 3)) - 0.5) * self.config.random_scale)
        distances, _ = self.k_nearest_sklearn(means.data, 3)
        distances = torch.from_numpy(distances)
        # find the average of the three nearest neighbors for each point and use that as the scale
        avg_dist = distances.mean(dim=-1, keepdim=True)
        scales = torch.nn.Parameter(torch.log(avg_dist.repeat(1, 3)))
        num_points = means.shape[0]
        quats = torch.nn.Parameter(random_quat_tensor(num_points))
        dim_sh = num_sh_bases(self.config.sh_degree)

        if (
            self.seed_points is not None
            and not self.config.random_init
            # We can have colors without points.
            and self.seed_points[1].shape[0] > 0
        ):
            shs = torch.zeros((self.seed_points[1].shape[0], dim_sh, 3)).float().cuda()
            if self.config.sh_degree > 0:
                shs[:, 0, :3] = RGB2SH(self.seed_points[1] / 255)
                shs[:, 1:, 3:] = 0.0
            else:
                CONSOLE.log("use color only optimization with sigmoid activation")
                shs[:, 0, :3] = torch.logit(self.seed_points[1] / 255, eps=1e-10)
            features_dc = torch.nn.Parameter(shs[:, 0, :])
            features_rest = torch.nn.Parameter(shs[:, 1:, :])
        else:
            features_dc = torch.nn.Parameter(torch.rand(num_points, 3))
            features_rest = torch.nn.Parameter(torch.zeros((num_points, dim_sh - 1, 3)))


        # Compute normals using quats and scales.
        # Note: we are not using the initialization that is being used in DNSplatter. They use
        # a flatter version of initialization. It looks wrong when being used in our dataset.
        if self.config.normal_init:
            def _load_points3D_normals(points_3d):
                transform_matrix = torch.eye(4, dtype=torch.float, device="cpu")[:3, :4]
                pcd = o3d.geometry.PointCloud()
                pcd.points = o3d.utility.Vector3dVector(points_3d.cpu().numpy())
                pcd.estimate_normals(
                    search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30)
                )
                pcd.normalize_normals()
                points3D_normals = torch.from_numpy(np.asarray(pcd.normals, dtype=np.float32))
                points3D_normals = (
                    torch.cat(
                        (points3D_normals, torch.ones_like(points3D_normals[..., :1])), -1
                    )
                    @ transform_matrix.T
                )
                return points3D_normals


            self.normals_seed = _load_points3D_normals(self.seed_points[0]).float()  # type: ignore
            self.normals_seed = self.normals_seed / torch.norm(
                self.normals_seed, dim=-1, keepdim=True
            )
            #normals = torch.nn.Parameter(self.normals_seed.detach())
            scales = torch.log(avg_dist.repeat(1, 3))
            scales[:, 2] = torch.log((avg_dist / 10)[:, 0])
            scales = torch.nn.Parameter(scales.detach())
            quats = torch.zeros(len(self.normals_seed), 4)
            rots = rotate_vector_to_vector(
                torch.tensor(
                    [0, 0, 1], dtype=torch.float, device=self.normals_seed.device
                ).repeat(self.normals_seed.shape[0], 1),
                self.normals_seed,
            )
            quats = matrix_to_quaternion(rots)
            quats = torch.nn.Parameter(quats.detach())
        else: 
            scales = torch.nn.Parameter(torch.log(avg_dist.repeat(1, 3)))
            quats = torch.nn.Parameter(random_quat_tensor(num_points))
            rots = quat_to_rotmat(quats)

            # init random normals based on the above scales and quats
            #normals = F.one_hot(torch.argmin(scales, dim=-1), num_classes=3).float()
            #normals = torch.bmm(rots, normals[:, :, None]).squeeze(-1)
            #normals = F.normalize(normals, dim=1, eps=EPS)
            #normals = torch.nn.Parameter(normals.detach())
        
        # Only 2D scales for 2DGRT
        if self.config.mode == "2dgrt":
            scales = scales[:, :2]

        opacities = torch.nn.Parameter(torch.logit(0.1 * torch.ones(num_points, 1)))
        self.gauss_params = torch.nn.ParameterDict(
            {
                "means": means,
                "scales": scales,
                "quats": quats,
                "features_dc": features_dc,
                "features_rest": features_rest,
                "opacities": opacities,
                "normals": rots[:, :, 2],
            }
        )
        
        if self.config.enable_uncertainty:
            """
            if self.config.initialize_uncertainty_with_hessian:
                # Set uncertainties to zero until we initialize them at uncertainty_start_step
                self.gauss_params.update({
                    "position_uncertainties": torch.log(1e-45 + torch.zeros_like(means)),
                    "scale_uncertainties": torch.log(1e-45 + torch.zeros_like(scales)),
                    "quat_uncertainties": torch.log(1e-45 + torch.zeros_like(quats)),
                    "color_uncertainties": torch.log(1e-45 + torch.zeros_like(features_dc)), # Only estimate RGB uncertainty for colors, ignore SH for efficiency
                    "opacity_uncertainties": torch.log(1e-45 + torch.zeros_like(opacities)),
                })
            else:
            """
            init_var = 1e-3
            # Scale uncertainties based on distance to closest camera that sees the point
            cameras = self.cameras
            #camera_scale_fac = self._get_downscale_factor()
            #cameras.rescale_output_resolution(1 / camera_scale_fac)
            _, c2ws = get_viewmat(cameras.camera_to_worlds)
            Ks = cameras.get_intrinsics_matrices()
            means_dists = 1e9 * torch.ones_like(means[:, 0:1])
            for i in range(c2ws.shape[0]):
                c2w = c2ws[i]
                K = Ks[i]
                w = cameras.width[i]
                h = cameras.height[i]
                means_in_C = c2w[None, :3, :3].mT @ (means[..., None] - c2w[None, :3, 3:4])
                means_in_C_hom = means_in_C.clone() / means_in_C[:, 2:3]
                means_2d = (K[None] @ means_in_C_hom)[:, :2, 0:1]
                in_bounds = (means_2d[:, 0] >= 0) & (means_2d[:, 0] < w) \
                        & (means_2d[:, 1] >= 0) & (means_2d[:, 1] < h) & (means_in_C[:, 2] > 0.01)
                means_dists = torch.where(in_bounds, torch.minimum(means_dists, means_in_C[:, 2]), means_dists)
            
            var_scale = means_dists.detach() ** 2
            #var_scale = means_dists.detach().exp()
            self.gauss_params.update({
                "position_uncertainties": torch.log(var_scale.repeat(1, 3) * init_var),
                "scale_uncertainties": torch.log(var_scale.repeat(1, scales.shape[1]) * init_var),
                "quat_uncertainties": torch.log(var_scale.repeat(1, 4) * init_var),
                "color_uncertainties": torch.log(torch.ones_like(means) * init_var), # Only estimate RGB uncertainty for colors, ignore SH for efficiency
                "opacity_uncertainties": torch.log(torch.ones_like(means[:, :1]) * init_var),
            })

        self.camera_optimizer: CameraOptimizer = self.config.camera_optimizer.setup(
            num_cameras=self.num_train_data, device="cpu"
        )

        # self.exposure_mapping = {cam_info.image_name: idx for idx, cam_info in enumerate(cam_infos)}
        radiometry = torch.eye(3, 4, device="cuda")[None].repeat(self.num_train_data, 1, 1)
        self.radiometry = torch.nn.Parameter(radiometry)
        self.radiometry_params = torch.nn.ParameterDict({"radiometry": self.radiometry}) 

        # metrics
        from torchmetrics.image import PeakSignalNoiseRatio
        from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

        self.psnr = PeakSignalNoiseRatio(data_range=1.0)
        self.ssim = SSIM(data_range=1.0, size_average=True, channel=3)
        self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True)
        self.step = 0

        self.crop_box: Optional[OrientedBox] = None
        if self.config.background_color == "random":
            self.background_color = torch.tensor(
                [0.1490, 0.1647, 0.2157]
            )  # This color is the same as the default background color in Viser. This would only affect the background color when rendering.
        else:
            self.background_color = get_color(self.config.background_color)

        if self.config.strategy == "default":
            self.strategy = RayTracerDefaultStrategy(
                prune_opa=self.config.cull_alpha_thresh,
                grow_grad3d=self.config.densify_grad_thresh,
                grow_scale3d=self.config.densify_size_thresh,
                prune_scale3d=self.config.cull_scale_thresh,
                refine_start_iter=self.config.warmup_length,
                refine_stop_iter=self.config.stop_split_at,
                reset_every=self.config.reset_alpha_every*self.config.refine_every,
                refine_every=self.config.refine_every,
                pause_refine_after_reset=max(self.num_train_data + self.config.refine_every, self.config.pause_refine_after_reset),
                max_gaussian_count=self.config.max_gaussian_count,
                prune_ratio=self.config.prune_ratio,
                visibility_pruning=self.config.visibility_pruning,
                verbose=True
            )
        elif self.config.strategy == "trim":
            pause_after_reset_multiplier = 2
            trim_step = ((self.config.trim_every // self.num_train_data) + (1 + pause_after_reset_multiplier)) * self.num_train_data
            self.strategy = TrimStrategy(
                prune_opa=self.config.cull_alpha_thresh,
                grow_grad3d=self.config.densify_grad_thresh,
                grow_norm3d=self.config.densify_norm_thresh,
                grow_scale3d=self.config.densify_size_thresh,
                prune_scale3d=self.config.cull_scale_thresh,
                refine_start_iter=self.config.warmup_length,
                refine_stop_iter=self.config.stop_split_at,
                reset_every=self.config.reset_alpha_every*self.config.refine_every,
                refine_every=self.config.refine_every,
                pause_refine_before_reset=(trim_step - self.num_train_data),
                pause_refine_after_reset=pause_after_reset_multiplier * self.num_train_data,
                max_gaussian_count=self.config.max_gaussian_count,
                prune_ratio=self.config.prune_ratio,
                trim_every=trim_step,
                only_trim=self.config.only_trim,
                trim_ratio=self.config.trim_ratio,
                verbose=True
            )
        else:
            raise NotImplementedError("This strategy is not implemented.")

        self.strategy_state = self.strategy.initialize_state(scene_scale=1.0)

        self.tracer = RayTracer(
            min_T=self.config.min_transmittance, 
            min_alpha=self.config.min_alpha, 
            mode=self.config.mode,
            is_training=self.training,
            bound_extent=self.config.bound_extent,
            enable_depth_training=self.config.enable_depth_training,
            enable_normal_training=self.config.enable_normal_training or self.config.enable_normal_consistency_loss,
            return_rb_grads=self.config.camera_optimizer.mode != "off" or self.config.camera_optimizer.refine_intrinsics,
            enable_dd_loss=self.config.enable_dd_loss,
            enable_uncertainty=self.config.enable_uncertainty,
            uncertainty_grad_clip=self.config.uncertainty_grad_clip,
            strategy=self.config.strategy,
            enable_profiling = self.config.enable_profiling,
            profiler_path = str(self.gf_profiling_path) if self.gf_profiling_path is not None else None 
        )
        self.camera : Optional[Cameras] = None

    @property
    def colors(self):
        if self.config.sh_degree > 0:
            return SH2RGB(self.features_dc)
        else:
            return torch.sigmoid(self.features_dc)

    @property
    def shs_0(self):
        if self.config.sh_degree > 0:
            return self.features_dc
        else:
            return RGB2SH(torch.sigmoid(self.features_dc))

    @property
    def shs_rest(self):
        return self.features_rest

    @property
    def num_points(self):
        return self.means.shape[0]

    @property
    def means(self):
        return self.gauss_params["means"]

    @property
    def scales(self):
        return self.gauss_params["scales"]

    @property
    def quats(self):
        return self.gauss_params["quats"]

    @property
    def features_dc(self):
        return self.gauss_params["features_dc"]

    @property
    def features_rest(self):
        return self.gauss_params["features_rest"]

    @property
    def opacities(self):
        return self.gauss_params["opacities"]

    @property
    def normals(self):
        return self.gauss_params["normals"]

    def load_state_dict(self, dict, **kwargs):  # type: ignore
        # resize the parameters to match the new number of points
        self.step = 30000
        if "means" in dict:
            # For backwards compatibility, we remap the names of parameters from
            # means->gauss_params.means since old checkpoints have that format
            for p in ["means", "scales", "quats", "features_dc", "features_rest", "opacities"]:
                dict[f"gauss_params.{p}"] = dict[p]
        newp = dict["gauss_params.means"].shape[0]
        for name, param in self.gauss_params.items():
            old_shape = param.shape
            new_shape = (newp,) + old_shape[1:]
            self.gauss_params[name] = torch.nn.Parameter(torch.zeros(new_shape, device=self.device))
        super().load_state_dict(dict, **kwargs)
        if self.config.randomize_data:
            with torch.no_grad():
                self.gauss_params["means"] += torch.randn_like(self.gauss_params["means"]) * 0.03
                # self.gauss_params["features_dc"] += torch.randn_like(self.gauss_params["features_dc"]) * 1.53
                self.gauss_params["quats"] += torch.randn_like(self.gauss_params["quats"]) * 1.23
                self.gauss_params["scales"] -= torch.randn_like(self.gauss_params["scales"]) * 0.83
                self.gauss_params["opacities"] += torch.randn_like(self.gauss_params["opacities"]) * 0.53

        
    def k_nearest_sklearn(self, x: torch.Tensor, k: int):
        """
            Find k-nearest neighbors using sklearn's NearestNeighbors.
        x: The data tensor of shape [num_samples, num_features]
        k: The number of neighbors to retrieve
        """
        # Convert tensor to numpy array
        x_np = x.cpu().numpy()

        # Build the nearest neighbors model
        from sklearn.neighbors import NearestNeighbors

        nn_model = NearestNeighbors(n_neighbors=k + 1, algorithm="auto", metric="euclidean").fit(x_np)

        # Find the k-nearest neighbors
        distances, indices = nn_model.kneighbors(x_np)

        # Exclude the point itself from the result and return
        return distances[:, 1:].astype(np.float32), indices[:, 1:].astype(np.float32)

    def set_crop(self, crop_box: Optional[OrientedBox]):
        self.crop_box = crop_box

    def set_background(self, background_color: torch.Tensor):
        assert background_color.shape == (3,)
        self.background_color = background_color

    def step_post_backward(self, step):
        assert step == self.step
        # self.tracer.info_dict["grad_pos_densities"] = self.means.grad
        # if step % 100 == 0:
        #     import pdb; pdb.set_trace();
        self.strategy.step_post_backward(
            params=self.gauss_params,
            optimizers=self.optimizers,
            state=self.strategy_state,
            step=self.step,
            info=self.tracer.info_dict,
            enabled=self.config.adaptive_control_enabled
        )
        
    def step_post_optim_step(self, step):
        if self.config.enable_uncertainty and step == self.config.uncertainty_start_step and self.config.initialize_uncertainty_with_hessian:
            self.initialize_uncertainty_with_hessian()

    def get_training_callbacks(
        self, training_callback_attributes: TrainingCallbackAttributes
    ) -> List[TrainingCallback]:
        cbs = []
        cbs.append(
            TrainingCallback(
                [TrainingCallbackLocation.BEFORE_TRAIN_ITERATION],
                self.step_cb,
                args=[training_callback_attributes.optimizers],
            )
        )
        cbs.append(
            TrainingCallback(
                [TrainingCallbackLocation.DURING_TRAIN_AFTER_LOSS_BACKWARD_BEFORE_OPTIM_STEP],
                self.step_post_backward,
            )
        )
        cbs.append(
            TrainingCallback(
                [TrainingCallbackLocation.AFTER_TRAIN_ITERATION],
                self.step_post_optim_step,
            )
        )
        return cbs

    def step_cb(self, optimizers: Optimizers, step):
        self.step = step
        self.optimizers = optimizers.optimizers

    def get_gaussian_param_groups(self) -> Dict[str, List[Parameter]]:
        # Here we explicitly use the means, scales as parameters so that the user can override this function and
        # specify more if they want to add more optimizable params to gaussians.
        return {
            name: [self.gauss_params[name]]
            #for name in ["means", "scales", "quats", "features_dc", "features_rest", "opacities", "normals"]
            for name in self.gauss_params.keys()
        }
        # return {
        #     name: [self.gauss_params[name]]
        #     for name in ["means", "scales", "quats", "opacities"] #, "features_dc"] #, "features_rest", "opacities"]
        # }



    def get_param_groups(self) -> Dict[str, List[Parameter]]:
        """Obtain the parameter groups for the optimizers

        Returns:
            Mapping of different parameter groups
        """
        gps = self.get_gaussian_param_groups()
        self.camera_optimizer.get_param_groups(param_groups=gps)
        gps["radiometry_opt"] = [self.radiometry_params["radiometry"]]
        return gps

    def _get_downscale_factor(self):
        if self.training:
            return 2 ** max(
                (self.config.num_downscales - self.step // self.config.resolution_schedule),
                0,
            )
        else:
            return 1

    def _downscale_if_required(self, image):
        d = self._get_downscale_factor()
        if d > 1:
            return resize_image(image, d)
        return image

    @staticmethod
    def get_empty_outputs(width: int, height: int, background: torch.Tensor) -> Dict[str, Union[torch.Tensor, List]]:
        rgb = background.repeat(height, width, 1)
        depth = background.new_ones(*rgb.shape[:2], 1) * 10
        accumulation = background.new_zeros(*rgb.shape[:2], 1)
        return {"rgb": rgb, "depth": depth, "accumulation": accumulation, "background": background}

    def _get_background_color(self):
        if self.config.background_color == "random":
            if self.training:
                background = torch.rand(3, device=self.device)
            else:
                background = self.background_color.to(self.device)
        elif self.config.background_color == "white":
            background = torch.ones(3, device=self.device)
        elif self.config.background_color == "black":
            background = torch.zeros(3, device=self.device)
        else:
            raise ValueError(f"Unknown background color {self.config.background_color}")
        return background

    def _apply_bilateral_grid(self, rgb: torch.Tensor, cam_idx: int, H: int, W: int) -> torch.Tensor:
        # make xy grid
        grid_y, grid_x = torch.meshgrid(
            torch.linspace(0, 1.0, H, device=self.device),
            torch.linspace(0, 1.0, W, device=self.device),
            indexing="ij",
        )
        grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0)

        out = slice(
            bil_grids=self.bil_grids,
            rgb=rgb,
            xy=grid_xy,
            grid_idx=torch.tensor(cam_idx, device=self.device, dtype=torch.long),
        )
        return out["rgb"]

    def get_outputs(self, camera_orig: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
        """Takes in a camera and returns a dictionary of outputs.

        Args:
            camera: The camera(s) for which output images are rendered. It should have
            all the needed information to compute the outputs.

        Returns:
            Outputs of model. (ie. rendered colors)
        """
        if not isinstance(camera_orig, Cameras):
            print("Called get_outputs with not a camera")
            return {}
    
        if self.training:
            assert camera_orig.shape[0] == 1, "Only one camera at a time"
            camera = self.camera_optimizer.apply_to_camera(camera_orig)
        else:
            camera = camera_orig

        optimized_camera_to_world = camera.camera_to_worlds

        # cropping
        if self.crop_box is not None and not self.training:
            crop_ids = self.crop_box.within(self.means).squeeze()
            if crop_ids.sum() == 0:
                return self.get_empty_outputs(
                    int(camera.width.item()), int(camera.height.item()), self.background_color
                )
        else:
            crop_ids = None

        if crop_ids is not None:
            opacities_crop = self.opacities[crop_ids]
            means_crop = self.means[crop_ids]
            features_dc_crop = self.features_dc[crop_ids]
            features_rest_crop = self.features_rest[crop_ids]
            scales_crop = self.scales[crop_ids]
            quats_crop = self.quats[crop_ids]
        else:
            opacities_crop = self.opacities
            means_crop = self.means
            features_dc_crop = self.features_dc
            features_rest_crop = self.features_rest
            scales_crop = self.scales
            quats_crop = self.quats

        colors_crop = torch.cat((features_dc_crop[:, None, :], features_rest_crop), dim=1)
        # self.config.sh_degree = 0
        # colors_crop = np.sqrt(1/4*np.pi)*features_dc_crop + 0.5
        # colors_crop = SH2RGB(features_dc_crop)
        # colors_crop = torch.clamp(self.colors, 0, 1)
        camera_scale_fac = self._get_downscale_factor()
        camera.rescale_output_resolution(1 / camera_scale_fac)
        viewmat, viewmat_rt = get_viewmat(optimized_camera_to_world)
        assert camera_orig.distortion_params is None or camera_orig.distortion_params.norm() < 1e-12, "Distortion not yet supported in raytracingfacto"
        K = camera.get_intrinsics_matrices().cuda()
        W, H = int(camera.width.item()), int(camera.height.item())
        self.last_size = (H, W)
        camera.rescale_output_resolution(camera_scale_fac)  # type: ignore

        quats_crop = torch.nn.functional.normalize(quats_crop, dim=-1)
        #rots = quat_to_rotmat(quats_crop)

        if self.config.sh_degree > 0:
            sh_degree_to_use = min(self.step // self.config.sh_degree_interval, self.config.sh_degree)
        else:
            colors_crop = torch.sigmoid(colors_crop).squeeze(1)  # [N, 1, 3] -> [N, 3]
            sh_degree_to_use = None
        
        r_o, r_d = create_image_ray_bundle(W, H, viewmat_rt[0], K[0])
        if r_d.norm(dim=-1).any() > 1 or r_d.flatten().any() > 1.0:
            print("SHOULD NOT BE HERE")
            breakpoint()

        #rb = spu.ray_bundle(r_o.cuda(), r_d.cuda())
        rb = {"origins": r_o.cuda(), "directions": r_d.cuda()}
        colors_crop = colors_crop.reshape(colors_crop.shape[0],-1)

        exp_scales = torch.exp(scales_crop)
        if self.config.mode != "2dgrt" and self.config.pure_2d_gaussians:
            exp_scales = torch.cat([exp_scales[:, :2], torch.zeros_like(exp_scales[:, 2:3])], 1)

        elsies = {
            "sh_lobes": torch.tensor(sh_degree_to_use),
            "colors": colors_crop,
            "positions": means_crop,
            "opacities": torch.sigmoid(opacities_crop),
            "scales": exp_scales,
            "quats": quats_crop,
        }
        
        if self.config.enable_uncertainty:
            # We estimate log uncertainty
            elsies.update({
                k: v.exp() for k, v in self.gauss_params.items() if k.endswith("uncertainties")
            })
            
            # Uncertainty propagation for scale uncertainty (exp function) and opacity (sigmoid)
            # The uncertainty propagation for f(x) is df/dx @ Sigma_x @ df/dx^T

            # d(exp(x))/dx = exp(x)
            elsies["scale_uncertainties"] = exp_scales.detach()**2 * elsies["scale_uncertainties"]
            
            exp_neg_o = torch.exp(-opacities_crop.detach())
            dsigmoido_do = exp_neg_o / (1 + exp_neg_o)**2
            elsies["opacity_uncertainties"] = dsigmoido_do**2 * elsies["opacity_uncertainties"]

        render_dict = self.tracer(elsies=elsies, rb=rb)
         
        (
            render, transmittance, depth, median_depth, normal, dd_loss
        ) = [
            render_dict[k] for k in [
                "rgb", "transmittance", "depth", "median_depth", "normal", "depth_distortion_loss"
            ]
        ]

        # convert normals from world space to camera space
        # NOTE: the GT normals are now transformed so that the normal uncertainty is valid.
        # Performing an uncertainty propagation here would densify the diagonal covariance matrix
        #normal = normal @ camera.camera_to_worlds.squeeze(0)[:3, :3]
        normals_im = torch.nn.functional.normalize(normal, dim=-1)
        normals_im = (normals_im + 1) / 2
        # HxWx3 -> 3xHxW -> 1x3xHxW -> 1xHxWx3
        render = render.reshape(H,W,3).permute(2,0,1).unsqueeze(0).permute(0,2,3,1)
        true_depth = depth.reshape(H,W,1).permute(2,0,1).unsqueeze(0).permute(0,2,3,1).squeeze(0)
        true_median_depth = median_depth.reshape(H,W,1).permute(2,0,1).unsqueeze(0).permute(0,2,3,1).squeeze(0)
        normals_im = normals_im.reshape(H,W,3).permute(2,0,1).unsqueeze(0).permute(0,2,3,1).squeeze(0)
        dd_loss = dd_loss.reshape(H,W,1).permute(2,0,1).unsqueeze(0).permute(0,2,3,1).squeeze(0)

        u = torch.arange(W, dtype=torch.float32).to(true_depth.device)
        v = torch.arange(H, dtype=torch.float32).to(true_depth.device)
        grid = torch.meshgrid(u, v)
        coords_x = (grid[0] - K[0,0,2]) / K[0,0,0]
        coords_y = (grid[1] - K[0,1,2]) / K[0,1,1]


        # coords_uv = torch.stack((coords_x, coords_y), dim=-1)
        # theta = torch.asin(coords_uv.norm(dim=-1, keepdim=True))
        # theta = theta.permute(1,0,2)
        # depth_im = true_depth * torch.cos(theta)

        # Following has a better numerical stability than using trignometric functions
        # Both are equivalent formulations but due to numerical stability it's better to use the following
        coords_z = torch.ones_like(coords_x)
        coords_hom = torch.stack((coords_x, coords_y, coords_z), dim=-1)
        coords_hom = coords_hom.permute(1, 0, 2)
        depth = true_depth / (coords_hom.norm(dim=-1, keepdim=True) + 1e-5)
        median_depth = true_median_depth / (coords_hom.norm(dim=-1, keepdim=True) + 1e-5)

        self.info = {}
        if self.training:
            self.strategy.step_pre_backward(
                self.gauss_params, self.optimizers, self.strategy_state, self.step
            )

        background = self._get_background_color()

        # radiometry refinement
        # 1xHxWx3 * 1x3x3
        # hxwx3 * 3x3
        # TODO figure out why this breaks TSDF script
        idx = camera.metadata.get('cam_idx', None) if camera.metadata is not None else None
        has_cam_idx = idx is not None

        if has_cam_idx and self.step > self.config.radiometry_refine_warmup and self.config.refine_radiometry and self.training:
            render = torch.matmul(render, self.radiometry[idx, :3, :3]) + self.radiometry[idx, :3, 3].view(1, 1, 1, 3)

        render = torch.clamp(render, 0.0, 1.0)

        if background.shape[0] == 3 and not self.training:
            background = background.expand(H, W, 3)

        c2w = camera.camera_to_worlds.squeeze(0).detach()
        c2w = c2w @ torch.diag(
            torch.tensor([1, -1, -1, 1], device=c2w.device, dtype=c2w.dtype)
        )
        self.camera = camera
        # '''
        surface_normal = normal_from_depth_image(
            depths=depth.detach(),
            fx=camera.fx.item(),
            fy=camera.fy.item(),
            cx=camera.cx.item(),
            cy=camera.cy.item(),
            img_size=(depth.shape[1], depth.shape[0]), # because of training downscaling self.config.num_downscales
            c2w=torch.eye(4, dtype=torch.float, device=depth.device),
            device=self.device,
            smooth=False,
        )
        surface_normal = surface_normal @ torch.diag(
            torch.tensor([1, -1, -1], device=depth.device, dtype=depth.dtype)
        )
        surface_normal = (1 + surface_normal) / 2
        # '''

        ret = {
            "rgb": render.squeeze(0),  # type: ignore
            "depth": depth, #render[0, ...],  # type: ignore
            #"true depth": true_depth, #render[0, ...],  # type: ignore
            "median_depth": median_depth, # type: ignore
            #"true_median_depth": true_median_depth, # type: ignore
            "normal": normals_im,
            "surface_normal": surface_normal,
            #"accumulation": None, # crender,  # type: ignore
            "background": background,  # type: ignore
            "dd_loss": dd_loss,
            "refined_camera": camera,
        }  # type: ignore
        
        if self.config.enable_uncertainty:
            for k, v in render_dict.items():
                if k.endswith("uncertainty"):
                    v = v.reshape(H, W, -1)
                    ret[k] = v

                    # Make a separate visualization that can be viewed as a heat map
                    v_viz = v.clone().mean(-1, keepdim=True)
                    max_viz = 1e4 # Some uncertainty can be very large where normals are orthoganal to the ray
                    v_viz[v_viz > max_viz] = max_viz
                    ret[k + "_viz"] = v_viz

        return ret

    def get_gt_img(self, image: torch.Tensor):
        """Compute groundtruth image with iteration dependent downscale factor for evaluation purpose

        Args:
            image: tensor.Tensor in type uint8 or float32
        """
        if image.dtype == torch.uint8:
            image = image.float() / 255.0
        gt_img = self._downscale_if_required(image)
        return gt_img.to(self.device)

    def composite_with_background(self, image, background) -> torch.Tensor:
        """Composite the ground truth image with a background color when it has an alpha channel.

        Args:
            image: the image to composite
            background: the background color
        """
        if image.shape[2] == 4:
            alpha = image[..., -1].unsqueeze(-1).repeat((1, 1, 3))
            return alpha * image[..., :3] + (1 - alpha) * background
        else:
            return image

    def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
        """Compute and returns metrics.

        Args:
            outputs: the output to compute loss dict to
            batch: ground truth batch corresponding to outputs
        """
        gt_rgb = self.composite_with_background(self.get_gt_img(batch["image"]), outputs["background"])
        metrics_dict = {}
        predicted_rgb = outputs["rgb"]

        metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb)
        if self.config.color_corrected_metrics:
            cc_rgb = color_correct(predicted_rgb, gt_rgb)
            metrics_dict["cc_psnr"] = self.psnr(cc_rgb, gt_rgb)

        metrics_dict["gaussian_count"] = self.num_points

        for k, v in self.gauss_params.items():
            if k.endswith("uncertainties"):
                v = v.exp()
                metrics_dict["gauss_uncertainties/" + k + "_min"] = v.min()
                metrics_dict["gauss_uncertainties/" + k + "_max"] = v.max()
                metrics_dict["gauss_uncertainties/" + k + "_mean"] = v.mean()
                metrics_dict["gauss_uncertainties/" + k + "_median"] = v.median()
        
        for k, v in outputs.items():
            if k.endswith("uncertainty"):
                metrics_dict["output_uncertainties/" + k + "_min"] = v.min()
                metrics_dict["output_uncertainties/" + k + "_max"] = v.max()
                metrics_dict["output_uncertainties/" + k + "_mean"] = v.mean()
                metrics_dict["output_uncertainties/" + k + "_median"] = v.median()

        self.camera_optimizer.get_metrics_dict(metrics_dict)
        return metrics_dict

    def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Tensor]:
        """Computes and returns the losses dict.

        Args:
            outputs: the output to compute loss dict to
            batch: ground truth batch corresponding to outputs
            metrics_dict: dictionary of metrics, some of which we can use for loss
        """
        # breakpoint()
        gt_img = self.composite_with_background(self.get_gt_img(batch["image"]), outputs["background"])
        pred_img = outputs["rgb"]

        # Set masked part of both ground-truth and rendered image to black.
        # This is a little bit sketchy for the SSIM loss.
        if "mask" in batch:
            # batch["mask"] : [H, W, 1]
            mask = self._downscale_if_required(batch["mask"])
            mask = mask.to(self.device)
            assert mask.shape[:2] == gt_img.shape[:2] == pred_img.shape[:2]
            gt_img = gt_img * mask
            pred_img = pred_img * mask
        
        L_uncert = 0.0
        calc_uncert_loss = self.config.enable_uncertainty and self.step > self.config.uncertainty_start_step

        L_rgb = torch.abs(gt_img - pred_img).mean()
        if calc_uncert_loss:
            L_uncert += self.config.lambda_uncertainty * mle_loss(gt_img, pred_img, outputs["rgb_uncertainty"])

        simloss = 1 - self.ssim(gt_img.permute(2, 0, 1)[None, ...], pred_img.permute(2, 0, 1)[None, ...])
        if self.config.use_scale_regularization and self.step % 10 == 0:
            scale_exp = torch.exp(self.scales)
            scale_reg = (
                torch.maximum(
                    scale_exp.amax(dim=-1) / scale_exp.amin(dim=-1),
                    torch.tensor(self.config.max_gauss_ratio),
                )
                - self.config.max_gauss_ratio
            )
            scale_reg = 0.1 * scale_reg.mean()
        else:
            scale_reg = torch.tensor(0.0).to(self.device)

   
        ## Depth loss 
        depth_loss = 0
        #lambda_mono_depth = 0.2 #* (0.98 ** (self.step // 200)) + 0.01
        if "depth_image" in batch and self.config.enable_depth_training:

            gt_depth = self.get_gt_img(batch["depth_image"]).to(self.device)
            gt_depth = torch.clamp(gt_depth, max=1e4)
            pred_depth = outputs["depth"]
            
            min_depth = 0.01
            valid_mask = (gt_depth.nan_to_num_() > min_depth) & (pred_depth.nan_to_num_() > min_depth)
            if valid_mask.int().sum() > 0:
                    
                pred_inv_depth = pred_depth.clone()
                gt_inv_depth = gt_depth.clone()
                pred_inv_depth[valid_mask] = 1.0 / pred_depth[valid_mask]
                gt_inv_depth[valid_mask] = 1.0 / gt_depth[valid_mask]

                global_depth_loss = self.depth_loss(pred_inv_depth, gt_inv_depth)[valid_mask].mean()
                local_depth_loss = self.local_depth_loss(pred_inv_depth, gt_inv_depth, valid_mask)

                depth_loss = self.config.lambda_depth * (
                    global_depth_loss + self.config.lambda_depth * local_depth_loss)

                assert not depth_loss.isnan().any()

                if calc_uncert_loss:
                    # Scale gt depth into pred depth scale before sending to MLE loss
                    gt_depth_norm = normalize_depth_midas(gt_depth[valid_mask])
                    scale, bias = estimate_scale_and_bias_midas(pred_depth)
                    gt_depth_pred_scale = scale * gt_depth_norm + bias
                    #L_uncert += self.config.lambda_depth * self.config.lambda_uncertainty * \
                    L_uncert += self.config.lambda_uncertainty * \
                            mle_loss(gt_depth_pred_scale, pred_depth[valid_mask], outputs["depth_uncertainty"][valid_mask])

        normal_loss = 0.0
        scale_loss = 0.0
        if self.config.enable_normal_training or self.config.enable_normal_consistency_loss:
            pred_normal = outputs["normal"]
            normal_loss += self.config.lambda_normal_smooth * self.normal_smooth(pred_normal)
            if self.config.enable_normal_training:
                assert "depth_image" in batch, "Need depth image in batch for GT normals"
                gt_normal = normal_from_depth_image(
                    depths=batch["depth_image"].to(self.device),
                    fx=self.camera.fx.item(),
                    fy=self.camera.fy.item(),
                    cx=self.camera.cx.item(),
                    cy=self.camera.cy.item(),
                    img_size=(self.camera.width.item(), self.camera.height.item()),
                    c2w=torch.eye(4, dtype=torch.float, device=self.device),
                    device=self.device,
                    smooth=False,
                )

                gt_normal = gt_normal @ torch.diag(
                    torch.tensor(
                        [1, -1, -1], device=self.device, dtype=gt_normal.dtype
                    )
                )
                # Transform normal from camera space to world space
                gt_normal = gt_normal @ outputs["refined_camera"].camera_to_worlds.squeeze(0)[:3, :3].T
                gt_normal = (1 + gt_normal) / 2

                normal_loss += self.normal_l1(pred_normal, gt_normal)

                if calc_uncert_loss:
                    # Convert normal back to [-1,1] like it originally was
                    L_uncert += self.config.lambda_uncertainty * \
                            mle_loss(2 * gt_normal - 1, 2 * pred_normal - 1, outputs["normal_uncertainty"])
            
            if self.config.enable_normal_consistency_loss:
                gt_normal_depth = normal_from_depth_image(
                    depths=outputs["depth"].detach(),
                    fx=self.camera.fx.item(),
                    fy=self.camera.fy.item(),
                    cx=self.camera.cx.item(),
                    cy=self.camera.cy.item(),
                    img_size=(self.camera.width.item(), self.camera.height.item()),
                    c2w=torch.eye(4, dtype=torch.float, device=self.device),
                    device=self.device,
                    smooth=False,
                )

                gt_normal_depth = gt_normal_depth @ torch.diag(
                    torch.tensor(
                        [1, -1, -1], device=self.device, dtype=gt_normal_depth.dtype
                    )
                )
                # Transform normal from camera space to world space
                gt_normal = gt_normal @ outputs["refined_camera"].camera_to_worlds.squeeze(0)[:3, :3].T
                gt_normal_depth = (1 + gt_normal_depth) / 2

                normal_loss += self.normal_l1(pred_normal, gt_normal_depth)
                
                if calc_uncert_loss:
                    # Convert normal back to [-1,1] like it originally was
                    L_uncert += self.config.lambda_uncertainty * \
                            mle_loss(2 * gt_normal_depth - 1, 2 * pred_normal - 1, outputs["normal_uncertainty"])

            if self.config.mode == "3dgrt":
                #scale_loss = torch.min(torch.exp(self.scales), dim=1, keepdim=True)[0].mean()
                # Normal is computed as 3rd column of rotation matrix. Only z-axis scale should be small.
                scale_loss = torch.exp(self.scales)[:, 2].mean()

        loss_dict = {
            "main_loss": (1 - self.config.lambda_ssim) * L_rgb + self.config.lambda_ssim * simloss,
            "uncertainty_loss": L_uncert,
            "depth_loss": depth_loss,
            "normal_loss": self.config.lambda_normal * normal_loss,
            "scale_loss": scale_loss,
            "scale_reg": scale_reg,
        }
        
        if self.config.enable_dd_loss:
            loss_dict["dd_loss"] = self.config.depth_distortion_alpha * outputs["dd_loss"].mean()
            
        if self.training:
            # Add loss from camera optimizer
            self.camera_optimizer.get_loss_dict(loss_dict)
        
        return loss_dict

    @torch.no_grad()
    def get_outputs_for_camera(self, camera: Cameras, obb_box: Optional[OrientedBox] = None) -> Dict[str, torch.Tensor]:
        """Takes in a camera, generates the raybundle, and computes the output of the model.
        Overridden for a camera-based gaussian model.

        Args:
            camera: generates raybundle
        """
        assert camera is not None, "must provide camera to gaussian model"
        self.set_crop(obb_box)
        outs = self.get_outputs(camera.to(self.device))
        return outs  # type: ignore

    def get_image_metrics_and_images(
        self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]
    ) -> Tuple[Dict[str, float], Dict[str, torch.Tensor]]:
        """Writes the test image outputs.

        Args:
            image_idx: Index of the image.
            step: Current step.
            batch: Batch of data.
            outputs: Outputs of the model.

        Returns:
            A dictionary of metrics.

        Internal note:
            This method is claee as a part of `get_eval_image_metrics_and_images` which 
            is being called as a part of evaluation iteration during training as a part 
            of `eval_iteration` inside `Trainer.py`. Therefore, we can update depth as 
            a part of this routine and see how that works out for us.
        """
        gt_rgb = self.composite_with_background(self.get_gt_img(batch["image"]), outputs["background"])
        predicted_rgb = outputs["rgb"]
        cc_rgb = None

        combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1)

        if self.config.color_corrected_metrics:
            cc_rgb = color_correct(predicted_rgb, gt_rgb)
            cc_rgb = torch.moveaxis(cc_rgb, -1, 0)[None, ...]

        # Switch images from [H, W, C] to [1, C, H, W] for metrics computations
        gt_rgb = torch.moveaxis(gt_rgb, -1, 0)[None, ...]
        predicted_rgb = torch.moveaxis(predicted_rgb, -1, 0)[None, ...]

        psnr = self.psnr(gt_rgb, predicted_rgb)
        ssim = self.ssim(gt_rgb, predicted_rgb)
        lpips = self.lpips(gt_rgb, predicted_rgb)

        # all of these metrics will be logged as scalars
        metrics_dict = {"psnr": float(psnr.item()), "ssim": float(ssim)}  # type: ignore
        metrics_dict["lpips"] = float(lpips)

        if self.config.color_corrected_metrics:
            assert cc_rgb is not None
            cc_psnr = self.psnr(gt_rgb, cc_rgb)
            cc_ssim = self.ssim(gt_rgb, cc_rgb)
            cc_lpips = self.lpips(gt_rgb, cc_rgb)
            metrics_dict["cc_psnr"] = float(cc_psnr.item())
            metrics_dict["cc_ssim"] = float(cc_ssim)
            metrics_dict["cc_lpips"] = float(cc_lpips)

        if "depth_image" in batch:
            gt_depth = self.get_gt_img(batch["depth_image"])
            H, W = gt_depth.shape[:2]
            gt_depth = gt_depth.reshape(H, W)
            gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - gt_depth.min())

            pred_depth = outputs["depth"]
            pred_depth = pred_depth.reshape(H, W)
            pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - pred_depth.min())

            # Convert depth to color map using matplotlib for better visualization.
            gt_colored_depth = self.cmap(gt_depth.cpu().numpy())[:, :, :3]
            pred_colored_depth = self.cmap(pred_depth.cpu().numpy())[:, :, :3]
            combined_depth = np.concatenate((gt_colored_depth, pred_colored_depth), axis=1)
            combined_torch = torch.from_numpy(combined_depth).to(combined_rgb.device)

            # This allows for coherent viewing. This ensures we can slide the same 
            # image inside tensorboard.
            combined_rgb = torch.cat((combined_rgb, combined_torch), dim=0)

            # Add normal to viewing
            #gt_normal = outputs["surface_normal"]
            gt_normal = normal_from_depth_image(
#                depths=depth_out.detach(),
                depths=batch["depth_image"].to(pred_depth.device),
                fx=self.camera.fx.item(),
                fy=self.camera.fy.item(),
                cx=self.camera.cx.item(),
                cy=self.camera.cy.item(),
                img_size=(self.camera.width.item(), self.camera.height.item()),
                c2w=torch.eye(4, dtype=torch.float, device=pred_depth.device),
                device=self.device,
                smooth=False,
            )
            gt_normal = gt_normal @ torch.diag(
                torch.tensor([1, -1, -1], device=gt_normal.device, dtype=gt_normal.dtype)
            )
            # Transform normal from camera space to world space
            gt_normal = gt_normal @ outputs["refined_camera"].camera_to_worlds.squeeze(0)[:3, :3].T
            gt_normal = (1 + gt_normal) / 2
            pred_normal = outputs["normal"]
            combined_normal = torch.cat((gt_normal, pred_normal), dim=1)
            combined_rgb = torch.cat((combined_rgb, combined_normal), dim=0)


        images_dict = {"img": combined_rgb}
        
        #if self.config.enable_uncertainty:
        #    images_dict["rgb_err_vs_uncertainty"] = plot_uncertainty(gt_rgb - predicted_rgb, outputs["image_uncertainty"])

        return metrics_dict, images_dict
    
    # Initialize the per-Gaussian uncertainty with inverse of Hessian diagonal
    def initialize_uncertainty_with_hessian(self):
        if not self.training:
            self.tracer = RayTracer(
                min_T=self.config.min_transmittance, 
                min_alpha=self.config.min_alpha, 
                mode=self.config.mode,
                is_training=True,
                bound_extent=self.config.bound_extent,
                enable_depth_training=self.config.enable_depth_training,
                enable_normal_training=self.config.enable_normal_training or self.config.enable_normal_consistency_loss,
                return_rb_grads=self.config.camera_optimizer.mode != "off" or self.config.camera_optimizer.refine_intrinsics,
                enable_dd_loss=self.config.enable_dd_loss,
                enable_uncertainty=self.config.enable_uncertainty,
                strategy=self.config.strategy,
                enable_profiling = self.config.enable_profiling,
                profiler_path = str(self.gf_profiling_path) if self.gf_profiling_path is not None else None 
            )
        cameras = self.cameras.clone()
        gauss_params = {k: v.requires_grad_() for k, v in self.gauss_params.items()}
        params = [v for k, v in gauss_params.items()]
        fake_optimizer = torch.optim.SGD(params, 0.0) # For zero_grad only
        H_diag = {k: torch.zeros_like(v) for k, v in gauss_params.items() if not (k.endswith("uncertainties") or k == "features_rest" or k == "normals")}
        # Keep the GAS around for this loop so we can call backward multiple times while the Gaussians don't change.
        # TODO make this work so we can run backward multiple times without rendering every time.
        #self.tracer.keep_gas()
        for i in tqdm(range(cameras.length()), "Initializing per-Gaussian uncertainty"):
            camera = cameras[i:i+1]
            #render_dict = self.get_outputs(camera)
            # Calculate the Hessian for this view with a backward pass with identity loss gradients
            # Must do a backward for each output channel separately to do compute H.
            for j in range(3):
                render_dict = self.get_outputs(camera)
                fake_loss = render_dict["rgb"][:, :, j].sum()
                fake_loss.backward() #retain_graph = True)
                for k in H_diag.keys():
                    H_diag[k] += gauss_params[k].grad.detach() ** 2
                fake_optimizer.zero_grad()
            
            if self.config.enable_depth_training:
                render_dict = self.get_outputs(camera)
                fake_loss = render_dict["depth"].sum()
                fake_loss.backward() #retain_graph = True)
                for k in H_diag.keys():
                    H_diag[k] += gauss_params[k].grad.detach() ** 2
                fake_optimizer.zero_grad()

            if self.config.enable_normal_training:
                for j in range(3):
                    render_dict = self.get_outputs(camera)
                    fake_loss = render_dict["normal"][:, :, j].sum()
                    fake_loss.backward() #retain_graph = True)
                    for k in H_diag.keys():
                        H_diag[k] += gauss_params[k].grad.detach() ** 2
                    fake_optimizer.zero_grad()

        #self.tracer.keep_gas(False)
        
        # Diagonal pad
        H_diag = {k: v + self.config.hessian_diag_pad for k, v in H_diag.items()}

        #print("H_diag ", {k: (v.min(), v.max(), v.mean(), v.median()) for k, v in H_diag.items()})
        # Compute the covariance
        Sigma_diag = {k: 1.0 / torch.maximum(H_diag[k], 1e-12 * torch.ones_like(H_diag[k])) for k in H_diag.keys()}
        #print("Sigma_diag ", {k: (v.min(), v.max(), v.mean(), v.median()) for k, v in Sigma_diag.items()})
        #render_dict = self.get_outputs(camera)
        #exit(-1)

        if self.config.sh_degree > 0:
            # Uncertainty propagate through the SH. Ignore view-dependent parts.
            C0 = 0.28209479177387814
            # rgb = sh * C0 + 0.5
            Sigma_diag["features_dc"] = C0**2 * Sigma_diag["features_dc"]
        
        if not self.training:
            self.tracer = RayTracer(
                min_T=self.config.min_transmittance, 
                min_alpha=self.config.min_alpha, 
                mode=self.config.mode,
                is_training=False,
                bound_extent=self.config.bound_extent,
                enable_depth_training=self.config.enable_depth_training,
                enable_normal_training=self.config.enable_normal_training or self.config.enable_normal_consistency_loss,
                return_rb_grads=self.config.camera_optimizer.mode != "off" or self.config.camera_optimizer.refine_intrinsics,
                enable_dd_loss=self.config.enable_dd_loss,
                enable_uncertainty=self.config.enable_uncertainty,
                strategy=self.config.strategy,
                enable_profiling = self.config.enable_profiling,
                profiler_path = str(self.gf_profiling_path) if self.gf_profiling_path is not None else None 
            )

        self.gauss_params.update({
            "position_uncertainties": torch.log(1e-12 + Sigma_diag["means"]),
            "scale_uncertainties": torch.log(1e-12 + Sigma_diag["scales"]),
            "quat_uncertainties": torch.log(1e-12 + Sigma_diag["quats"]),
            "color_uncertainties": torch.log(1e-12 + Sigma_diag["features_dc"]),
            "opacity_uncertainties": torch.log(1e-12 + Sigma_diag["opacities"]),
        })
