import type { Object3D } from "three";
import {
  BufferAttribute,
  ConeGeometry,
  CylinderGeometry,
  Group,
  Matrix4,
  Mesh,
  Ray,
  Vector3,
} from "three";
import { STLExporter } from "three/examples/jsm/exporters/STLExporter.js";
import {
  CONTAINED,
  INTERSECTED,
  MeshBVH,
  NOT_INTERSECTED,
} from "three-mesh-bvh";

const HEATMAP_MAX_DISTANCE = 1.5;

export const generate3DObjectSTLURL = (object: Object3D): Promise<string> => {
  return new Promise((resolve) => {
    const exporter = new STLExporter();

    const stlString = exporter.parse(object);

    try {
      const blob = new Blob([stlString], { type: "model/stl" });
      const url = URL.createObjectURL(blob);
      resolve(url);
    } catch (error) {
      throw new Error("Error creating STL file");
    }
  });
};

export const createCustomArrow = (
  insertionDirection: Vector3,
  inferredDirection: number,
  origin: Vector3,
  arrowWidth: number,
  arrowLength: number
) => {
  const normalizedDirection = new Vector3(
    insertionDirection.x,
    insertionDirection.y,
    insertionDirection.z
  ).normalize();
  const angle = Math.acos(normalizedDirection.dot(new Vector3(0, 0, 1)));

  const arrowBodyLength = arrowLength * 0.75;
  const arrowBody = new CylinderGeometry(
    arrowWidth,
    arrowWidth,
    arrowBodyLength,
    12
  );
  if (inferredDirection > 0) {
    arrowBody.rotateX(Math.PI);
  }

  arrowBody.translate(origin.x, origin.y, origin.z).rotateX(angle);

  const arrowHeadLength = arrowLength * 0.25;
  const arrowHead = new ConeGeometry(arrowWidth * 2, arrowHeadLength, 12);

  if (inferredDirection > 0) {
    arrowHead.rotateX(Math.PI);
  }

  arrowHead
    .translate(
      origin.x,
      inferredDirection > 0
        ? origin.y - arrowBodyLength / 2 - arrowHeadLength / 2
        : origin.y + arrowBodyLength / 2 + arrowHeadLength / 2,
      origin.z
    )
    .rotateX(angle);

  const bodyMesh = new Mesh(arrowBody);
  const headMesh = new Mesh(arrowHead);

  const arrowGroup = new Group();
  arrowGroup.add(bodyMesh);
  arrowGroup.add(headMesh);

  return arrowGroup;
};

export const getDefaultCylinderArgs = (point: InspectionWindow) => {
  return {
    radiusTop: point.diameter / 2,
    radiusBottom: point.diameter / 2,
    height: point.height,
    radialSegments: 16,
    heightSegments: 1,
  };
};

export const getPositionCorrectionForInspectionWindow = (
  point: InspectionWindow
) => {
  return new Vector3(
    point.position_x,
    point.position_z + point.height / 2,
    -point.position_y
  );
};

export const detectIntersections = (
  mesh: Mesh,
  cylinder: Mesh,
  intersectionClonedMesh: Mesh
) => {
  const geometry = mesh.geometry;
  const meshBVH = geometry.boundsTree;
  const cylinderBVH = new MeshBVH(cylinder.geometry);
  const ray = new Ray();
  ray.direction.set(0, 0, 1);
  const tri = ["a", "b", "c"] as const;

  const relativeMatrix = new Matrix4()
    .multiply(cylinder.matrixWorld.clone().invert())
    .multiply(intersectionClonedMesh.matrixWorld);

  const scale = relativeMatrix.getMaxScaleOnAxis();

  let heatmap: BufferAttribute = geometry.getAttribute(
    "heatmap"
  ) as BufferAttribute;
  const index = geometry.getIndex();

  if (!index || !meshBVH) {
    return;
  }

  if (!heatmap) {
    geometry.toNonIndexed();
    const position = geometry.getAttribute("position");
    heatmap = new BufferAttribute(
      new Float32Array(Array.from(Array(position.count)).fill(0)),
      1,
      false
    );
    geometry.setAttribute("heatmap", heatmap);
  }

  meshBVH.shapecast({
    intersectsBounds: (box) => {
      const intersects = cylinderBVH.intersectsBox(box, relativeMatrix);
      if (intersects) {
        return INTERSECTED;
      }

      ray.origin.copy(box.min).applyMatrix4(relativeMatrix);
      const res = cylinderBVH.raycastFirst(ray, 2);
      if (res && (res.face?.normal?.dot(ray.direction) || 0) > 0.0) {
        return CONTAINED;
      }

      return NOT_INTERSECTED;
    },
    intersectsTriangle: (triangle, triangleIndex, contained) => {
      let dist = 0;
      tri.forEach((side, pointIndex) => {
        const position = index.getX(triangleIndex * 3 + pointIndex);
        const currentHeatmapValue = heatmap.getX(position);
        if (currentHeatmapValue) return;
        if (!(dist && contained)) {
          const measure = cylinderBVH.closestPointToPoint(
            triangle[side].applyMatrix4(relativeMatrix),
            undefined,
            0,
            HEATMAP_MAX_DISTANCE
          );

          if (!contained) {
            ray.origin.copy(triangle[side]).applyMatrix4(relativeMatrix);
            const res = cylinderBVH.raycastFirst(ray, 2);
            if (!res || (res.face?.normal?.dot(ray.direction) || 0) < 0.0) {
              heatmap.setX(position, currentHeatmapValue);
              return;
            }
          }

          dist =
            (measure?.distance || HEATMAP_MAX_DISTANCE) /
            scale /
            HEATMAP_MAX_DISTANCE;
        }

        heatmap.setX(position, Math.max(currentHeatmapValue, dist));
      });
    },
  });
  heatmap.needsUpdate = true;
};

export const resetHeatmap = (
  mesh: Mesh | null,
  intersectionClonedMesh: Mesh | null
) => {
  if (!intersectionClonedMesh || !mesh) return;

  const geometry = mesh.geometry;
  const heatmap = geometry.getAttribute("heatmap");

  if (heatmap) {
    for (let i = 0; i < heatmap.count; i++) {
      heatmap.setX(i, 0);
    }

    heatmap.needsUpdate = true;
    geometry.deleteAttribute("heatmap");
  }
};
