import vtk
import wx
import numpy as np

class CustomInteractorStyle(vtk.vtkInteractorStyleTrackballCamera):

  def __init__(self, renderer, vtk_panel=None):
    self.vtk_panel = vtk_panel
    
    self.renderer = renderer
    self.AutoAdjustCameraClippingRangeOff()
    self.AddObserver("LeftButtonPressEvent", self.left_button_press_event)

    self.AddObserver("LeftButtonDoubleClickEvent", self.left_button_double_click_event)
    self.AddObserver("MiddleButtonPressEvent", self.middle_button_press_event)
    self.AddObserver("MiddleButtonReleaseEvent", self.middle_button_release_event)
    self.AddObserver("MouseMoveEvent", self.mouse_move_event)
    self.AddObserver("MouseWheelForwardEvent", self.mouse_wheel_forward_event)
    self.AddObserver("MouseWheelBackwardEvent", self.mouse_wheel_backward_event)
    self.AddObserver("CharEvent", self.character_event)
    self.picker = vtk.vtkPropPicker()
    self.orbiting = False
    self.panning = False
    self.zooming = False
    self.picking = False
    self.hide_flag = False
    self.SetMouseWheelMotionFactor(0.5)

    self.add_sphere()

  def add_sphere(self):
    # TODO: toggle this based on toolbar mode
    # self.sphereSource = vtk.vtkSphereSource()
    # self.sphereSource.SetRadius(0.05)
    # self.sphereMapper = vtk.vtkPolyDataMapper()
    # self.sphereMapper.SetInputConnection(self.sphereSource.GetOutputPort())
    # self.sphereActor = vtk.vtkActor()
    # self.sphereActor.SetMapper(self.sphereMapper)
    # self.sphereActor.GetProperty().SetColor(1.0, 0.0, 0.0)  # Red color
    # self.sphereActor.GetProperty().SetOpacity(0.8)
    # self.renderer.AddActor(self.sphereActor)
    pass

  def hide_check(self):
    if self.hide_flag:
      self.hide_flag = False
      self.vtk_panel.show_frustums()

  def left_button_double_click_event(self):
    clickPos = self.GetInteractor().GetEventPosition()
    renderer = self.GetInteractor().FindPokedRenderer(clickPos[0], clickPos[1])

    picker = vtk.vtkWorldPointPicker()
    picker.Pick(clickPos[0], clickPos[1], 0, renderer)

    if picker.GetPickList() is not None:
      
      camera = self.vtk_panel.camera
      camera_pos = camera.GetPosition()
      camera_view = camera.GetViewUp()
      focal_point = camera.GetFocalPoint()

      distance = [camera_pos[0] -focal_point[0], camera_pos[1] - focal_point[1] , camera_pos[2] - focal_point[2]]
      intersection_pos = picker.GetPickPosition()
      camera.SetFocalPoint(intersection_pos[0], intersection_pos[1], intersection_pos[2])
      camera_pos = [intersection_pos[0] + distance[0], intersection_pos[1] + distance[1], intersection_pos[2] + distance[2]] 
      camera.SetPosition(camera_pos)
      camera.SetViewUp(camera_view)
      #self.vtk_panel.main_frame.status_bar.SetStatusText("\tClicked: X: " + str(intersection_pos[0]) 
     #                                                      + " Y: " + str(intersection_pos[1]) 
     #                                                      + " Z: " + str(intersection_pos[2]))
      renderer.GetRenderWindow().Render()
  
  def hide_check(self):
    if self.hide_flag:
      self.hide_flag = False
      self.vtk_panel.show_frustums()

  def mouse_wheel_forward_event(self, obj, event):
    self.OnMouseWheelForward()
    self.hide_check()
    pass

  def mouse_wheel_backward_event(self, obj, event):
    self.OnMouseWheelBackward()
    self.hide_check()
    pass

  def start_orbiting(self):
    self.OnLeftButtonDown()
    self.orbiting = True

  def end_orbiting(self):
    self.OnLeftButtonUp()
    self.orbiting = False
    
  def start_panning(self):
    self.OnMiddleButtonDown()
    self.panning = True

  def end_panning(self):
    self.OnMiddleButtonUp()
    self.panning = False

  def start_zooming(self):
    self.OnRightButtonDown()
    self.zooming = True

  def end_zooming(self):
    self.OnRightButtonUp()
    self.zooming = False

  def mouse_move_event(self, obj, event):
    if self.orbiting:
      self.OnMouseMove()
      self.vtk_panel.camera.SetViewUp(0, 0, 1)
    elif self.panning:
      self.OnMouseMove()
    elif self.zooming:
      self.OnMouseMove()
    pass

  def character_event(self, obj, event):
    key = self.vtk_panel.widget.GetKeyCode()
    if (key == "f"):
      self.vtk_panel.fly_toward_selected()
    elif (key == "v"):
      self.vtk_panel.fly_to_selected()
      tree_ctrl = self.vtk_panel.main_frame.treectrl_database

      index = len(tree_ctrl.GetSelections()) - 1
      item = tree_ctrl.GetSelections()[index]
      data = tree_ctrl.GetItemData(item)
      self.vtk_panel.display_capture_in_frustum(data)

  def left_button_press_event(self, obj, event):

    if wx.GetKeyState(wx.WXK_ALT):
      self.middle_button_press_event(obj, event)
      return

    click_position = self.GetInteractor().GetEventPosition()
    self.picker.Pick(click_position[0], click_position[1], 0, self.vtk_panel.renderer)
    prop = self.picker.GetProp3D()

    if prop is not None:
      self.selected_prop = prop
      name = prop.GetObjectName()
      #self.vtk_panel.selected_prop_name = name
      if name is not None:
        if "capture" in name:
          self.vtk_panel.add_selection(name)
    # if self.picking:

    #   clickPos = self.GetInteractor().GetEventPosition()
    #   renderer = self.GetInteractor().FindPokedRenderer(clickPos[0], clickPos[1])

    #   picker = vtk.vtkWorldPointPicker()
    #   picker.Pick(clickPos[0], clickPos[1], 0, renderer)

    #   if picker.GetPickList() is not None:
    #     intersection_pos = picker.GetPickPosition()
    #     # self.sphereActor.SetPosition(intersection_pos)
    #    # self.vtk_panel.main_frame.status_bar.SetStatusText("\tClicked: X: " + str(intersection_pos[0]) 
    #    #                                                    + " Y: " + str(intersection_pos[1]) 
    #   #                                                     + " Z: " + str(intersection_pos[2]))
    #     renderer.GetRenderWindow().Render()

    # else:
    #   click_position = self.GetInteractor().GetEventPosition()
    #   self.picker.Pick(click_position[0], click_position[1], 0, self.vtk_panel.renderer)

    #   prop = self.picker.GetProp3D()
    #   if prop is not None:
    #     self.selected_prop = prop
    #     name = prop.GetObjectName()
    #     self.vtk_panel.selected_prop_name = name
    #     if name is not None and name != "":
    #       self.vtk_panel.main_frame.treectrl_database.UnselectAll()
    #       self.vtk_panel.main_frame.treectrl_database.select_item(name)
    #       cell_pick = vtk.vtkCellPicker()
    #       cell_pick.SetTolerance(0.001)
    #       cell_pick.Pick(click_position[0], click_position[1], 0, self.vtk_panel.renderer)
    #       if cell_pick.GetCellId() < 0:
    #         print("No cell picked.")
    #       else:
    #         print("Picked cell ID:", cell_pick.GetCellId())
    #         polydata = prop.GetMapper().GetInput()
           
    #         cell = polydata.GetCell(cell_pick.GetCellId())

    #         num_vertices = cell.GetNumberOfPoints()
    #         total_depth = 0.0
            
    #         for vertex_id in range(num_vertices):
    #           vertex_depth = np.array(polydata.GetPoint(cell.GetPointId(vertex_id)))
    #           vertex_depth = np.dot(vertex_depth - self.vtk_panel.centroid, self.vtk_panel.normal)
    #           total_depth += vertex_depth

    #         average_depth = total_depth / num_vertices
    #         print(average_depth)

  def middle_button_press_event(self, obj, event):
    shift_pressed = obj.GetInteractor().GetShiftKey()
    ctrl_pressed = obj.GetInteractor().GetControlKey()
    if shift_pressed:
      self.start_panning()
    elif ctrl_pressed:
      self.start_zooming()
    else:
      self.start_orbiting()
    
    self.hide_check()

  def middle_button_release_event(self, obj, event):
    self.end_orbiting()
    self.end_panning()
    self.end_zooming()
    self.hide_check()

  def set_picking(self, value):
    self.picking = value