#version 330

#include light.ss;

// Textures
uniform sampler2D uDiffuseMap;
uniform sampler2D uRoughnessMap;
uniform sampler2D uMetallicMap;

in vec3 fNormal;
in vec3 wNormal;
in vec3 fPosition;
in vec3 wPosition;
in vec2 fTexCoord;

/*
Ray tracing to lights
*/

uniform float ucLightSizes[LIGHT_AMOUNT-1];

vec4 raytraceLights(vec3 rayStart, vec3 rayDir);
// Returns distance, -1 if no hit
float rayToSphere(vec3 rayStart, vec3 rayDir, vec3 spherePos, float radius);
float hitRay(vec3 sphereCenter, float radius, vec3 rayStart, vec3 rayDir);

out vec4 fragColor;

void main()
{	
	vec4 textureColor = texture(uDiffuseMap, fTexCoord);
	vec4 roughnessSample = texture(uRoughnessMap, fTexCoord);
	vec4 metallicSample = texture(uMetallicMap, fTexCoord);
	
	float roughness = roughnessSample.x;
	float metallic = metallicSample.x;
	
	Material mat = getMaterial(metallic, roughness);
	
	float diamondRefractionIndex = 2.42;
	float airRefractionIndex = 1.0f;
	float refractionIndex = airRefractionIndex / diamondRefractionIndex;
	
	vec3 toFragment = normalize(wPosition - uCameraPosition);
	vec3 toCube = refract(toFragment, wNormal, refractionIndex);
	vec3 rayDir = normalize(toCube);
	vec4 lighting = raytraceLights(wPosition, rayDir);
	
	vec4 refraction = calculateRefractionColor(
		  wPosition
		, wNormal
		, uCameraPosition
		, uCubeMap
		, uSkyboxRotationMatrix
		, refractionIndex);
		
	// Lighting from inside to out, flip normal
	vec4 normalLights = calculateAllLighting(wPosition
	, fPosition
	, wNormal
	, -fNormal
	, textureColor
	, mat);
		
	fragColor = textureColor * (normalLights + lighting + refraction);
}

vec4 raytraceLights(vec3 rayStart, vec3 rayDir)
{
	vec3 lightColor = vec3(0,0,0);
	float closest = 10000.0;
	
	// Light 0 is directional
	for (int i = 1; i < LIGHT_AMOUNT; i++)
	{
		vec3 spherePos = uLights[i].positionDirection;
		float radius = ucLightSizes[i-1];
		
		
		// From book, works
		float distanceToHit = -1.0;
		
		distanceToHit = hitRay(spherePos, radius, rayStart, rayDir);
			
		// Own idea, does not work.
		// distanceToHit = rayToSphere(rayStart, rayDir, spherePos, radius);
		
		if (distanceToHit >= 0 && distanceToHit < closest)
		{
			closest = distanceToHit;
			// Calculate light attenuation
			float attenuation = getAttenuation(distanceToHit, uLights[i].linearAttenuation, uLights[i].quadraticAttenuation);
			lightColor = uLights[i].color * attenuation;
		}
		
	}
	
	return vec4(lightColor, 1.0);
}

float hitRay(vec3 sphereCenter, float radius, vec3 rayStart, vec3 rayDir)
{
	vec3 oc = rayStart - sphereCenter;
	float a = dot(rayDir, rayDir);
	float b = 2.0 * dot(oc, rayDir);
	float c = dot(oc, oc) - radius * radius;
	float disc = b*b - 4*a*c;
	if (disc < 0)
	{
		return -1.0;
	}
	else
	{
		return (-b - sqrt(disc)) / (2.0 * a);
	}
}

float rayToSphere(vec3 rayStart, vec3 rayDir, vec3 spherePos, float radius)
{

	// C - rayStart
	// P - Sphere center
	// T - closest point on ray to P
	vec3 rayUnit = rayStart + rayDir;
	vec3 CP = spherePos - rayStart;
	
	// CP over ray unit vector
	float dotCP_ray = dot(CP, rayUnit);
	
	vec3 CT = rayStart + dotCP_ray * rayDir;
	return length(CT);
	/*
	vec3 PT = -CP + CT;
	float distanceToSphere = length(PT);
	
	// Skip spheres that are in opposite direction
	if (distanceToSphere <= radius) //&& dotCP_ray > 0)
	{
		// Correct would be to find the 
		// exact spot on sphere surface but this is ok ;)
		return length(CT);
	}
	else
	{
		return -1.0;
	}
	*/
}

