from shortcrust.shader import ShaderProgram
from shortcrust.mesh import Mesh
from shortcrust.gl2 import *
from shortcrust.texture import Texture
from shortcrust.matrix import rotate_z, glmatrix, identity
import math


class RainbowShader(ShaderProgram):
	vertex_shader = """
		attribute vec4 vPosition;
		varying vec4 texPosition;
		uniform mat4 rotation;
		uniform float viewOrientation;
		uniform float xOffset;

		void main()
		{
			mat4 view = mat4(
				viewOrientation * 1.3, 0, 0, 0,
				0, 1.333 * 1.3, 0, 0,
				0, 0, 1, 0,
				0, 0, 0, 1
			);
			texPosition = vPosition;
			gl_Position = view * rotation * vPosition - vec4(xOffset, 1.0, 0.0, 0.0);
		}
	"""

	fragment_shader = """
		/* precision mediump float; */

		const float PI = 3.1415926536;

		const float BEAT_LEN = 60.0 / 95.0;

		uniform sampler2D uSampler;
		varying vec4 texPosition;

		void main()
		{
			/* gl_FragColor = vec4(0.905, 0.776, 0.616, 1.0); */
			gl_FragColor = texture2D(uSampler, texPosition.xy);
		}
	"""

	def __init__(self):
		check_gl_error()
		super(RainbowShader, self).__init__()
		check_gl_error()
		self.vertex_position_attr = glGetAttribLocation(self.program_object, 'vPosition')
		check_gl_error()
		self.rotation_unif = glGetUniformLocation(self.program_object, 'rotation')
		check_gl_error()
		self.view_unif = glGetUniformLocation(self.program_object, 'viewOrientation')
		check_gl_error()
		self.sampler_unif = glGetUniformLocation(self.program_object, 'uSampler')
		check_gl_error()
		self.x_offset_unif = glGetUniformLocation(self.program_object, 'xOffset')
		check_gl_error()


class RainbowMesh(Mesh):
	mode = GL_TRIANGLE_STRIP
	def __init__(self, z=-0.90):
		self.vertices = [
			(0.0, 0.0, z),
			(0.0, 1.0, z),
			(1.0, 0.0, z),
			(1.0, 1.0, z),
		]
		super(RainbowMesh, self).__init__()


class RainbowLayer(object):
	def __init__(self, app):
		self.app = app

		self.shader = RainbowShader()
		self.mesh1 = RainbowMesh()
		self.mesh2 = RainbowMesh(-0.91)
		self.texture = Texture('data/rainbow.png', transparency=True, flipped=False)

	def draw(self, time):
		beat = time / self.app.beat_length - 18
		if beat < 0 or beat > 8:
			return

		self.shader.use()

		self.texture.activate(GL_TEXTURE0)
		glUniform1i(self.shader.sampler_unif, 0)
		check_gl_error()

		rota = rotate_z(beat * math.pi * 0.25 - math.pi)
		glUniformMatrix4fv(self.shader.rotation_unif, 16, GL_FALSE, glmatrix(rota))

		glUniform1f(self.shader.view_unif, 1.0)
		glUniform1f(self.shader.x_offset_unif, 0.5)

		self.mesh1.draw(self.shader.vertex_position_attr)

		# again, but mirrored
		glUniform1f(self.shader.view_unif, -1.0)
		glUniform1f(self.shader.x_offset_unif, -0.5)

		self.mesh2.draw(self.shader.vertex_position_attr)
