# ##### BEGIN GPL LICENSE BLOCK #####
#
#  This program is free software; you can redistribute it and/or
#  modify it under the terms of the GNU General Public License
#  as published by the Free Software Foundation; either version 2
#  of the License, or (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program; if not, write to the Free Software Foundation,
#  Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# ##### END GPL LICENSE BLOCK #####

import bpy
import bmesh
from mathutils import Vector, Matrix
import numpy as np
from math import floor, ceil

from .LilyPack import pack

#from .profiling import Timer, Stats
class Timer():
    def ellapsed(self):
        return None
class Stats():
    def add_sample(self, name, value):
        pass
    def print(self):
        pass


# Hypotheses:
#  * One UV per object
#  * One material per object
#  * One texture par material is packed, trying to use baseColor but falling back to others otherwise
#  * Texture is mapped using UV
#  * Objects don't share textures
#  * UDIMs are not used

BASE_COLOR_INDEX = 0

stats = Stats()


def get_material_output(material):
    for node in material.node_tree.nodes:
        if node.type == "OUTPUT_MATERIAL":
            return node


def list_image_nodes(node, weight=0):
    if node.type == 'TEX_IMAGE':
        return [(node, weight)]
    image_nodes = []
    for i, in_socket in enumerate(node.inputs):
        w = weight
        if node.type == 'BSDF_PRINCIPLED' and i == BASE_COLOR_INDEX:
            w += 100
        for l in in_socket.links:
            image_nodes += list_image_nodes(l.from_node, weight=w - 1)
    return image_nodes


def get_image_node(obj):
    material = obj.material_slots[0].material
    material_output = get_material_output(material)
    image_nodes = list_image_nodes(material_output)
    image_nodes.sort(key=lambda x: -x[1])
    return image_nodes[0][0] if len(image_nodes) > 0 else None


all_uvs = np.zeros((1, 2), dtype=np.float32)
def transform_uv(obj, translate=(0,0), scale=(1,1), flip=False):
    global all_uvs
    S = np.array([[scale[0], 0], [0, scale[1]]])
    T = np.array(translate)

    me = obj.data
    all_uvs = np.resize(all_uvs, (len(me.loops),2))
    me.uv_layers.active.data.foreach_get('uv', all_uvs.ravel())

    if flip:
        all_uvs = np.roll(all_uvs, 1, axis=1)

    all_uvs = all_uvs @ S + T

    me.uv_layers.active.data.foreach_set('uv', all_uvs.ravel())
    me.update()


def get_pixel_uv_bbox(obj, img):
    global all_uvs
    me = obj.data
    iw = img.size[0]
    ih = img.size[1]
    x_min, y_min = iw, ih
    x_max, y_max = 0, 0

    if len(me.loops) == 0:
        return 0, 0, 0, 0

    all_uvs = np.resize(all_uvs, (len(me.loops), 2))#all_uvs.resize((len(me.loops),2))
    me.uv_layers.active.data.foreach_get('uv', all_uvs.ravel())
    x_min, y_min = all_uvs.min(axis=0)
    x_max, y_max = all_uvs.max(axis=0)

    x_min = floor(x_min * iw)
    y_min = floor(y_min * ih)
    x_max = ceil(x_max * iw)
    y_max = ceil(y_max * ih)
    w = x_max - x_min + 1
    h = y_max - y_min + 1
    return max(0, x_min), max(0, y_min), min(w, iw-1), min(h, ih-1)


def pack_textures(objects, spacing=0, max_texture_size=0):
    timer = Timer()
    source_image_nodes = [ get_image_node(obj) for obj in objects ]
    source_images = [ node.image for node in source_image_nodes ]
    stats.add_sample("get_image_node", timer.ellapsed())
    timer = Timer()
    source_rectangles = [ get_pixel_uv_bbox(obj, img) for obj, img in zip(objects, source_images) ]
    rect_sizes = np.array([ [r[2], r[3]] for r in source_rectangles ])
    rect_sizes += np.array([ [spacing, spacing] ])
    stats.add_sample("get rects", timer.ellapsed())

    wm = bpy.context.window_manager
    wm.progress_begin(0, 100)
    wm.progress_update(0)
    i = 0

    # Pack
    timer = Timer()
    rect_packing, packed_w, packed_h = pack(rect_sizes, max_texture_size)
    layer_count = int(rect_packing[:,2].max()) + 1
    stats.add_sample("packing", timer.ellapsed())

    # Create output image
    timer = Timer()
    packed_images = []
    for i in range(layer_count):
        img = bpy.data.images.new(f"LilyPackedImage-{i:04d}", width=packed_w, height=packed_h)
        packed_images.append(img)

    packed_pixels = np.zeros((layer_count, packed_h, packed_w, 4), dtype=np.float32)
    stats.add_sample("create output image", timer.ellapsed())

    all_pixels = np.zeros((1, 1, 4), dtype=np.float32)

    half_spacing = spacing // 2
    for obj, img, r, node, [x, y, layer, f] in zip(objects, source_images, source_rectangles, source_image_nodes, rect_packing):
        # Copy pixels to output
        timer = Timer()
        (rx, ry, rw, rh) = r
        rx += half_spacing
        ry += half_spacing
        rw -= spacing
        rh -= spacing
        if rw <= 0 or rh <= 0:
            continue
        all_pixels.resize((img.size[0], img.size[1], 4))
        img.pixels.foreach_get(all_pixels.ravel())
        #pixels = np.array(img.pixels)
        stats.add_sample("convert image to np", timer.ellapsed())
        timer = Timer()
        pixels = all_pixels.reshape(img.size[1], img.size[0], 4)
        pixels = pixels[ry:ry+rh,rx:rx+rw,:]
        flipped = f == 1
        stats.add_sample("copy pixels - part 1", timer.ellapsed())
        timer = Timer()
        if flipped:
            pixels = pixels.transpose(1, 0, 2)
        h, w = pixels.shape[:2]
        packed_pixels[layer,y:y+h,x:x+w,:] = pixels
        stats.add_sample("copy pixels - part 2", timer.ellapsed())

        # Transform UVs
        timer = Timer()
        transform_uv(obj,
            translate=(-rx/rw, -ry/rh),
            scale=(img.size[0]/rw, img.size[1]/rh))
        transform_uv(obj,
            translate=(x/packed_w, y/packed_h),
            scale=(w/packed_w, h/packed_h),
            flip=flipped)
        stats.add_sample("transform UVs", timer.ellapsed())

        del pixels

        wm.progress_update(i//len(objects))
        i += 1

        # Update image in materials
        node.image = packed_images[layer]

    timer = Timer()
    for img, pxl in zip(packed_images, packed_pixels):
        img.pixels.foreach_set(pxl.ravel())
    stats.add_sample("convert image back from np", timer.ellapsed())

    wm.progress_end()
    stats.print()

