import {
  useArrayBuffer,
  useAttribLocation,
  useBindVertexArribArray,
  useElementArrayBuffer,
  useProgram,
  useShader,
  useUniformLocation,
  useVertexBuffer
} from "#lib/gl-react/index.ts";
import React, {useMemo} from "react";
import {Color, HSLA, Point} from "common/types/index.ts";
import {Matrix4f} from "#lib/math/index.ts";
import {Spline, SplineFn} from "common/types/generic/spline/index.ts";
import {Vector2} from "common/math/vector/vector2.ts";
import {usePVM} from "../../context/pvm-context.ts";
import {useRenderPassTexture} from "../../context/render-pass-texture-context.ts";


const vertexShader = `#version 300 es
precision highp float;

in vec2 a_position;
in vec2 a_tex_coord;
in vec2 a_start;
in vec2 a_end;

uniform mat4 u_projection;
uniform mat4 u_view;
uniform mat4 u_model;

out vec2 fragCoord;

out vec2 v_tex_coord;
out vec2 v_normal_coord;
out vec2 v_world_pos;
out vec2 v_start;
out vec2 v_end;

void main()
{
  gl_Position = u_projection * u_view * u_model * vec4(a_position, 0.0, 1.0);
  v_normal_coord = (gl_Position.xy + 1.0) / 2.0;

  v_world_pos = (u_model * vec4(a_position, 0.0, 1.0)).xy;
  
  v_tex_coord = a_tex_coord;
  v_start = a_start;
  v_end = a_end;
}
`;

const fragmentShader = `#version 300 es
precision highp float;
precision highp sampler2DArray;

uniform vec2 u_origin;
in vec2 v_world_pos;

in vec2 v_normal_coord;
in vec2 v_tex_coord;
in vec2 v_start;
in vec2 v_end;

uniform float u_scale;
uniform vec4 u_color;
uniform float falloffStrength;
uniform sampler2D normalTexture;

layout(location = 0) out vec4 outColor;
layout(location = 1) out vec4 normal;

void main() {
  vec2 lineDir = v_end - v_start;
  vec2 pointDir = v_tex_coord - v_start;
  float t = dot(normalize(lineDir), pointDir);
  vec2 intersection = v_start + t * normalize(lineDir);
  
  float d = distance(v_start, intersection) / u_scale;
  float p = distance(intersection, v_tex_coord) / u_scale;
  float d1 = distance(v_start, v_tex_coord) / u_scale;
  float d2 = distance(v_end, v_tex_coord) / u_scale;

  float r1 = distance(intersection, v_start);
  float r2 = distance(intersection, v_end);
  float rf = distance(v_start, v_end);
  

  vec2 coord = normalize(u_origin-v_world_pos);
  vec3 normal = texture(normalTexture, v_normal_coord).rgb * 2.0 - 1.0;
  float n = mix(1.0, max(dot(normalize(normal.xy), coord.xy), 0.), max(0.0, 1.0 - normal.z));
  outColor = vec4(u_color.xyz, u_color.a * n);

  if (r1 > rf || r2 > rf) {
    outColor = vec4(u_color.rgb, u_color.a * pow(1. - d1, falloffStrength) * n);
  } else if (p < 1.) {
    outColor = vec4(u_color.rgb, u_color.a * pow(1. - p, falloffStrength) * n);
  } else {
    outColor = vec4(u_color.rgb, 0.);
  }
}
`;

export function DSFLineLightShader({spline, origin, color, fallout, falloffStrength}:  {
  origin: Point;

  fallout: number;
  falloffStrength: number;
  spline: Spline;
  color: HSLA;
}) {
  const {projection, view, model} = usePVM();
  const lines = useMemo(() => SplineFn.getLines(spline), [spline]);

  const program = useProgram(
    useShader(WebGL2RenderingContext.VERTEX_SHADER, vertexShader),
    useShader(WebGL2RenderingContext.FRAGMENT_SHADER, fragmentShader)
  );

  const vboArray = useMemo(() => {
    const vertices: number[] = [];
    for (let i = 0; i < lines.length - 1; i ++) {
      const start = Vector2.subtract(lines[i], origin);
      const end = Vector2.subtract(lines[i+1], origin);
      const [sx, sy] = start;
      const [ex, ey] = end;
      const x1 = Math.min(sx, ex);
      const x2 = Math.max(sx, ex);
      const y1 = Math.min(sy, ey);
      const y2 = Math.max(sy, ey);
      const [w, h] = [32, 32];
      vertices.push(...[
        -w/2 * fallout + x1, -h/2 * fallout + y1,  -w/2*fallout + x1, -h/2*fallout + y1, start[0], start[1], end[0], end[1],
         w/2 * fallout + x2, -h/2 * fallout + y1,   w/2*fallout + x2, -h/2*fallout + y1, start[0], start[1], end[0], end[1],
         w/2 * fallout + x2,  h/2 * fallout + y2,   w/2*fallout + x2,  h/2*fallout + y2, start[0], start[1], end[0], end[1],
        -w/2 * fallout + x1,  h/2 * fallout + y2,  -w/2*fallout + x1,  h/2*fallout + y2, start[0], start[1], end[0], end[1]
      ]);
    }
    return new Float32Array(vertices);
  }, [lines, fallout, origin]);

  const vbo = useArrayBuffer(vboArray);
  const vao = useVertexBuffer();
  useBindVertexArribArray(vao, useAttribLocation(program, "a_position"), vbo, 2, WebGL2RenderingContext.FLOAT, false, 8 * 4, 0);
  useBindVertexArribArray(vao, useAttribLocation(program, "a_tex_coord"), vbo, 2, WebGL2RenderingContext.FLOAT, false, 8 * 4, 2 * 4);
  useBindVertexArribArray(vao, useAttribLocation(program, "a_start"), vbo, 2, WebGL2RenderingContext.FLOAT, false, 8 * 4, 4 * 4);
  useBindVertexArribArray(vao, useAttribLocation(program, "a_end"), vbo, 2, WebGL2RenderingContext.FLOAT, false, 8 * 4, 6 * 4);

  const eboArray = useMemo(() => {
    const eboArray = [];
    for (let i = 0; i < vboArray.length / 32; i ++) {
      const o = i*4;
      eboArray.push(...[
        o+0, o+1, o+2,
        o+2, o+3, o+0
      ]);
    }
    return new Uint16Array(eboArray)
  }, [vboArray.length]);
  const ebo = useElementArrayBuffer(eboArray);

  const projectionMatrix4f = useMemo(() => new Float32Array(projection), [projection]);
  const viewMatrix4f = useMemo(() => new Float32Array(Matrix4f.transform(view)), [view]);
  const modelMatrix4f = useMemo(() => new Float32Array(Matrix4f.transform(model)), [model]);
  const color4f = useMemo(() => new Float32Array(Color.toRGBA(color)), [color])
  const origin2f = useMemo(() => new Float32Array(Vector2.multiplyTransform([0, 0], model)), [model]);
  const [_, normalTexture] = useRenderPassTexture();

  return (<program value={program}>
    <uniformMat4fv location={useUniformLocation(program, "u_projection")} transpose={true} data={projectionMatrix4f}/>
    <uniformMat4fv location={useUniformLocation(program, "u_view")} transpose={true} data={viewMatrix4f}/>
    <uniformMat4fv location={useUniformLocation(program, "u_model")} transpose={true} data={modelMatrix4f}/>
    <uniform1f location={useUniformLocation(program, "u_scale")} data={fallout}/>
    <uniform1f location={useUniformLocation(program, "falloffStrength")} data={falloffStrength}/>
    <uniform4fv location={useUniformLocation(program, "u_color")} data={color4f}/>
    <uniform2fv location={useUniformLocation(program, "u_origin")} data={origin2f} />
    <texture2d value={normalTexture}>
      <vertexArray value={vao}>
        <elementArrayBuffer value={ebo}>
          <drawElements mode={WebGL2RenderingContext.TRIANGLES} type={WebGL2RenderingContext.UNSIGNED_SHORT} offset={0} count={eboArray.length}/>
        </elementArrayBuffer>
      </vertexArray>
    </texture2d>
  </program>);
}