import {useArrayBuffer, useAttribLocation, useBindVertexArribArray, useElementArrayBuffer, useProgram, useShader, useUniformLocation, useVertexBuffer} from "#lib/gl-react/index.ts";
import React, {useMemo} from "react";
import {Color, HSLA, Transform} from "common/types/index.ts";
import {Matrix4f} from "#lib/math/index.ts";
import {Line} from "../../../viewport/common/node/layer-view/scene-view.tsx";


const vertexShader = `#version 300 es
precision highp float;

in vec2 a_position;
in vec2 a_tex_coord;
in vec2 a_start;
in vec2 a_end;

uniform mat4 u_projection;
uniform mat4 u_view;
uniform mat4 u_model;

out vec2 fragCoord;

out vec2 v_tex_coord;
out vec2 v_start;
out vec2 v_end;

void main()
{
  gl_Position = u_projection * u_view * u_model * vec4(a_position, 0, 1);
  v_tex_coord = a_tex_coord;
  v_start = a_start;
  v_end = a_end;
}
`;

const fragmentShader = `#version 300 es
precision highp float;
precision highp sampler2DArray;

in vec2 v_tex_coord;
in vec2 v_start;
in vec2 v_end;

uniform float u_scale;
uniform vec4 u_color;

out vec4 outColor;
const float r = 2.;

void main() {
  vec2 lineDir = v_end - v_start;
  vec2 pointDir = v_tex_coord - v_start;
  float t = dot(normalize(lineDir), pointDir);
  vec2 intersection = v_start + t * normalize(lineDir);
  
  float d = distance(v_start, intersection) / u_scale;
  float p = distance(intersection, v_tex_coord) / u_scale;
  float d1 = distance(v_start, v_tex_coord) / u_scale;
  float d2 = distance(v_end, v_tex_coord) / u_scale;

  float r1 = distance(intersection, v_start);
  float r2 = distance(intersection, v_end);
  float rf = distance(v_start, v_end);
  
  if (d1 <= r || d2 <= r) {
    outColor = u_color;
  } else if (r1 > rf || r2 > rf) {
    outColor = vec4(0.);
  } else if (p < r) {
    outColor = u_color;
  }
}
`;

export type AreaLineShaderProps = {
  projection: Matrix4f;
  view: Transform;
  model: Transform;

  lines: Line[];
  scale: number;
  color: HSLA;
};

export function AreaLineShader({lines, projection, model, view, scale, color}: AreaLineShaderProps) {
  const program = useProgram(
    useShader(WebGL2RenderingContext.VERTEX_SHADER, vertexShader),
    useShader(WebGL2RenderingContext.FRAGMENT_SHADER, fragmentShader)
  );
  const projectionLocation = useUniformLocation(program, "u_projection");
  const viewLocation = useUniformLocation(program, "u_view");
  const modelLocation = useUniformLocation(program, "u_model");
  const scaleLocation = useUniformLocation(program, "u_scale");
  const colorLocation = useUniformLocation(program, "u_color");

  const vboArray = useMemo(() => {
    const vertices: number[] = [];
    for (const line of lines) {
      const start = line.start;
      const end = line.end;
      const [sx, sy] = start;
      const [ex, ey] = end;
      const x1 = Math.min(sx, ex);
      const x2 = Math.max(sx, ex);
      const y1 = Math.min(sy, ey);
      const y2 = Math.max(sy, ey);
      const [w, h] = [32, 32];
      vertices.push(...[
        -w/2 * scale + x1, -h/2 * scale + y1,  -w/2*scale + x1, -h/2*scale + y1, start[0], start[1], end[0], end[1],
         w/2 * scale + x2, -h/2 * scale + y1,   w/2*scale + x2, -h/2*scale + y1, start[0], start[1], end[0], end[1],
         w/2 * scale + x2,  h/2 * scale + y2,   w/2*scale + x2,  h/2*scale + y2, start[0], start[1], end[0], end[1],
        -w/2 * scale + x1,  h/2 * scale + y2,  -w/2*scale + x1,  h/2*scale + y2, start[0], start[1], end[0], end[1]
      ]);
    }
    return new Float32Array(vertices);
  }, [lines, scale]);

  const vbo = useArrayBuffer(vboArray);
  const vao = useVertexBuffer();
  useBindVertexArribArray(vao, useAttribLocation(program, "a_position"), vbo, 2, WebGL2RenderingContext.FLOAT, false, 8 * 4, 0);
  useBindVertexArribArray(vao, useAttribLocation(program, "a_tex_coord"), vbo, 2, WebGL2RenderingContext.FLOAT, false, 8 * 4, 2 * 4);
  useBindVertexArribArray(vao, useAttribLocation(program, "a_start"), vbo, 2, WebGL2RenderingContext.FLOAT, false, 8 * 4, 4 * 4);
  useBindVertexArribArray(vao, useAttribLocation(program, "a_end"), vbo, 2, WebGL2RenderingContext.FLOAT, false, 8 * 4, 6 * 4);

  const eboArray = useMemo(() => {
    const eboArray = [];
    for (let i = 0; i < vboArray.length / 32; i ++) {
      const o = i*4;
      eboArray.push(...[
        o+0, o+1, o+2,
        o+2, o+3, o+0
      ]);
    }
    return new Uint16Array(eboArray)
  }, [vboArray]);
  const ebo = useElementArrayBuffer(eboArray);

  const projectionMatrix4f = useMemo(() => new Float32Array(projection), [projection]);
  const viewMatrix4f = useMemo(() => new Float32Array(Matrix4f.transform(view)), [view]);
  const modelMatrix4f = useMemo(() => new Float32Array(Matrix4f.transform(model)), [model]);
  const color4f = useMemo(() => new Float32Array(Color.toRGBA(color)), [color])

  return (<>
      <program value={program}>
        <uniformMat4fv location={projectionLocation} transpose data={projectionMatrix4f}/>
        <uniformMat4fv location={viewLocation} transpose data={viewMatrix4f}/>
        <uniformMat4fv location={modelLocation} transpose data={modelMatrix4f}/>
        <uniform1f location={scaleLocation} data={scale}/>
        <uniform4fv location={colorLocation} data={color4f}/>

        <vertexArray value={vao}>
          <elementArrayBuffer value={ebo}>
            <drawElements mode={WebGL2RenderingContext.TRIANGLES} type={WebGL2RenderingContext.UNSIGNED_SHORT} offset={0} count={eboArray.length}/>
          </elementArrayBuffer>
        </vertexArray>
      </program>
    </>
  );
}