from dataclasses import dataclass
from typing import Any, Dict, Tuple, Union

import torch

from .base import Strategy
from .ops import duplicate, duplicate_3dgrt, remove, reset_opa, split

torch.set_printoptions(precision=20)

@dataclass
class RayTracerDefaultStrategy(Strategy):
    """A default strategy that follows the original 3DGRT paper:

    `3D Gaussian Ray Tracer <https://arxiv.org/pdf/2407.07090>`_

    The strategy will:

    - Periodically prune GSs with low opacity.
    - Periodically reset GSs to a lower opacity.

    Args:
        prune_opa (float): GSs with opacity below this value will be pruned. Default is 0.005.
        refine_start_iter (int): Start refining GSs after this iteration. Default is 500.
        refine_stop_iter (int): Stop refining GSs after this iteration. Default is 15_000.
        reset_every (int): Reset opacities every this steps. Default is 3000.
        refine_every (int): Refine GSs every this steps. Default is 100.
        pause_refine_after_reset (int): Pause refining GSs until this number of steps after
          reset, Default is 0 (no pause at all) and one might want to set this number to the
          number of images in training set.
        verbose (bool): Whether to print verbose information. Default is False.

    Examples:

        >>> from gsplat import RayTracerDefaultStrategy, RayTracer
        >>> params: Dict[str, torch.nn.Parameter] | torch.nn.ParameterDict = ...
        >>> optimizers: Dict[str, torch.optim.Optimizer] = ...
        >>> strategy = RayTracerDefaultStrategy()
        >>> strategy.check_sanity(params, optimizers)
        >>> strategy_state = strategy.initialize_state()
        >>> for step in range(1000):
        ...     render_image, render_alpha, info = raytracer(...)
        ...     strategy.step_pre_backward(params, optimizers, strategy_state, step, info)
        ...     loss = ...
        ...     loss.backward()
        ...     strategy.step_post_backward(params, optimizers, strategy_state, step, info)

    """
    grow_grad3d: float = 0.0002
    grow_scale3d: float = 0.01
    prune_opa: float = 0.01
    prune_scale3d: float = 0.01
    refine_start_iter: int = 500
    refine_stop_iter: int = 15_000
    reset_every: int = 3000
    refine_every: int = 100
    pause_refine_after_reset: int = 0
    max_gaussian_count: int = 3000000
    prune_ratio: float = 0.1
    verbose: bool = False
    visibility_pruning: bool = False
    revised_opacity : bool = False

    def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]:
        """Initialize and return the running state for this strategy.

        The returned state should be passed to the `step_pre_backward()` and
        `step_post_backward()` functions.
        """
        # Postpone the initialization of the state to the first step so that we can
        # put them on the correct device.
        # - grad3d: running accum of the norm of the 3D positional gradients for each GS.
        # - count: running accum of how many time each GS is visible.
        state = {"grad3d": None, "count": None, "contributions": None,
                 "scene_scale": scene_scale}
        return state
    
    def check_sanity(
          self,
          params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
          optimizers: Dict[str, torch.optim.Optimizer],
      ):
        """Sanity check for the parameters and optimizers.

        Check if:
            * `params` and `optimizers` have the same keys.
            * Each optimizer has exactly one param_group, corresponding to each parameter.
            * The following keys are present: {"means", "scales", "quats", "opacities"}.

        Raises:
            AssertionError: If any of the above conditions is not met.

        .. note::
            It is not required but highly recommended for the user to call this function
            after initializing the strategy to ensure the convention of the parameters
            and optimizers is as expected.
        """
        super().check_sanity(params, optimizers)
        # The following keys are required for this strategy.
        for key in ["means", "scales", "quats", "opacities"]:
            assert key in params, f"{key} is required in params but missing."

    def step_pre_backward(
        self,
        params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
        optimizers: Dict[str, torch.optim.Optimizer],
        state: Dict[str, Any],
        step: int
    ):
        """Callback function to be executed before the `loss.backward()` call."""
        pass
        # assert (
        #     "means2d" in info
        # ), "The 2D means of the Gaussians is required but missing."
        # info["means2d"].retain_grad()

    def step_post_backward(
        self,
        params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
        optimizers: Dict[str, torch.optim.Optimizer],
        state: Dict[str, Any],
        step: int,
        info: Dict[str, Any],
        enabled: bool = False,
    ):
        # import pdb; pdb.set_trace()
        if not enabled:
            return
        """Callback function to be executed after the `loss.backward()` call."""
        if step >= self.refine_stop_iter:
            if (
                step > self.refine_start_iter
                and step % self.refine_every == 0
                and step % self.reset_every >= self.pause_refine_after_reset
                ):
                pass
                torch.cuda.empty_cache()
            return  
        
        self._update_state(params, state, info)

        if (
            step > self.refine_start_iter
            and step % self.refine_every == 0
            and step % self.reset_every >= self.pause_refine_after_reset
        ):
            # grow GSs
            n_dupli, n_split = self._grow_gs(params, optimizers, state, step)
            if self.verbose:
                print(
                    f"[After _grow_gs] Step {step}: {n_dupli} GSs duplicated, {n_split} GSs split. "
                    f"Now having {params['means'].shape} GSs."
                )
            # # prune GSs
            n_prune = self._prune_gs(params, optimizers, state, step)
            if self.verbose:
                print(
                    f"Step {step}: {n_prune} GSs pruned. "
                    f"Now having {len(params['means'])} GSs."
                )
            

            # reset running stats
            state["grad3d"].zero_()
            state["count"].zero_()
            if step % self.reset_every != 0:
                state["contributions"].zero_()

            torch.cuda.empty_cache()
        
        if step % self.reset_every == 0:
            reset_opa(
                params=params,
                optimizers=optimizers,
                state=state,
                value=self.prune_opa,
            )
        

    
    def _update_state(
        self,
        params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
        state: Dict[str, Any],
        info: Dict[str, Any]
    ):
        for key in ["grad_pos_densities"]:
            assert key in info, f"{key} is required but missing."
        if self.visibility_pruning:
            assert "contributions" in info, "contributions is required for visibility pruning but missing."

        ## these gradients are scaled by half the distance of each particle

        # initialize state on the first run
        scaled_grads = info["grad_pos_densities"].clone()
        n_gaussian = len(list(params.values())[0])
        if state["grad3d"] is None:
            state["grad3d"] = torch.zeros(n_gaussian, device=scaled_grads.device)

        if state["contributions"] is None:
            state["contributions"] = torch.zeros(n_gaussian, device=scaled_grads.device)
        
        if self.visibility_pruning:
            contributions = info["contributions"].clone()
            state["contributions"] += contributions.flatten()

        if state["count"] is None:
            state["count"] = torch.zeros(n_gaussian, device=scaled_grads.device)


        norms = scaled_grads.norm(dim=-1)  #[n_gaussian]
        sel = norms > 0.0  # n_gaussian]
        gs_ids = torch.where(sel)[0]
        if gs_ids.numel() > 0:
            state["grad3d"].index_add_(0, gs_ids, norms[sel])
            state["count"].index_add_(0, gs_ids, torch.ones_like(gs_ids, dtype=torch.float32))

    
    @torch.no_grad()
    def _grow_gs(
        self,
        params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
        optimizers: Dict[str, torch.optim.Optimizer],
        state: Dict[str, Any],
        step: int,
    ) -> Tuple[int, int]:
        count = state["count"]
        grads = state["grad3d"] / count.clamp_min(1)
        device = grads.device
        is_grad_high = grads > self.grow_grad3d

        is_small = (
            torch.exp(params["scales"]).max(dim=-1).values
            <= self.grow_scale3d * state["scene_scale"]
        )
        is_dupli = is_grad_high & is_small
        n_dupli = is_dupli.sum().item()
        is_large = ~is_small
        is_split = is_grad_high & is_large
        n_split = is_split.sum().item()

        # first duplicate
        if n_dupli > 0:
            duplicate(params=params, optimizers=optimizers, state=state, mask=is_dupli)

        # new GSs added by duplication will not be split
        is_split = torch.cat(
            [
                is_split,
                torch.zeros(n_dupli, dtype=torch.bool, device=device),
            ]
        )

        # then split
        if n_split > 0:
            split(
                params=params,
                optimizers=optimizers,
                state=state,
                mask=is_split,
                revised_opacity=self.revised_opacity,
            )

        return n_dupli, n_split
    

    @torch.no_grad()
    def _prune_gs(
        self,
        params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
        optimizers: Dict[str, torch.optim.Optimizer],
        state: Dict[str, Any],
        step: int,
    ) -> int:
        is_prune = torch.sigmoid(params["opacities"].flatten()) < self.prune_opa
        if step > self.reset_every:
            is_too_big = (
                torch.exp(params["scales"]).max(dim=-1).values
                > self.prune_scale3d * state["scene_scale"]
            )
            is_prune = is_prune | is_too_big

        if self.visibility_pruning:
            if (state["contributions"]).shape[0] > self.max_gaussian_count:
                num_gaussians_to_prune = int(self.prune_ratio * self.max_gaussian_count)
                idx = torch.argsort(state["contributions"].flatten())
                remove_count = (state["contributions"].shape[0] - self.max_gaussian_count) + num_gaussians_to_prune
                remove_extra = idx[:remove_count]
                is_prune[remove_extra] = True
                if self.verbose:
                    print("Prunning excess.")

        n_prune = is_prune.sum().item()
        if n_prune > 0:
            remove(params=params, optimizers=optimizers, state=state, mask=is_prune)
        return n_prune
