import type { Mesh } from "three";
import { BufferAttribute, BufferGeometry, Matrix4 } from "three";
import { INTERSECTED, MeshBVH, NOT_INTERSECTED } from "three-mesh-bvh";

export enum VisualisationMode {
  Heatmap = "HEATMAP",
  VertexColors = "VERTEX_COLORS",
}

export interface IntersectionToolOptions {
  maxDistance: number; // Maximum distance for heatmap
  heatThreshold?: number; // Threshold to visualize heat
  visualisationMode: VisualisationMode; // Rendering mode: Heatmap or VertexColors
  onIntersectionsUpdated?: (mesh: Mesh) => void; // Callback for heatmap updates
}

export class IntersectionTool {
  private readonly meshA: Mesh;
  private readonly meshB: Mesh;
  private readonly options: IntersectionToolOptions;

  constructor(meshA: Mesh, meshB: Mesh, options?: IntersectionToolOptions) {
    this.meshA = meshA;
    this.meshB = meshB;
    this.options = {
      maxDistance: 0.2,
      heatThreshold: 0.1,
      visualisationMode: VisualisationMode.Heatmap,
      ...options,
    };

    this.initialize();
  }

  /**
   * Prepares a mesh for BVH computation by assigning a bounds tree.
   * BVH (Bounding Volume Hierarchy) allows efficient spatial queries like intersections.
   */
  private prepareMeshForBVH(mesh: Mesh): Mesh {
    if (!mesh.geometry?.attributes?.position) {
      throw new Error(
        "Mesh geometry is invalid or missing position attributes."
      );
    }

    mesh.geometry.computeVertexNormals();
    const bufferGeometry = new BufferGeometry();
    bufferGeometry.setAttribute("position", mesh.geometry.attributes.position);

    if (mesh.geometry.index) {
      bufferGeometry.setIndex(mesh.geometry.index);
    }

    const bvh = new MeshBVH(bufferGeometry);
    const serializedBVH = MeshBVH.serialize(bvh);
    mesh.geometry.boundsTree = MeshBVH.deserialize(
      serializedBVH,
      mesh.geometry
    );

    return mesh;
  }

  /**
   * Initializes the geometry attributes used for heatmap and vertex color visualization.
   * Adds "heatmap" (float) and "color" (RGB) attributes to mesh A's geometry.
   */
  private initialize() {
    const heatmap = new BufferAttribute(
      new Float32Array(this.meshA.geometry.attributes.position.count).fill(0),
      1
    );
    this.meshA.geometry.setAttribute("heatmap", heatmap);

    const colors = new BufferAttribute(
      new Float32Array(this.meshA.geometry.attributes.position.count * 3).fill(
        1
      ),
      3
    );
    this.meshA.geometry.setAttribute("color", colors);
  }

  /**
   * Processes the given geometry attribute using shared logic for shapecasting.
   * The reset and update behavior is customizable via callback functions.
   *
   * @param attribute The geometry attribute to process (e.g., heatmap or color).
   * @param resetFn Function to reset the attribute before processing.
   * @param updateFn Function to update the attribute during triangle processing.
   */
  private processGeometryAttribute(
    attribute: BufferAttribute,
    resetFn: (attr: BufferAttribute, i: number) => void,
    updateFn: (
      position: number,
      heatValue: number,
      attr: BufferAttribute
    ) => void
  ) {
    const { maxDistance } = this.options;

    // Compute the relative transformation matrix between meshA and meshB
    const relativeMatrix = new Matrix4()
      .copy(this.meshB.matrixWorld)
      .invert()
      .multiply(this.meshA.matrixWorld);

    // Normalize distances by the mesh scale
    const scale = relativeMatrix.getMaxScaleOnAxis();
    const bvhA = this.meshA.geometry.boundsTree;
    const bvhB = this.meshB.geometry.boundsTree;

    const index = this.meshA.geometry.getIndex();
    if (!index) return;

    // Reset the attribute using the provided callback
    for (let i = 0; i < attribute.count; i++) {
      resetFn(attribute, i);
    }

    // Perform BVH shapecasting to find intersections
    bvhA?.shapecast({
      intersectsBounds: (box) => {
        // Quickly test if bounding boxes of A and B intersect
        return bvhB?.intersectsBox(box, relativeMatrix)
          ? INTERSECTED
          : NOT_INTERSECTED;
      },
      intersectsTriangle: (triangle, triangleIndex) => {
        // Iterate over triangle vertices
        const tri = ["a", "b", "c"] as const;

        tri.forEach((side, pointIndex) => {
          const position = index.getX(triangleIndex * 3 + pointIndex);

          // Transform the vertex to the relative coordinate space of meshB
          const transformedVertex = triangle[side].applyMatrix4(relativeMatrix);
          const closestPoint = bvhB?.closestPointToPoint(
            transformedVertex,
            undefined,
            0,
            maxDistance
          );

          // Compute the distance to the closest point on meshB
          const dist = Math.min(
            (closestPoint?.distance || maxDistance) / scale,
            maxDistance
          );

          // If within maxDistance, compute heat value and update the attribute
          if (dist < maxDistance) {
            const heatValue = 1 - dist / maxDistance;
            updateFn(position, heatValue, attribute);
          }
        });
      },
    });

    attribute.needsUpdate = true;
  }

  /**
   * Main method to compute intersections and update visualization based on mode.
   */
  computeIntersections() {
    const { visualisationMode, onIntersectionsUpdated } = this.options;

    // Ensure meshes are updated in the world coordinate space
    this.meshA.updateMatrixWorld(true);
    this.meshB.updateMatrixWorld(true);

    // Prepare BVH trees for intersection queries
    this.prepareMeshForBVH(this.meshA);
    this.prepareMeshForBVH(this.meshB);

    if (visualisationMode === VisualisationMode.Heatmap) {
      const heatmap = this.meshA.geometry.getAttribute(
        "heatmap"
      ) as BufferAttribute;

      this.processGeometryAttribute(
        heatmap,
        (attr, i) => attr.setX(i, 0), // Reset heatmap to zero
        (position, heatValue, attr) => {
          const { heatThreshold = 0 } = this.options;

          // Apply heatThreshold: ignore values below the threshold
          const effectiveHeatValue = heatValue > heatThreshold ? heatValue : 0;

          // Update heatmap with the effective heat value
          const currentHeatmapValue = attr.getX(position);
          attr.setX(
            position,
            Math.max(currentHeatmapValue, effectiveHeatValue)
          );
        }
      );

      if (onIntersectionsUpdated) onIntersectionsUpdated(this.meshA);
    } else if (visualisationMode === VisualisationMode.VertexColors) {
      const colors = this.meshA.geometry.getAttribute(
        "color"
      ) as BufferAttribute;

      this.processGeometryAttribute(
        colors,
        (attr, i) => attr.setXYZ(i, 1, 1, 1),
        (position, heatValue, attr) => {
          const { maxDistance, heatThreshold = 0 } = this.options;

          // Compute scaled heat value based on maxDistance
          const scaledHeatValue =
            Math.max(0, 1 - heatValue / maxDistance) + 2 - heatThreshold * 2;

          // Apply heatThreshold: only set color if heatValue exceeds threshold
          if (scaledHeatValue > heatThreshold / 2) {
            attr.setXYZ(position, heatValue, 1 - heatValue, 0);
          }
        }
      );
    }
  }

  /**
   * Change the visualization mode and recompute intersections.
   */
  changeVisualisationMode(mode: VisualisationMode) {
    this.options.visualisationMode = mode;
    this.computeIntersections();
  }

  /**
   * Dynamically update the maxDistance property and recompute intersections.
   * @param maxDistance The new maximum distance for heatmap/vertex color calculations.
   */
  setMaxDistance(maxDistance: number) {
    if (maxDistance <= 0) {
      throw new Error("maxDistance must be greater than 0.");
    }

    this.options.maxDistance = maxDistance;

    // Recompute intersections to reflect the updated maxDistance
    this.computeIntersections();
  }

  /**
   * Dynamically update the heatThreshold property and recompute intersections.
   * @param heatThreshold The new threshold for heatmap/vertex color calculations.
   */
  setHeatThreshold(heatThreshold: number) {
    if (heatThreshold < 0 || heatThreshold > 1) {
      throw new Error("heatThreshold must be between 0 and 1.");
    }

    this.options.heatThreshold = heatThreshold;

    // Recompute intersections to reflect the updated heatThreshold
    this.computeIntersections();
  }

  /**
   * Reset the current visualization (heatmap or vertex colors).
   */
  reset() {
    const mode = this.options.visualisationMode;

    if (mode === VisualisationMode.Heatmap) {
      const heatmap = this.meshA.geometry.getAttribute("heatmap");
      if (heatmap) {
        for (let i = 0; i < heatmap.count; i++) {
          heatmap.setX(i, 0);
        }
        heatmap.needsUpdate = true;
      }
    } else if (mode === VisualisationMode.VertexColors) {
      const colors = this.meshA.geometry.getAttribute("color");
      if (colors) {
        for (let i = 0; i < colors.count; i++) {
          colors.setXYZ(i, 1, 1, 1);
        }
        colors.needsUpdate = true;
      }
    }
  }

  /**
   * Recompute intersections when the meshes or their transformations change.
   */
  updateOnChange() {
    this.computeIntersections();
  }
}
