import { updateRaycaster } from '../../ModelViewer';
import type { ScanReviewRecord } from '../ScanReviewRecordTypes';
import type { ScanReviewPartialScene } from '../ScanReviewSceneTypes';
import type { ScanReviewViewManager } from '../ScanReviewViewTypes';
import {
    type AdjacencyMatrix,
    buildMeshAdjacency,
    ensureMeshIndex,
    getNeighbors,
    AttributeName,
} from '@orthly/forceps';
import { Jaw } from '@orthly/shared-types';
import * as THREE from 'three';
import type { MeshBVH } from 'three-mesh-bvh';

export interface ScanReviewShadePick {
    /**
     * RGB value with each component in range of 0-255
     */
    color: [number, number, number];
    /**
     * The center of the sampling area on the surface of the mesh.
     */
    center: THREE.Vector3;
}

export class ScanReviewRecordAccelerationData {
    scanRecord: ScanReviewRecord;
    adjacencyMatrix: AdjacencyMatrix;
    bvhIndex: MeshBVH;
    constructor(scanRecord: ScanReviewRecord) {
        this.scanRecord = scanRecord;
        // Important:
        //  ensureMeshIndex can reorder index attribute of geometry, leaving
        //  previously calculated face-to-face adjacency information invalid.
        //  To alleviate this call ensureMeshIndex first.
        this.bvhIndex = ensureMeshIndex(scanRecord.scanMesh.geometry);
        this.adjacencyMatrix = buildMeshAdjacency(scanRecord.scanMesh.geometry);
    }
}

export class ScanReviewShadeMatchingPicker {
    private readonly lowerJawAccelerationData: ScanReviewRecordAccelerationData | null;
    private readonly upperJawAccelerationData: ScanReviewRecordAccelerationData | null;
    private jawAccelerationData: ScanReviewRecordAccelerationData | null = null;

    private readonly rayCaster: THREE.Raycaster = new THREE.Raycaster();

    constructor(
        public scene: ScanReviewPartialScene,
        public viewManager: ScanReviewViewManager,
    ) {
        this.lowerJawAccelerationData = scene.lowerJaw ? new ScanReviewRecordAccelerationData(scene.lowerJaw) : null;
        this.upperJawAccelerationData = scene.upperJaw ? new ScanReviewRecordAccelerationData(scene.upperJaw) : null;
    }

    setCurrentJawType(jawType: Jaw | null) {
        if (!jawType) {
            this.scene.setUpperJawVisibility(false);
            this.scene.setLowerJawVisibility(false);
            this.jawAccelerationData = null;
            return;
        }
        if (jawType === Jaw.UPPER) {
            this.scene.setUpperJawVisibility(true);
            this.scene.setLowerJawVisibility(false);
            this.jawAccelerationData = this.upperJawAccelerationData;
        } else {
            this.scene.setUpperJawVisibility(false);
            this.scene.setLowerJawVisibility(true);
            this.jawAccelerationData = this.lowerJawAccelerationData;
        }
    }

    respondToMouseEvent(evt: MouseEvent) {
        if (!this.viewManager.canvas || !this.viewManager.camera) {
            return;
        }
        updateRaycaster(this.rayCaster, this.viewManager.canvas, this.viewManager.camera, evt);
    }

    pickShadeFromVertexColors(maxRadiusMm: number): ScanReviewShadePick | null {
        if (!this.jawAccelerationData) {
            return null;
        }

        const intersection = this.jawAccelerationData.bvhIndex.raycastFirst(this.rayCaster.ray, THREE.FrontSide);

        // we did not click on the mesh
        if (!intersection || !intersection.face) {
            return null;
        }

        const adjacencyMatrix = this.jawAccelerationData.adjacencyMatrix;
        const geometry = this.jawAccelerationData.scanRecord.scanMesh.geometry;
        const neighbors = getNeighbors({
            adjacencyMatrix,
            mainHandle: intersection.face.a,
            maxRadiusMm: maxRadiusMm || 1,
            geometry,
        });

        // Find the sum of each of the red, green, and blue channels.
        // These can then be averaged to find the average color within the selected region.
        // We don't do any multiplication or floor until the end to avoid floating point math errors.
        // We intentionally square the r, g, and b values before adding them to the sum.
        // This is to approximately reverse the compression done to store rgb.
        // For more information, see: https://graphicdesign.stackexchange.com/questions/113884/calculating-average-of-two-rgb-values
        const rgbSums = neighbors.reduce<[number, number, number]>(
            (sums, vert) => {
                const r = geometry.getAttribute(AttributeName.Color)?.getX(vert) ?? 0;
                const g = geometry.getAttribute(AttributeName.Color)?.getY(vert) ?? 0;
                const b = geometry.getAttribute(AttributeName.Color)?.getZ(vert) ?? 0;

                return [sums[0] + r * r, sums[1] + g * g, sums[2] + b * b];
            },
            [0, 0, 0],
        );
        const rgb: [number, number, number] = [
            Math.floor(Math.sqrt(rgbSums[0] / neighbors.length) * 255),
            Math.floor(Math.sqrt(rgbSums[1] / neighbors.length) * 255),
            Math.floor(Math.sqrt(rgbSums[2] / neighbors.length) * 255),
        ];

        return {
            color: rgb,
            center: intersection.point,
        };
    }
}
