import os
import collections
import numpy as np
import shutil
try:
    import depth_pro
except:
    raise ModuleNotFoundError("Depth pro.")
import struct
import torch
import open3d as o3d
import matplotlib
import cv2
from PIL import Image
from tqdm import tqdm
from pathlib import Path
import OpenEXR
import Imath

CameraModel = collections.namedtuple(
    "CameraModel", ["model_id", "model_name", "num_params"]
)
Camera = collections.namedtuple(
    "Camera", ["id", "model", "width", "height", "params"]
)
BaseImage = collections.namedtuple(
    "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]
)
Point3D = collections.namedtuple(
    "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]
)

CAMERA_MODELS = {
    CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
    CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
    CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
    CameraModel(model_id=3, model_name="RADIAL", num_params=5),
    CameraModel(model_id=4, model_name="OPENCV", num_params=8),
    CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
    CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
    CameraModel(model_id=7, model_name="FOV", num_params=5),
    CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
    CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
    CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12),
}
CAMERA_MODEL_IDS = dict(
    [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS]
)
CAMERA_MODEL_NAMES = dict(
    [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS]
)

def qvec2rotmat(qvec):
    return np.array(
        [
            [
                1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
                2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
                2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
            ],
            [
                2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
                1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
                2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
            ],
            [
                2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
                2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
                1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
            ],
        ]
    )

class Image(BaseImage):
    def qvec2rotmat(self):
        return qvec2rotmat(self.qvec)

def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
    """pack and write to a binary file.
    :param fid:
    :param data: data to send, if multiple elements are sent at the same time,
    they should be encapsuled either in a list or a tuple
    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
    should be the same length as the data list or tuple
    :param endian_character: Any of {@, =, <, >, !}
    """
    if isinstance(data, (list, tuple)):
        bytes = struct.pack(endian_character + format_char_sequence, *data)
    else:
        bytes = struct.pack(endian_character + format_char_sequence, data)
    fid.write(bytes)

def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
    data = fid.read(num_bytes)
    return struct.unpack(endian_character + format_char_sequence, data)


def read_cameras_binary(path_to_model_file):
    cameras = {}
    with open(path_to_model_file, "rb") as fid:
        num_cameras = read_next_bytes(fid, 8, "Q")[0]
        for _ in range(num_cameras):
            camera_properties = read_next_bytes(
                fid, num_bytes=24, format_char_sequence="iiQQ"
            )
            camera_id = camera_properties[0]
            model_id = camera_properties[1]
            model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
            width = camera_properties[2]
            height = camera_properties[3]
            num_params = CAMERA_MODEL_IDS[model_id].num_params
            params = read_next_bytes(
                fid,
                num_bytes=8 * num_params,
                format_char_sequence="d" * num_params,
            )
            cameras[camera_id] = Camera(
                id=camera_id,
                model=model_name,
                width=width,
                height=height,
                params=np.array(params),
            )
        assert len(cameras) == num_cameras
    return cameras

def read_images_binary(path_to_model_file):
    """
    see: src/colmap/scene/reconstruction.cc
        void Reconstruction::ReadImagesBinary(const std::string& path)
        void Reconstruction::WriteImagesBinary(const std::string& path)
    """
    images = {}
    with open(path_to_model_file, "rb") as fid:
        num_reg_images = read_next_bytes(fid, 8, "Q")[0]
        for _ in range(num_reg_images):
            binary_image_properties = read_next_bytes(
                fid, num_bytes=64, format_char_sequence="idddddddi"
            )
            image_id = binary_image_properties[0]
            qvec = np.array(binary_image_properties[1:5])
            tvec = np.array(binary_image_properties[5:8])
            camera_id = binary_image_properties[8]
            binary_image_name = b""
            current_char = read_next_bytes(fid, 1, "c")[0]
            while current_char != b"\x00":  # look for the ASCII 0 entry
                binary_image_name += current_char
                current_char = read_next_bytes(fid, 1, "c")[0]
            image_name = binary_image_name.decode("utf-8")
            num_points2D = read_next_bytes(
                fid, num_bytes=8, format_char_sequence="Q"
            )[0]
            x_y_id_s = read_next_bytes(
                fid,
                num_bytes=24 * num_points2D,
                format_char_sequence="ddq" * num_points2D,
            )
            xys = np.column_stack(
                [
                    tuple(map(float, x_y_id_s[0::3])),
                    tuple(map(float, x_y_id_s[1::3])),
                ]
            )
            point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
            images[image_id] = Image(
                id=image_id,
                qvec=qvec,
                tvec=tvec,
                camera_id=camera_id,
                name=image_name,
                xys=xys,
                point3D_ids=point3D_ids,
            )
    return images

def write_points3D_binary(points3D, path_to_model_file):
    """
    see: src/colmap/scene/reconstruction.cc
        void Reconstruction::ReadPoints3DBinary(const std::string& path)
        void Reconstruction::WritePoints3DBinary(const std::string& path)
    """
    with open(path_to_model_file, "wb") as fid:
        write_next_bytes(fid, len(points3D), "Q")
        for _, pt in points3D.items():
            write_next_bytes(fid, pt.id, "Q")
            write_next_bytes(fid, pt.xyz.tolist(), "ddd")
            write_next_bytes(fid, pt.rgb.tolist(), "BBB")
            write_next_bytes(fid, pt.error, "d")
            track_length = pt.image_ids.shape[0]
            write_next_bytes(fid, track_length, "Q")
            for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
                write_next_bytes(fid, [image_id, point2D_id], "ii")

def generate_depth(dataset_path: Path):
    dirs_in_dataset_path = [f for f in dataset_path.iterdir() if f.is_dir()]
    assert dataset_path.name == "processed", "Make sure the dataset_path ends in processed."

    sparse_dir = dataset_path.joinpath("sparse")
    assert sparse_dir in dirs_in_dataset_path, "Make sure {sparse} exists.."

    cameras = read_cameras_binary(os.path.join(sparse_dir, "0/cameras.bin"))
    camera = cameras[list(cameras.keys())[0]]
    focal_length = camera.params[0]

    checkpoints_dir = Path("./checkpoints")
    if not checkpoints_dir.exists():
        os.mkdir(checkpoints_dir)

    checkpoint_url = "https://ml-site.cdn-apple.com/models/depth-pro/depth_pro.pt"
    checkpoint_path = checkpoints_dir.joinpath("depth_pro.pt")
    if not checkpoint_path.exists():
        import urllib.request
        print(f"Downloading {checkpoint_url} to {checkpoint_path}")
        urllib.request.urlretrieve(checkpoint_url, checkpoint_path)

    depth_folder = dataset_path.joinpath("mono_depth")
    if depth_folder.exists():
        print(f"[Warning] Reusing the depth folder: {depth_folder}")
        return depth_folder
    else:
        os.makedirs(depth_folder, exist_ok=True)

        
    model, transform = depth_pro.create_model_and_transforms(device=torch.device("cuda"))
    model.eval()


    images_folder = dataset_path.joinpath("images")

    # iterate over all png files in a folder
    for filename in os.listdir(images_folder): #, "Running mono depth network":
        if filename.endswith(".png") or filename.endswith(".JPG"):
            image_path = os.path.join(images_folder, filename)
            # Load and preprocess an image.
            image, _, f_px = depth_pro.load_rgb(image_path)
            image = transform(image)

            # Run inference.
            #prediction = model.infer(image, f_px=f_px)
            prediction = model.infer(image, f_px=focal_length)
            depth = prediction["depth"]  # Depth in [m].

            #focallength_px = prediction["focallength_px"]  # Focal length in pixels.
            depth = depth.detach().cpu().numpy()
            print(f"Stats for {filename}: mean = {depth.mean()}, std = {depth.std()}, max = {depth.max()}, min = {depth.min()}")
            np.save(os.path.join(depth_folder, filename.replace(".png", ".npy")), depth)
            depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
            depth = depth.astype(np.uint8)
            cmap = matplotlib.colormaps.get_cmap('Spectral_r')
            depth = (cmap(depth)[:, :, :3] * 255)[:, :, ::-1].astype(np.uint8)
            # Save image

            # im = Image.fromarray(depth)
            # im.save(os.path.join(depth_folder, filename))
            # save with opencv
            cv2.imwrite(os.path.join(depth_folder, filename), depth)
    return depth_folder

def read_exr_depth(filepath):
    """Reads a depth map from an EXR file.

    Args:
        filepath: Path to the EXR file.

    Returns:
        A NumPy array containing the depth map.
    """

    file = OpenEXR.InputFile(filepath)
    dw = file.header()['dataWindow']
    size = (dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1)

    # Read the depth channel (usually named 'Z')
    depth_channel = file.channel('V', Imath.PixelType(Imath.PixelType.FLOAT))
    depth = np.frombuffer(depth_channel, dtype=np.float32)
    depth.shape = (size[1], size[0]) 

    return depth

def generate_depth_GT(dataset_path: Path):
    dirs_in_dataset_path = [f for f in dataset_path.iterdir() if f.is_dir()]
    assert dataset_path.name == "processed", "Make sure the dataset_path ends in processed."

    depth_folder = dataset_path.joinpath("mono_depth")
    if depth_folder.exists():
        print(f"[Warning] Reusing the depth folder: {depth_folder}")
        return depth_folder
    else:
        os.makedirs(depth_folder, exist_ok=True)

    images_folder = dataset_path.joinpath("depth")

    # iterate over all png files in a folder
    for filename in os.listdir(images_folder): #, "Running mono depth network":


        depth = read_exr_depth(os.path.join(images_folder, filename)) 
        print(f"Stats for {filename}: mean = {depth.mean()}, std = {depth.std()}, max = {depth.max()}, min = {depth.min()}")
        np.save(os.path.join(depth_folder, filename.replace(".exr", ".npy")), depth)
        depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
        depth = depth.astype(np.uint8)
        cmap = matplotlib.colormaps.get_cmap('Spectral_r')
        depth = (cmap(depth)[:, :, :3] * 255)[:, :, ::-1].astype(np.uint8)
        # Save image
        # print(depth.shape, depth.mean()); exit()
        # im = Image.fromarray(depth)
        # im.save(os.path.join(depth_folder, filename))
        # save with opencv
        cv2.imwrite(os.path.join(depth_folder, filename.replace(".exr", ".png")), depth)
    return depth_folder

def write_o3d_pointcloud_to_colmap_format(pointcloud, out_path):
    pts = {}
    xyzs = np.asarray(pointcloud.points)
    rgbs = np.asarray(pointcloud.colors)
    for i in range(len(xyzs)):
        pts[i] = Point3D(
            id=i,
            xyz=xyzs[i].astype(np.float64),
            rgb=(255 * rgbs[i]).astype(np.int64),
            error=np.array(0.0),
            image_ids=np.array([1, 2, 3]), # Insert bogus so colmap viewer will show the points by default
            point2D_idxs=np.array([1, 2, 3]),
        )

    write_points3D_binary(pts, out_path)

# TODO change voxel_size based on scene size
def create_depth_seed_points(processed_root, voxel_size=0.01, sdf_trunc=0.02, depth_trunc=1e9):
    print("Fusing mono depth maps to create Gaussian seed points")
    sparse_path = os.path.join(processed_root, "sparse")
    sparse_path_backup = os.path.join(processed_root, "sparse_orig")
    print(f"Backing up original COLMAP sparse directory to {sparse_path_backup}")
    shutil.copytree(sparse_path, sparse_path_backup)
        
    cameras = read_cameras_binary(os.path.join(sparse_path, "0/cameras.bin"))
    images = read_images_binary(os.path.join(sparse_path, "0/images.bin"))

    volume = o3d.pipelines.integration.ScalableTSDFVolume(
            voxel_length=voxel_size,
            sdf_trunc=sdf_trunc,
            depth_sampling_stride=4,
            color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8
    )

    for image_id, image in tqdm(images.items(), "Fusing mono depths"):
        image_file = os.path.join(processed_root, "images", image.name)
        depth_file = os.path.join(processed_root, "mono_depth", image.name.replace(".png", ".npy"))

        pose = np.eye(4)
        pose[:3, :3] = image.qvec2rotmat()
        pose[:3, 3] = image.tvec
        
        bgr = cv2.imread(image_file)
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        depth = np.load(depth_file)

        intrinsic = o3d.camera.PinholeCameraIntrinsic(
            width = rgb.shape[1],
            height = rgb.shape[0],
            fx = cameras[image.camera_id].params[0], 
            fy = cameras[image.camera_id].params[1],
            cx = cameras[image.camera_id].params[2],
            cy = cameras[image.camera_id].params[3], 
        )

        im_rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
            o3d.geometry.Image(rgb),
            o3d.geometry.Image(depth), 
            depth_trunc=depth_trunc, 
            depth_scale=1,
            convert_rgb_to_intensity=False
        )

        volume.integrate(im_rgbd, intrinsic, pose)
    
    pointcloud = volume.extract_point_cloud()
    mesh = volume.extract_triangle_mesh()

    # TODO load these normals into raytracingfacto
    mesh.compute_vertex_normals()

    mesh_path = os.path.join(processed_root, 'mono_depth_mesh.ply')
    print('Writing mesh to', mesh_path)
    o3d.io.write_triangle_mesh(mesh_path, mesh)
    write_o3d_pointcloud_to_colmap_format(pointcloud, os.path.join(sparse_path, "0/points3D.bin"))
