import * as THREE from 'three'
import { extend } from 'react-three-fiber'
import distortion from './Distortion'

const fragmentShader = `
varying float vAlpha;
varying vec3 vColor;
varying float mixAmount;
vec3 colorB = vec3(0.,0.035,0.067);
  void main(){
		vec3 color = vec3(0.0);
    vec3 colorA = vec3(vColor);
    color = mix(colorA, colorB, mixAmount);
    gl_FragColor = vec4(color, vAlpha);
  }
`

const vertexShader = `
attribute float aOffset;
attribute vec3 aColor;

attribute vec2 aMetrics;

uniform float uTravelLength;
uniform float uTime;
uniform float uSpeed;

varying float vAlpha;
varying vec3 vColor;
varying float mixAmount;
mat4 rotationY( in float angle ) {
	return mat4(	cos(angle),		0,		sin(angle),	0,
			 				0,	1.0,			 0,	0,
					-sin(angle),	0,		cos(angle),	0,
							0, 		0,				0,	1);
}



  #include <getDistortion_vertex>
  void main(){
    vec3 transformed = position.xyz;
    float width = aMetrics.x;
    float height = aMetrics.y;

    transformed.xy *= vec2(width,height);
    float time = mod(uTime  * uSpeed *2. + aOffset , uTravelLength);

    transformed = (rotationY(3.14/2.) * vec4(transformed,1.)).xyz;

    transformed.z +=  - uTravelLength + time;


    float progress = abs(transformed.z / uTravelLength);
    transformed.xyz += getDistortion(progress);

	 	transformed.y += height;
	 	

    transformed.x += -width/2.;
    vec4 mvPosition = modelViewMatrix * vec4(transformed,1.);
    gl_Position = projectionMatrix * mvPosition;
    vColor = aColor;
    float easingProgress = abs(sqrt(1.0 - pow((1.0 - progress)*1. - 1.0, 2.0)));
    vAlpha = abs(min(easingProgress *1., 1.0));
    mixAmount = min(abs( (1.0 - easingProgress) * 2.0), 1.0);
  }
`

class LightSticksMaterial extends THREE.ShaderMaterial {
	constructor() {
		super({
			uniforms: Object.assign(
				{
					uTime: new THREE.Uniform(0),
					uTravelLength: new THREE.Uniform(400),
					uSpeed: new THREE.Uniform(100)
				},
				distortion.uniforms
			),
			vertexShader: vertexShader,
			fragmentShader: fragmentShader,
			fog: false,
			transparent: true
		})
	}

	onBeforeCompile(shader, renderer) {
		super.onBeforeCompile(shader, renderer)
		shader.vertexShader = shader.vertexShader.replace(
			'#include <getDistortion_vertex>',
			distortion.getDistortion
		)
	}

	get uTime() {
		return this.uniforms.uTime.value
	}
	set uTime(v) {
		return (this.uniforms.uTime = new THREE.Uniform(v))
	}
}

extend({ LightSticksMaterial })
