"""Contains the ray tracer code for forward and backward pass of the gradients."""
import torch
import torch.utils._pytree as pytree
import qr.spatial_py as spu
from enum import Enum

class GradientMode(Enum):
    STANDARD = 0
    LEGACY = 1

class RayTracer(torch.nn.Module):
    def __init__(self, min_T, min_alpha,
        mode="3dgrt",
        is_training=False,
        bound_extent=1.0,
        enable_depth_training=False,
        enable_normal_training=False,
        return_rb_grads=False,
        enable_dd_loss=False,
        enable_uncertainty=False,
        uncertainty_grad_clip=0.5,
        gradient_mode = GradientMode.LEGACY,
        strategy="default",
        enable_profiling=False,
        profiler_path=None
    ):
        super(RayTracer, self).__init__()
        self.gf = None
        if enable_profiling and profiler_path is not None:
            self.gf = spu.general_field(
                mode,
                bound_extent, 
                return_rb_grads,
                enable_depth_training,
                enable_normal_training,
                enable_dd_loss,
                enable_uncertainty,
                uncertainty_grad_clip,
                gradient_mode.value,
                f"{profiler_path}/profile_raw.dump"
            )
        else:
            self.gf = spu.general_field(
                mode,
                bound_extent, 
                return_rb_grads,
                enable_depth_training,
                enable_normal_training,
                enable_dd_loss,
                enable_uncertainty,
                uncertainty_grad_clip,
                gradient_mode.value,
            )
         
        assert mode in ["3dgrt", "2dgrt"], f"Unknown mode {mode}"
        self.mode = mode
        self.min_T = min_T
        self.min_alpha = min_alpha
        self.training = is_training
        self.particles_loaded = False
        self.enable_depth = enable_depth_training
        self.enable_normal = enable_normal_training
        self.strategy = strategy
        self.info_dict = {}
    
    # If True, keep the Optix GAS buffer around after backward and don't build it during forward.
    def keep_gas(self, keep_gas=True):
        self.gf.keep_gas(keep_gas)

    def forward(self, elsies, rb):
        assert type(elsies) == dict
        assert type(rb) == dict
        
        # Ensure all dict values are tensors before sending to C++ or else it will err
        # For example, if there are ints or floats in there that don't need grads
        # Also make sure they are contiguous for accessing the raw buffers in CUDA
        elsies = {k : (v if torch.is_tensor(v) else torch.tensor(v)).contiguous() for k, v in elsies.items()}
        rb = {k : (v if torch.is_tensor(v) else torch.tensor(v)).contiguous() for k, v in rb.items()}
        
        # torch.is_grad_enabled() will tell us if we are in a no_grad() block.
        # This is important since load_and_render_elsies deletes the GAS buffer
        # but forward() does not.
        if not (self.training and torch.is_grad_enabled()):
            return self.gf.load_and_render_elsies(elsies, rb, self.min_T, self.min_alpha)
        
        return RayTracerFunction.apply(self, elsies, rb)

# Allows for dict input to autograd Function 
# https://gist.github.com/albanD/804d5909295a1e71b5d726597dfbd605
def pytreeify(cls):
    assert issubclass(cls, torch.autograd.Function)

    orig_fw = cls.forward
    orig_bw = cls.backward
    orig_apply = cls.apply

    def new_apply(*inp):
        flat_inp, struct = pytree.tree_flatten(inp)
        out_struct_holder = []
        flat_out = orig_apply(struct, out_struct_holder, *flat_inp)
        assert len(out_struct_holder) == 1
        return pytree.tree_unflatten(flat_out, out_struct_holder[0])

    def new_forward(ctx, struct, out_struct_holder, *flat_inp):
        inp = pytree.tree_unflatten(flat_inp, struct)
        out = orig_fw(ctx, *inp)
        flat_out, out_struct = pytree.tree_flatten(out)
        ctx._inp_struct = struct
        ctx._out_struct = out_struct
        out_struct_holder.append(out_struct)
        return tuple(flat_out)

    def new_backward(ctx, *flat_grad_outputs):
        grad_outputs = pytree.tree_unflatten(flat_grad_outputs, ctx._out_struct)
        if not isinstance(grad_outputs, tuple):
            grad_outputs = (grad_outputs,)
        grad_inputs = orig_bw(ctx, *grad_outputs)
        flat_grad_inputs, grad_inputs_struct = pytree.tree_flatten(grad_inputs)
        if grad_inputs_struct != ctx._inp_struct:
            raise RuntimeError("The backward generated an arg structure that doesn't "
                               "match the forward's input.")
        return (None, None) + tuple(flat_grad_inputs)

    cls.apply = new_apply
    cls.forward = new_forward
    cls.backward = new_backward
    return cls

@pytreeify
class RayTracerFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tracer_instance, elsies, rb):
        gf = tracer_instance.gf
        min_T = tracer_instance.min_T
        min_alpha = tracer_instance.min_alpha
        
        ctx.tracer_instance = tracer_instance
        ctx.elsies = elsies
        ctx.rb = rb

        return gf.forward(elsies, rb, min_T, min_alpha)

    @staticmethod
    def backward(ctx, loss_grads):
        tracer_instance = ctx.tracer_instance
        elsies = ctx.elsies
        rb = ctx.rb

        gf = tracer_instance.gf

        min_T = tracer_instance.min_T
        min_alpha = tracer_instance.min_alpha
        
        elsie_grads, rb_grads, aux = gf.backward(loss_grads, rb, min_T, min_alpha)
            
        # TODO get rid of this. It's slowing us down, but safe for adding new features.
        ####### DEBUG ############################################################
        combined_outputs = {}
        combined_outputs.update({k + "_grad": v for k, v in elsie_grads.items()})
        combined_outputs.update({k + "_grad": v for k, v in rb_grads.items()})
        combined_outputs.update(aux)
        combined_outputs.update({k + "_loss_grad": v for k, v in loss_grads.items()})
        
        found_nan = False
        for name, x in combined_outputs.items():
            if x is not None and (x.isnan().any() or x.isinf().any()):
                print("WARNING:", name, "is not finite")
                found_nan = True
                print("Num rows:", x.shape[0])
                print("Row indices:", torch.arange(x.shape[0]).cuda()[x.isnan().any(1) | x.isinf().any(1)])
                print("Non-finite values at row indices:", x[x.isnan() | x.isinf()])
                x.nan_to_num_()
                #breakpoint()
        #if found_nan:
        #    exit(-1)
        
        ############## END DEBUG ###################################################

        tracer_instance.info_dict.update({"grad_pos_densities": aux["grad_pos_densities"]})

        if tracer_instance.strategy == "trim":
            tracer_instance.info_dict.update({"contributions": aux["contributions_trim"] / aux["ray_count"]})
        elif tracer_instance.strategy in ["default", "mcmc"]: 
            tracer_instance.info_dict.update({"contributions": aux["contributions_default"] / aux["ray_count"]})
        else:
            raise NotImplementedError(f"{tracer_instance.strategy} strategy isn't supported.")
        
        # Update the elsie grads with None for any elsie parameters that do not have gradients
        for k in elsies.keys():
            if k not in elsie_grads.keys():
                elsie_grads[k] = None
        
        # Ensure the output dicts have the same order as input or else the check in pytreeify may fail
        elsie_grads = {k: elsie_grads[k] for k in elsies.keys()}
        rb_grads = {k: rb_grads[k] for k in rb.keys()}

        return None, elsie_grads, rb_grads
