import React, {useMemo} from "react";
import {
  useArrayBuffer,
  useAttribLocation,
  useBindVertexArribArray,
  useBindVertexElementArrayBuffer,
  useElementArrayBuffer,
  useProgram,
  useShader,
  useUniformLocation,
  useVertexBuffer
} from "#lib/gl-react/index.ts";
import {Color, HSLA, Transform} from "common/types/index.ts";
import {GridNode} from "common/legends/index.ts";
import {Matrix4f} from "#lib/math/index.ts";
import {PerlinNoiseShaderFragment} from "./noise.ts";
import {useGrid} from "../../context/grid-context.ts";

const vertexShaderSource = `#version 300 es
precision highp float;

in vec2 a_position;

uniform mat4 u_projection;
uniform mat4 u_view;
uniform mat4 u_model;

out vec2 fragCoord;

void main()
{
  fragCoord = (u_model * u_view * u_projection * vec4(a_position, 0, 1)).xy;
  gl_Position = vec4(a_position, 0, 1);
}
`;

const fragmentShaderSource = `#version 300 es

precision highp float;
precision highp sampler2DArray;

in vec2 fragCoord;

layout(location = 0) out vec4 outColor;
layout(location = 1) out vec4 normal;

uniform float u_zoomLevel;
uniform vec2 u_gridSize;
uniform vec2 u_limits;
uniform float u_thickness;
uniform float u_noise;
uniform vec4 u_color;

${PerlinNoiseShaderFragment}

struct Cube {
  float q;
  float r;
  float s;
};

Cube cube_round(Cube frac) {
  float q = round(frac.q);
  float r = round(frac.r);
  float s = round(frac.s);
  float q_diff = abs(q - frac.q);
  float r_diff = abs(r - frac.r);
  float s_diff = abs(s - frac.s);
  if (q_diff > r_diff && q_diff > s_diff) {
      q = -r - s;
  } else if (r_diff > s_diff) {
      r = -q - s;
  } else {
      s = -q - r;
  }

  return Cube(q, r, s);
}

vec2 coord( vec2 p ) {
  vec2 q = vec2(
     p.x / u_gridSize.x - p.y / ((3./2.)*u_gridSize.y),
     p.y / ((3./4.)*u_gridSize.y)
  );
  Cube c = cube_round(Cube(q.x, q.y, -q.x-q.y));
  return vec2(c.q, c.r);
}

void main() {
  float n = perlin(fragCoord/u_gridSize)*u_noise;
  float thickness = u_thickness + n;

  vec2 du = fwidth(fragCoord);

  // grid.width, 0
  vec2 qThickness = vec2(thickness * u_gridSize.x, 0.);
  vec2 qDu = vec2(du.x, 0.);
  vec2 qP1 = coord(fragCoord + qThickness/2.);
  vec2 qP2 = coord(fragCoord - qThickness/2. - qDu);
  bool q = floor(qP1.x) > floor(qP2.x) && floor(qP1.y) == floor(qP2.y);

  // 1/2*grid.width, 3/4*grid.height
  vec2 rThickness = vec2(thickness * u_gridSize.x * 1./2., thickness * u_gridSize.y * 3./4.);
  vec2 rDu = vec2(du.x * 1./2., du.y * 3./4.);
  vec2 rP1 = coord(fragCoord + rThickness/2.);
  vec2 rP2 = coord(fragCoord - rThickness/2. - rDu);
  bool r = floor(rP1.x) == floor(rP2.x) && floor(rP1.y) > floor(rP2.y);

  // 1/2*grid.width, -3/4*grid.height
  vec2 sThickness = vec2(thickness * u_gridSize.x * 1./2., -thickness * u_gridSize.y * 3./4.);
  vec2 sDu = vec2(du.x * 1./2., -du.y * 3./4.);
  vec2 sP1 = coord(fragCoord + sThickness/2.);
  vec2 sP2 = coord(fragCoord - sThickness/2. - sDu);
  bool s = floor(sP1.x) > floor(sP2.x) && floor(sP1.y) < floor(sP2.y);

  if (
    (u_limits.x == -1. || u_limits.y == -1.) || (
      fragCoord.x > -u_gridSize.x/2.0 && fragCoord.x < u_limits.x - u_gridSize.x/2.0 &&
      fragCoord.y > -u_gridSize.y/2.0 && fragCoord.y < u_limits.y - u_gridSize.y/2.0
  )) {
    outColor = q || r || s ? u_color : vec4(0.);
  } else {
    outColor = vec4(0.);
  }
  normal = vec4(0.5, 0.5, 1.0, 1.0);
}
`;

export type HexagonVerticalGridViewProps = {
  value: GridNode;
  projection: Matrix4f;
  view: Transform;
  model: Transform;
}

export function HexagonVerticalGridShader({value, projection, view, model}: HexagonVerticalGridViewProps): JSX.Element {
  const grid = useGrid();
  const program = useProgram(
    useShader(WebGL2RenderingContext.VERTEX_SHADER, vertexShaderSource),
    useShader(WebGL2RenderingContext.FRAGMENT_SHADER, fragmentShaderSource)
  );
  const projectionLocation = useUniformLocation(program, "u_projection");
  const viewLocation = useUniformLocation(program, "u_view");
  const modelLocation = useUniformLocation(program, "u_model");
  const limitsLocation = useUniformLocation(program, "u_limits");
  const gridSizeLocation = useUniformLocation(program, "u_gridSize");
  const colorLocation = useUniformLocation(program, "u_color");
  const thicknessLocation = useUniformLocation(program, "u_thickness");
  const noiseLocation = useUniformLocation(program, "u_noise");

  const positionLocation = useAttribLocation(program, "a_position");

  const vao = useVertexBuffer();
  const position = useArrayBuffer(useMemo((): Float32Array => new Float32Array([-1, -1,  -1, +1,  +1, +1,  +1, -1]), []));
  useBindVertexArribArray(vao, positionLocation, position, 2, WebGL2RenderingContext.FLOAT, false, 0, 0);
  const vbo = useElementArrayBuffer(useMemo((): Uint16Array => new Uint16Array([0, 1, 2,  2, 3, 0]), []));
  useBindVertexElementArrayBuffer(vao, vbo);

  const projectionMatrix4f = useMemo(() => new Float32Array(Matrix4f.invert(projection)), [projection]);
  const viewMatrix4f = useMemo(() => new Float32Array(Matrix4f.invert(Matrix4f.transform(view))), [view]);
  const modelMatrix4f = useMemo(() => new Float32Array(Matrix4f.invert(Matrix4f.transform(model))), [model]);

  const gridSize = useMemo(() => new Float32Array([grid.width, grid.height]), [grid.width, grid.height]);
  const color = useMemo(() => new Float32Array(Color.toRGBA([...value.color, value.opacity] as HSLA)), [value.color, value.opacity]);
  const limits = useMemo(() => new Float32Array(value.size || [-1, -1]), [value.size]);

  return (
    <program value={program}>
      <uniformMat4fv location={projectionLocation} transpose data={projectionMatrix4f} />
      <uniformMat4fv location={viewLocation} transpose data={viewMatrix4f} />
      <uniformMat4fv location={modelLocation} transpose data={modelMatrix4f} />
      <uniform2fv location={limitsLocation} data={limits} />
      <uniform1f location={thicknessLocation} data={value.thickness} />
      <uniform1f location={noiseLocation} data={value.noise} />

      <uniform2fv location={gridSizeLocation} data={gridSize} />
      <uniform4fv location={colorLocation} data={color} />
      <vertexArray value={vao}>
        <drawElements mode={WebGL2RenderingContext.TRIANGLES} type={WebGL2RenderingContext.UNSIGNED_SHORT} offset={0} count={6} />
      </vertexArray>
    </program>
  );
}

