import { ZAxis } from '../../utils';
import { DepthMapComputer } from '../GpuAccel/ComputeDepthMap';
import type { RestorativeModel, InsertionDepthDataMap } from './FinishingApp.types';
import type { ScanKey } from './SceneAppearanceManager.types';
import { scanKeys } from './SceneAppearanceManager.types';
import { AttributeName, getInsertionAxisFromOrientation } from '@orthly/forceps';
import { ToothUtils } from '@orthly/items';
import type { ArrayMin1 } from '@orthly/runtime-utils';
import * as THREE from 'three';

const DEPTH_MAP_BUFFER_MM = 1.0;

/**
 * Generates depth maps along the insertion axes for the restorative models and jaw scans, to be used with the undercut
 * depth shaders.
 */
export class InsertionDepthGenerator {
    private depthDataMap: InsertionDepthDataMap = new Map();

    constructor(
        private restorativeModels: ArrayMin1<RestorativeModel>,
        private upperJawGeometry: THREE.BufferGeometry,
        private lowerJawGeometry: THREE.BufferGeometry,
    ) {
        this.generate();
    }

    /**
     * Generates the depth maps
     * @param insertionOrientation If specified, this insertion direction is used instead of the initial ones provided
     *   in `restorativeModels`. The same insertion direction is used for all restorative items - this behavior will
     *   have to be changed to handle multiple restorative items.
     * @returns The depth data map
     */
    generate(insertionOrientation?: THREE.Quaternion): InsertionDepthDataMap {
        this.depthDataMap.clear();

        const renderer = new THREE.WebGLRenderer({ antialias: false });
        type ScanDepthData = {
            box: THREE.Box3;
            orientation: THREE.Quaternion;
            geometry: THREE.BufferGeometry;
            axis?: THREE.Vector3;
        };
        const scans: { [K in ScanKey]: ScanDepthData } = {
            upper: { box: new THREE.Box3(), orientation: new THREE.Quaternion(), geometry: this.upperJawGeometry },
            lower: { box: new THREE.Box3(), orientation: new THREE.Quaternion(), geometry: this.lowerJawGeometry },
        };

        let restorativeModels = this.restorativeModels;
        if (insertionOrientation) {
            const insertionAxis = getInsertionAxisFromOrientation(insertionOrientation);

            restorativeModels = this.restorativeModels.map(el => ({
                ...el,
                insertionAxis,
            })) as ArrayMin1<RestorativeModel>;
        }

        restorativeModels.forEach((model, idx) => {
            const isUpper = ToothUtils.toothIsUpper(model.toothNumber);
            const jawGeometry = isUpper ? this.upperJawGeometry : this.lowerJawGeometry;
            const scan = scans[isUpper ? 'upper' : 'lower'];
            if (!scan.axis) {
                scan.axis = model.insertionAxis;
                scan.orientation.setFromUnitVectors(model.insertionAxis, ZAxis);
            }
            updateOrientedBoundingBox(model.geometry, scan.orientation, scan.box);
            const geometries = restorativeModels
                .filter((m, i) => i !== idx && ToothUtils.areAdjacent(model.toothNumber, m.toothNumber))
                .map(m => new THREE.Mesh(m.geometry));
            geometries.push(new THREE.Mesh(jawGeometry));
            const depthData = computeDepthMap(geometries, renderer, camera => setCameraParams(camera, model));
            this.depthDataMap.set(model.toothNumber, depthData);
        });

        for (const key of scanKeys) {
            const scan = scans[key];
            if (!scan.axis) {
                continue;
            }
            scan.box.expandByScalar(DEPTH_MAP_BUFFER_MM);
            makeSquareXY(scan.box);
            expandBoxZ(scan.geometry, scan.orientation, scan.box);
            const depthData = computeDepthMap([new THREE.Mesh(scan.geometry)], renderer, camera =>
                setCameraParamsFromOrientedBox(camera, scan.orientation, scan.box),
            );
            this.depthDataMap.set(`${key}Jaw`, depthData);
        }

        renderer.forceContextLoss();
        renderer.dispose();

        return this.getDepthDataMap();
    }

    getDepthDataMap(): InsertionDepthDataMap {
        return new Map(this.depthDataMap);
    }
}

function updateOrientedBoundingBox(geometry: THREE.BufferGeometry, orientation: THREE.Quaternion, box: THREE.Box3) {
    const posAttr = geometry.getAttribute(AttributeName.Position);
    const vec = new THREE.Vector3();
    for (let vIdx = 0; vIdx < posAttr.count; vIdx += 1) {
        vec.fromBufferAttribute(posAttr, vIdx).applyQuaternion(orientation);
        box.expandByPoint(vec);
    }
}

function setCameraParamsFromOrientedBox(
    camera: THREE.OrthographicCamera,
    orientation: THREE.Quaternion,
    box: THREE.Box3,
) {
    camera.quaternion.copy(orientation).invert();
    // Set camera position to center of max z face
    box.getCenter(camera.position).setZ(box.max.z).applyQuaternion(camera.quaternion);
    const size = Math.max(box.max.x - box.min.x, box.max.y - box.min.y);
    camera.left = -size / 2;
    camera.right = size / 2;
    camera.bottom = -size / 2;
    camera.top = size / 2;
    camera.near = 0;
    camera.far = box.max.z - box.min.z;
    camera.updateProjectionMatrix();
    camera.updateMatrixWorld();
    return { lateralScale: size, depthScale: camera.far - camera.near };
}

function setCameraParams(camera: THREE.OrthographicCamera, model: RestorativeModel) {
    const quat = new THREE.Quaternion();
    const bbox = new THREE.Box3();
    quat.setFromUnitVectors(model.insertionAxis, ZAxis);
    bbox.makeEmpty();
    updateOrientedBoundingBox(model.geometry, quat, bbox);
    bbox.expandByScalar(DEPTH_MAP_BUFFER_MM);
    return setCameraParamsFromOrientedBox(camera, quat, bbox);
}

function makeSquareXY(box: THREE.Box3) {
    const w = box.max.x - box.min.x;
    const h = box.max.y - box.min.y;
    const halfDiff = (w - h) / 2;
    if (halfDiff > 0) {
        box.max.y += halfDiff;
        box.min.y -= halfDiff;
    } else {
        box.max.x -= halfDiff;
        box.min.x += halfDiff;
    }
}

function expandBoxZ(geometry: THREE.BufferGeometry, orientation: THREE.Quaternion, box: THREE.Box3) {
    const posAttr = geometry.getAttribute(AttributeName.Position);
    const vec = new THREE.Vector3();
    for (let vIdx = 0; vIdx < posAttr.count; vIdx += 1) {
        vec.fromBufferAttribute(posAttr, vIdx).applyQuaternion(orientation);
        if (vec.x < box.min.x || box.max.x < vec.x || vec.y < box.min.y || box.max.y < vec.y) {
            continue;
        }
        box.min.z = Math.min(box.min.z, vec.z);
        box.max.z = Math.max(box.max.z, vec.z);
    }
}

function computeDepthMap(
    geometries: THREE.Mesh[],
    renderer: THREE.WebGLRenderer,
    setCamera: (camera: THREE.OrthographicCamera) => { lateralScale: number; depthScale: number },
) {
    const depthMapComputer = new DepthMapComputer(geometries, { renderer });
    try {
        const { lateralScale, depthScale } = setCamera(depthMapComputer.camera);
        depthMapComputer.compute();
        const texture = depthMapComputer.texture.clone();
        const matrix = depthMapComputer.camera.projectionMatrix
            .clone()
            .multiply(depthMapComputer.camera.matrixWorldInverse);
        return {
            texture,
            axisSpaceMatrix: matrix,
            lateralScale,
            depthScale,
            texSize: new THREE.Vector2(texture.image.width, texture.image.height),
        };
    } finally {
        depthMapComputer.dispose();
    }
}
