try:
    import bpy
    import mathutils
except:
    raise ModuleNotFoundError("""bpy and mathutils missing. 
Install it using:
    pip install bpy==4.0.0 mathutils==3.3.0""")
import json
import math
import argparse
import sys

from pathlib import Path


def init_world_sky():
    # Create a new world
    world = bpy.data.worlds.new("GradientWorld")
    bpy.context.scene.world = world
    world.use_nodes = True

    # Get node tree
    nodes = world.node_tree.nodes
    links = world.node_tree.links

    # Clear default nodes
    nodes.clear()

    # Create nodes
    output_node = nodes.new(type='ShaderNodeOutputWorld')
    background_node = nodes.new(type='ShaderNodeBackground')
    sky_texture_node = nodes.new(type='ShaderNodeTexSky')

    # Configure Sky Texture node
    sky_texture_node.sun_size = 0.174533
    sky_texture_node.sun_elevation = 1.0472
    sky_texture_node.sun_rotation = 2.0944
    sky_texture_node.altitude = 18
    sky_texture_node.ozone_density = 1.98592
    sky_texture_node.sun_intensity = 0.1

    # Configure Background node
    background_node.inputs[1].default_value = 1

    # Link nodes
    links.new(sky_texture_node.outputs['Color'], background_node.inputs['Color'])
    links.new(background_node.outputs['Background'], output_node.inputs['Surface'])


def init_world_uniform():
    world = bpy.context.scene.world
    world.use_nodes = True
    nodes = world.node_tree.nodes
    links = world.node_tree.links

    nodes.clear()

    # Create nodes
    output_node = nodes.new(type='ShaderNodeOutputWorld')
    background_node = nodes.new(type='ShaderNodeBackground')
    gradient_node = nodes.new(type='ShaderNodeTexGradient')
    color_ramp_node = nodes.new(type='ShaderNodeValToRGB')

    # Configure gradient
    gradient_node.gradient_type = 'SPHERICAL'
    color_ramp = color_ramp_node.color_ramp

    # Set color stops
    color_ramp.elements[0].position = 0.0
    color_ramp.elements[0].color = (0.7, 0.3, 0.8, 1)  # Dark color
    color_ramp.elements[1].position = 1.0
    color_ramp.elements[1].color = (0.9, 0.83, 0.88, 1)  # Light color

    # Link nodes
    links.new(gradient_node.outputs['Color'], color_ramp_node.inputs['Fac'])
    links.new(color_ramp_node.outputs['Color'], background_node.inputs['Color'])
    links.new(background_node.outputs['Background'], output_node.inputs['Surface'])

    # Set world background strength
    background_node.inputs['Strength'].default_value = 2.1
    print("Initialized with uniform.")


def get_material():
    material_name = "material"
    new_material = bpy.data.materials.new(name=material_name)

    # Enable 'Use Nodes' for the material
    new_material.use_nodes = True

    # Access the Principled BSDF shader node
    principled_node = new_material.node_tree.nodes.get("Principled BSDF")

    # Set material properties
    if principled_node:
        principled_node.inputs[0].default_value = (0.499473, 0.264397, 0.272829, 1)  # Base Color (RGBA)
        principled_node.inputs[2].default_value = 0.99  # Roughness
    return new_material


# Start with a empty blendfile and import the mesh
def import_mesh(mesh_path):
    if not mesh_path.suffix == ".ply":
        return

    bpy.ops.object.select_all(action='DESELECT')
    for obj in bpy.data.objects:
        if obj.type in {'MESH', 'CAMERA'}:
            obj.select_set(True)
    bpy.ops.object.delete()

    bpy.ops.wm.ply_import(filepath=mesh_path.as_posix())
    print("Mesh imported.")

    return mesh_path.stem


def render(mesh_name, camera_path, render_path):
    init_world_uniform()
    material = get_material()
    print(material.name)

    with open(camera_path, "r") as fs:
        data = json.load(fs)

    mesh = bpy.data.objects[mesh_name]
    try:
        render_cam = bpy.data.objects["RenderCamera"]
    except KeyError:
        camera_data = bpy.data.cameras.new(name="RenderCamera")
        render_cam = bpy.data.objects.new("RenderCamera", camera_data)

    try:
        bpy.context.scene.collection.objects.link(render_cam)
        bpy.context.scene.camera = render_cam
    except RuntimeError:
        print("Camera is already linked.")

    if not mesh.data.materials:
        print("No mesh material")
        mesh.data.materials.append(material)
    else:
        print("Assigning material")
        mesh.data.materials[0] = material

    for i, (fname, cam_data) in enumerate(data.items()):
        print(f"Active File: {fname}")
        c2w = cam_data['c2w']
        K = cam_data['k']
        fx = K[0][0]
        fy = K[1][1]
        w = cam_data['w']
        h = cam_data['h']
        fov_x = 2 * math.atan2(w, 2 * fx)
        fov_y = 2 * math.atan2(h, 2 * fy)

        # Get the current matrix frame
        cam_mat = mathutils.Matrix(c2w + [[0.0, 0.0, 0.0, 0.1]])
        cam_matrix = mesh.matrix_world @ cam_mat

        render_cam.data.sensor_fit = "VERTICAL"
        render_cam.data.lens_unit = "FOV"
        render_cam.data.angle_x = fov_x
        render_cam.data.angle_y = fov_y
        render_cam.matrix_world = cam_matrix

        image_path = render_path.joinpath(f"./{Path(fname).stem}.png").as_posix()
        print(f"Rendering at: {render_path}")

        bpy.context.scene.render.image_settings.file_format = 'PNG'
        bpy.context.scene.render.filepath = image_path
        bpy.context.scene.render.engine = 'CYCLES'
        bpy.context.scene.cycles.samples = 128
        bpy.context.scene.render.resolution_x = w
        bpy.context.scene.render.resolution_y = h
        bpy.context.preferences.addons["cycles"].preferences.compute_device_type = "OPTIX"
        bpy.context.scene.cycles.device = 'GPU'
        bpy.context.preferences.addons["cycles"].preferences.get_devices() 

        for scene in bpy.data.scenes:
            scene.cycles.device = 'GPU'


        # Render image
        bpy.ops.render.render() 
        image = bpy.data.images["Render Result"]
        if image:
            image.save_render(filepath=image_path)


def main():
    parser = argparse.ArgumentParser(description="Render the images from the camera.")
    parser.add_argument("--mesh_path", type=Path, help="Path to generated mesh.")
    parser.add_argument("--camera_path", type=Path, help="Path to camera data for given dataset.")

    args = parser.parse_args()
    mesh_name = import_mesh(args.mesh_path)

    assert mesh_name is not None, "Making sure mesh import was sucsseful."

    render_path = Path(args.camera_path.parent).joinpath("mesh_renders")
    render_path.mkdir(exist_ok=True)
    render(mesh_name, args.camera_path, render_path)


if __name__ == "__main__":
    main()