import numpy as np
import argparse
import matplotlib.pyplot as plt
from pathlib import Path
from pprint import pprint as print


def plot():
    fig, ax = plt.subplots(ncols=3, nrows=3, figsize=(12, 12))
    def plot_row(images, row, name):
        for idx, img in enumerate(images):
            if isinstance(img, Path) and img.suffix == ".png":
                img = plt.imread(img)
            ax[row][idx].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
            ax[row][idx].imshow(img)

            if idx == 0:
                ax[row, idx].set_ylabel(name, fontsize=14, rotation=0, labelpad=30)

        if row == 0:
            ax[0, 0].set_title("Radiance", fontsize=14)
            ax[0, 1].set_title("Depth", fontsize=14)
            ax[0, 2].set_title("Surface Normal", fontsize=14)

    return plot_row


def plot_single(ax, img, title):
    img = plt.imread(img)
    ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
    ax.imshow(img)
    ax.set_title(title, fontsize=14)


def compare_mesh(scene_name, render_output):
    save_path = render_output.joinpath(scene_name)
    if not save_path.exists():
        save_path.mkdir()

    output_path = Path.home().joinpath("benchmark", scene_name, "outputs")

    ert_images = list(f for f in output_path.joinpath("raytracingfacto").iterdir() if f.suffix == ".png")
    dns_images = list(f for f in output_path.joinpath("dn-splatter").iterdir() if f.suffix == ".png")

    ert_mesh_images = { f.stem.split("_")[-1]: f for f in ert_images if "mesh" in f.name }
    dns_mesh_images = { f.stem.split("_")[-1]: f for f in dns_images if "mesh" in f.name }
    gt_renders = { f.stem.split("_")[-1]: f for f in ert_images if "gt" in f.name}

    for image in gt_renders.keys():
        print(image)
        fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(12, 4)) 
        try:
            gt_render = gt_renders[image]
            plot_single(ax[0], gt_render, "Ground Truth")

            ert_render = ert_mesh_images[image]
            plot_single(ax[1], ert_render, "ERT Mesh output")

            dns_render = dns_mesh_images[image]
            plot_single(ax[2], dns_render, "DNS Mesh output")

        except Exception as e:
            print(e)
            continue

        plt.tight_layout()
        plt.savefig(save_path.joinpath(f"mesh_render_{image}.png"))
        plt.close()
    


def group_images(scene_name, render_output):
    output_path = Path.home().joinpath("benchmark", scene_name, "outputs")
    plotter = plot()

    save_path = render_output.joinpath(scene_name)
    if not save_path.exists():
        save_path.mkdir()

    ert_images = list(f for f in output_path.joinpath("raytracingfacto").iterdir() if f.suffix == ".png" or f.suffix == ".npy")
    dns_images = list(f for f in output_path.joinpath("dn-splatter").iterdir() if f.suffix == ".png" or f.suffix == ".npy")

    ert_depths = { f.stem.split("_")[-1]: f for f in ert_images if "depth" in f.name }
    dns_depths = { f.stem.split("_")[-1]: f for f in dns_images if "depth" in f.name and "gt" not in f.name}
    gt_depths = { f.stem.split("_")[-1]: f for f in dns_images if "gt_depth" in f.name}
    print(gt_depths)

    ert_normals = { f.stem.split("_")[-1]: f for f in ert_images if "surface" in f.name }
    dns_normals = { f.stem.split("_")[-1]: f for f in dns_images if "surface" in f.name and "gt" not in f.name}
    gt_normals = { f.stem.split("_")[-1]: f for f in dns_images if "gt_normal" in f.name}

    ert_renders = { f.stem.split("_")[-1]: f for f in ert_images if "render" in f.name }
    dns_renders = { f.stem.split("_")[-1]: f for f in dns_images if "render" in f.name}
    gt_renders = { f.stem.split("_")[-1]: f for f in ert_images if "gt" in f.name}

    for image in gt_depths.keys():
        try:
            gt_depth = np.load(gt_depths[image])
            min_d, max_d = gt_depth.min(), gt_depth.max()
            ert_depth = np.load(ert_depths[image])
            min_d, max_d = min(min_d, ert_depth.min()), max(max_d, ert_depth.max())
            dns_depth = np.load(dns_depths[image])
            min_d, max_d = min(min_d, dns_depth.min()), max(max_d, dns_depth.max())

            gt_depth = (gt_depth - min_d) / (max_d - min_d)
            ert_depth = (ert_depth - min_d) / (max_d - min_d)
            dns_depth = (dns_depth - min_d) / (max_d - min_d)

            gt = [gt_renders[image], gt_depth, gt_normals[image]]
            plotter(gt, 0, name="Ground\nTruth")
            ert = [ert_renders[image], ert_depth, ert_normals[image]]
            plotter(ert, 1, name="ERT")
            dns = [dns_renders[image], dns_depth, dns_normals[image]]
            plotter(dns, 2, name="DNS")

            plt.tight_layout()
            plt.savefig(save_path.joinpath(f"render_{image}.png"))
        except KeyError as e:
            print(f"Skipping at {image}.")
            continue



if __name__ == "__main__":

    output_path = Path("./compare_outputs/")
    if not output_path.exists():
        output_path.mkdir(parents=True)

    parser = argparse.ArgumentParser()
    parser.add_argument("--scene_name", type=str, required=True, help="Scene to use for generating table.")
    parser.add_argument("--output_path", type=Path, default=output_path)
    parser.add_argument("--compare_mesh", action="store_true", help="Add mesh render.")

    args = parser.parse_args()

    if args.compare_mesh:
        compare_mesh(args.scene_name, args.output_path)

    images = group_images(args.scene_name, args.output_path)