#version 460

#ifndef SPIRV_ENABLED
#extension GL_NV_gpu_shader5 : enable
#extension GL_NV_shader_thread_group : enable
#extension GL_ARB_shader_ballot : enable
#else
#extension GL_EXT_shader_explicit_arithmetic_types : enable
#extension GL_ARB_shader_ballot : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#define HAS_16BIT_TYPES
#endif

layout(early_fragment_tests) in;

#define RT_READ_ONLY 1

uniform sampler2DArray s_BlueNoise;

#include <shaders/materials/commons.glsl>
#include <shaders/commons_hlsl.glsl>
#include <shaders/materials/commons_sphere_sampling.glsl>
#include <shaders/materials/raytrace_buffers.glsl>
#include <shaders/materials/raytrace_commons.glsl>
#include <shaders/deferred/lighting_support.glsl>

#include <shaders/materials/mapping/triplanar.glsl>

layout(std140, row_major) uniform TransformParamsBuffer{
	EntityTransformParams transform_params;
};

//#pragma optionNV(fastmath on)
//#pragma optionNV(fastprecision off)
//#pragma optionNV(ifcvt none)
//#pragma optionNV(inline all)
//#pragma optionNV(strict on)
//#pragma optionNV(unroll 5)

#ifndef SPIRV_VULKAN
in Vertex
{
	vec3 vCoords;
	f16vec3 vNorm;
	f16vec3 vWorldNorm;
	vec3 vLocalPos;
	vec3 vWorldPos;
	f16vec4 vColor;
	f16vec2 vUV0;
} vtx_input;
#else
layout(location = 1) in struct
{
	vec3 vCoords;
	f16vec3 vNorm;
	f16vec3 vWorldNorm;
	vec3 vLocalPos;
	vec3 vWorldPos;
	f16vec4 vColor;
	f16vec2 vUV0;
} vtx_input;
#endif

// Marks rays which didn't hit solid object
//#define MARK_UNFINISHED 1
//#define VISUALIZE_GRID 1
//#define VISUALIZE_HEATMAP
//#define INNER_REFLECTION
#define INTERPOLATE_NORMALS

#ifndef MAX_BOUNCES
#define MAX_BOUNCES 4
#endif

#ifndef MAX_TRACE_LENGTH
#define MAX_TRACE_LENGTH 1024
#endif

#ifndef INITIAL_FACE_START_DISTANCE
#define INITIAL_FACE_START_DISTANCE 7.5
#endif

//in int gl_PrimitiveID;

uniform samplerCube s_reflection;
uniform usampler3D  s_grid_marker;

struct RTSetup
{
	mat4 mat_projection;
	mat4 mat_model;
	vec3 camera_position;
	int  screen_sampling_scale;
	vec4 camera_projection_params;
	vec4 near_far_plane;

	float trace_range;
	float roughness_clamp;

	int  lights_num;
};

layout (std140, row_major) uniform RTSetupBuffer
{
	RTSetup rt_setup;
};

// outputs depending on the pass
// -- RaytracePass
#if defined(RAYTRACE_PASS)

#ifndef GLASS_REFLECTION_PASS
layout(location = 0) out vec4 outAlbedo;
uniform sampler2D sFresnelReflection;
#else
layout(location = 0) out vec4 outFresnelReflection;
#endif // GLASS_REFLECTION_PASS

layout(r32ui) uniform readonly uimage2D imPrimitiveId;
layout(r32ui) uniform readonly uimage2D imNormalMaterial;
layout(rgba16ui) uniform readonly uimage2D imMetalnessRoughnessMaterialTags;
layout(rgba16f) uniform readonly image2D imAlbedo;
uniform sampler2D sTextureDepth;

float linearizeDepth(in float d)
{
	return rt_setup.near_far_plane.z / (rt_setup.near_far_plane.y + rt_setup.near_far_plane.x - d * rt_setup.near_far_plane.w);
}

vec3 positionFromDepth(vec3 vDirection, float depth)
{
	return vDirection.xyz * depth;
}


#endif // RAYTRACE_PASS
//

//
#define GRID_SIZE in_bbox_data.grid_size_raytrace

//

void build_triangle(uint idx, out vec3 p0, out vec3 p1, out vec3 p2)
{
	uint i0 = transformed_data_indices[idx * 3 + 0];
	uint i1 = transformed_data_indices[idx * 3 + 1];
	uint i2 = transformed_data_indices[idx * 3 + 2];
	
	p0 = rt_get_vertex(i0);
	p1 = rt_get_vertex(i1);
	p2 = rt_get_vertex(i2);
}

f16vec3 barycentric_for_face(int idx, vec3 p)
{
	uint i0 = transformed_data_indices[idx * 3 + 0];
	uint i1 = transformed_data_indices[idx * 3 + 1];
	uint i2 = transformed_data_indices[idx * 3 + 2];

	vec3 a = rt_get_vertex(i0);
	vec3 b = rt_get_vertex(i1);
	vec3 c = rt_get_vertex(i2);

	return f16vec3(rt_barycentric_xyz(p, a, b, c));
}

f16vec2 barycentric_for_face_yz(int idx, vec3 p)
{
	uint i0 = transformed_data_indices[idx * 3 + 0];
	uint i1 = transformed_data_indices[idx * 3 + 1];
	uint i2 = transformed_data_indices[idx * 3 + 2];

	vec3 a = rt_get_vertex(i0);
	vec3 b = rt_get_vertex(i1);
	vec3 c = rt_get_vertex(i2);

	return f16vec2(rt_barycentric_yz(p, a, b, c));
}

f16vec3 build_normal(int fi)
{
	vec3 p0, p1, p2;
	build_triangle(fi, p0, p1, p2);
		
	vec3 e1 = p1 - p0;
	vec3 e2 = p2 - p0;
	return f16vec3(normalize(cross(e1, e2)));
}

vec3 glass_refract(vec3 v, f16vec3 n)
{
	// TODO: add param for air->glass vs glass->air
	// Here value for air->glass
	//return refract(v, n, 1.0/1.5);
    //return n;

	float s = dot(f16vec3(v), n) < 0.0 ? 1.0 : -1.0;
	vec3 new_v = refract(v, n * s, 1.0 / 1.25);

	//return v;
	if (dot(new_v, new_v) == 0.0)
	{
		return v;
	}

	return new_v;
}

//---------------------------------------------------------------------------
// NOTE: We are doing dot produt thingie because this interpolator will 
// produce normal that is correct for face winding while during intersection
// we get proper normal always pointing towards the origin. This unifies this.
// NOTE2: This is also used to disable interpolation of the attributes on demand
f16vec3 interpolate_normal_from_bc_yz(int fi, f16vec2 bc_yz, f16vec3 ref_normal)
{
#ifdef INTERPOLATE_NORMALS
	f16vec3 n0, n1, n2;

	uint i0 = transformed_data_indices[fi * 3 + 0];
	uint i1 = transformed_data_indices[fi * 3 + 1];
	uint i2 = transformed_data_indices[fi * 3 + 2];
	
	n0 = f16vec3(rt_get_vertex_normal(i0));
	n1 = f16vec3(rt_get_vertex_normal(i1));
	n2 = f16vec3(rt_get_vertex_normal(i2));
	//return n0;

	f16vec3 smooth_normal = n0 * (float16_t(1.0) - bc_yz.x - bc_yz.y) + n1 * bc_yz.x + n2 * bc_yz.y;
	//return smooth_normal * (dot(smooth_normal, ref_normal) < 0.0 ? -1.0 : 1.0);
	return f16vec3(normalize(smooth_normal * (dot(smooth_normal, ref_normal) < 0.0 ? -1.0 : 1.0)));
#else
	return ref_normal;
#endif
}

f16vec2 interpolate_uv_from_bc_yz(int fi, f16vec2 bc_yz)
{
	f16vec2 uv0, uv1, uv2;

	uint i0 = transformed_data_indices[fi * 3 + 0];
	uint i1 = transformed_data_indices[fi * 3 + 1];
	uint i2 = transformed_data_indices[fi * 3 + 2];

	uv0 = f16vec2(rt_get_vertex_uv0(i0));
	uv1 = f16vec2(rt_get_vertex_uv0(i1));
	uv2 = f16vec2(rt_get_vertex_uv0(i2));

	f16vec2 uv = uv0 * (float16_t(1.0) - bc_yz.x - bc_yz.y) + uv1 * bc_yz.x + uv2 * bc_yz.y;
	return uv;
}

bool intersectTriangle(vec3 orig, vec3 dir, vec3 v0, vec3 v1, vec3 v2, out float intersection_t, out f16vec3 out_normal)
{
	intersection_t = 0.0;

	//out_normal = vec3(0.0);

	vec3 e1 = v1 - v0;
	vec3 e2 = v2 - v0;
	// Calculate planes normal vector
	vec3 pvec = cross(dir, e2);
	float det = dot(e1, pvec);

	// Ray is parallel to plane
	if (det < 1e-8 && det > -1e-8)
	{
		//out_normal = vec3(1.0, 0.4, 1.0);
		return false;
	}

	float inv_det = 1.0 / det;
	vec3 tvec = orig - v0;
	float u = dot(tvec, pvec) * inv_det;
	if (u < 0.0 || u > 1.0)
	{
		//out_normal = vec3(1.0, 1.0, 0.0);       // this one is somehow triggered now????
		return false;
	}

	vec3 qvec = cross(tvec, e1);
	float v = dot(dir, qvec) * inv_det;
	if (v < 0.0 || u + v > 1.0)
	{
		out_normal = f16vec3(1.0, 0.0, 0.0);
		return false;
	}

	intersection_t = dot(e2, qvec) * inv_det;
	if (intersection_t > 1e-8)
	{
		out_normal = f16vec3(normalize(cross(e1, e2)));	// TODO: remove normalization and reuse above calcs
		return true;
	}
	return false;
}

// NOTE: trying to workaround compiler issues here....

struct intersection
{
	float t;
	f16vec3 normal;
	float16_t denom;
	f16vec2 bc;			// barycentrics. only two included
};

intersection intersectTriangle2(vec3 orig, vec3 dir, vec3 v0, vec3 v1, vec3 v2)
{
	intersection it;
	it.t = -1.0;

	vec3 e1 = v1 - v0;
	vec3 e2 = v2 - v0;
	// Calculate planes normal vector
	vec3 pvec = cross(dir, e2);
	float det = dot(e1, pvec);

	// Ray is parallel to plane
	if (abs(det) < 1e-7)
	{
		return it;
	}

	float inv_det = 1.0 / det;
	vec3 tvec = orig - v0;
	float u = dot(tvec, pvec) * inv_det;
	if (u < 0.0 || u > 1.0)
	{
		return it;
	}

	vec3 qvec = cross(tvec, e1);
	float v = dot(dir, qvec) * inv_det;
	if (v < 0.0 || u + v > 1.0)
	{
		return it;
	}

	// this is some fucked up shit... miscompiles?
	float t = dot(e2, qvec) * inv_det;
	if (t > 1e-7)
	{
		// float denom = dot(normalize(cross(e1, e2)), dir);
		it.normal = f16vec3(normalize(cross(e2, e1)));	// TODO: remove normalization and reuse above calcs
		// NOTE: This fucks up for some reason:(
		//it.t = t;
		#if 1
		float denom = dot(vec3(it.normal), dir); 
		if (denom > 1e-7)
		{
			vec3 p0l0 = v0 - orig; 
			t = dot(p0l0, vec3(it.normal)) / denom; 
			it.t = t;
		}
		#endif
	}
	return it;
}

intersection intersectTriangle3(vec3 orig, vec3 dir, vec3 v0, vec3 v1, vec3 v2)
{
	intersection it;
	it.t = -1.0;

	vec3 e1 = v1 - v0;
	vec3 e2 = v2 - v0;
	// Calculate planes normal vector
	vec3 pvec = cross(dir, e2);
	float det = dot(e1, pvec);

	// Ray is parallel to plane
	if (abs(det) < 1e-7)
	{
		return it;
	}

	float inv_det = 1.0 / det;
	vec3 tvec = orig - v0;
	float u = dot(tvec, pvec) * inv_det;
	if (u < 0.0 || u > 1.0)
	{
		return it;
	}

	vec3 qvec = cross(tvec, e1);
	float v = dot(dir, qvec) * inv_det;
	if (v < 0.0 || u + v > 1.0)
	{
		return it;
	}

	// this is some fucked up shit... miscompiles?
	float t = dot(e2, qvec) * inv_det;
	if (t > 1e-7)
	{
		// float denom = dot(normalize(cross(e1, e2)), dir);
		it.normal = f16vec3(normalize(cross(e2, e1)));	// TODO: remove normalization and reuse above calcs
		//it.hit = true;
		// NOTE: This fucks up for some reason:(
		//it.t = t;
		#if 1
		float denom = dot(vec3(it.normal), dir); 
		
		//if (denom > 1e-7)
		{
			vec3 p0l0 = v0 - orig; 
			t = dot(p0l0, vec3(it.normal)) / denom; 
			it.t = t;
			it.denom = float16_t(denom);

			it.bc.x = float16_t(u / denom);
			it.bc.y = float16_t(v / denom);
		}
		#endif
	}
	return it;
}

#if 0
int findClosestNaive(vec3 origin, vec3 dir, int skip_fi, out int closest_fi, out float closest_it, out f16vec3 closest_norm)
{
	closest_fi = -1;
	closest_it = 10000.0;

	for(int fi = 0; fi < numFaces; fi++)
	{
		if (fi == skip_fi)
			continue;

		intersection it1, it2;

		vec3 p0, p1, p2;
		build_triangle(fi, p0, p1, p2);
		
		#if 0

		#if 1 // backfaces
		it1 = intersectTriangle2(origin, dir, p0, p1, p2);
		if (it1.t > 1e-8)
		{
			if (it1.t < closest_it)
			{
				closest_fi = fi;
				closest_it = it1.t;
				closest_norm = -it1.normal;
			}
		}
		#endif
		#if 1
		it2 = intersectTriangle2(origin, dir, p2, p1, p0);
		if (it2.t > 1e-8)
		{
			if (it2.t < closest_it)
			{
				closest_fi = fi;
				closest_it = it2.t;
				closest_norm = it2.normal;
			}
		}
		#endif

		#else // optimized

		
		it1 = intersectTriangle3(origin, dir, p2, p1, p0);
		if (it1.t > 1e-7)
		{
			if (it1.t < closest_it)
			{
				closest_fi = fi;
				closest_it = it1.t;
				closest_norm = it1.normal;

				#ifdef INNER_REFLECTION
				closest_norm = it1.denom > 0.0 ? it1.normal : -it1.normal;
				#endif
			}
		}

		#endif
	}

	return closest_fi;
}
#endif

#ifndef USE_LINKED_LISTS
void findClosestBucket(int bucket, vec3 origin, vec3 dir, int skip_fi, out int closest_fi, out float closest_it, out f16vec3 closest_norm)
{
	closest_fi = -1;
	closest_it = 1000000.0;

	//int bucket_offset = int(in_buckets.offsets[bucket]);
	//int bucket_size = int(in_buckets.sizes[bucket]);

	int bucket_offset = int(in_buckets.offsets[bucket]);
	int bucket_size = int(in_buckets.sizes[bucket]);

	// try this to remove one if from the loop
	//origin += dir * 1e-4;

	for(int fi_idx = 0; fi_idx < bucket_size; fi_idx++)
	{
		int fi = int(in_buckets.indices[bucket_offset + fi_idx]);
		if (fi == skip_fi)
			continue;

		vec3 p0, p1, p2;
		build_triangle(fi, p0, p1, p2);

		#if 0
		intersection it1, it2;
		#if 1 // backfaces
		it1 = intersectTriangle2(origin, dir, p0, p1, p2);
		if (it1.hit && it1.t > 1e-8)
		{
			if (it1.t < closest_it)
			{
				closest_fi = fi;
				closest_it = it1.t;
				#ifdef INNER_REFLECTION
				closest_norm = it1.normal;  // should be negative. positive if we want to bounce back into the object's inside!!
				#else
				closest_norm = -it1.normal;
				#endif
			}
		}
		#endif
		#if 1
		it2 = intersectTriangle2(origin, dir, p2, p1, p0);
		if (it2.hit && it2.t > 1e-8)
		{
			if (it2.t < closest_it)
			{
				closest_fi = fi;
				closest_it = it2.t;
				closest_norm = it2.normal;
			}
		}
		#endif

		#else

		intersection it1 = intersectTriangle3(origin, dir, p2, p1, p0);
		if (it1.hit)
		{
			if (it1.t < closest_it)
			{
				closest_fi = fi;
				closest_it = it1.t;
				#ifdef INNER_REFLECTION
				closest_norm = it1.denom > 0.0 ? it1.normal : -it1.normal;
				#else
				closest_norm = it1.normal;
				#endif
			}
		}

		#endif
	}
}

void findClosestBucket2(uint bucket_offset, int bucket_size, vec3 origin, vec3 dir, int skip_fi, float max_t, out int closest_fi, out float closest_it, out f16vec3 closest_norm, out f16vec2 closest_bc)
{
	closest_fi = -1;
	closest_it = max_t;

	// try this to remove one if from the loop
	//origin += dir * 1e-4;

	for(int fi_idx = 0; fi_idx < bucket_size; fi_idx++)
	{
		int fi = int(in_buckets.indices[bucket_offset + fi_idx]);

		if (fi == skip_fi)
			continue;

		vec3 p0, p1, p2;
		build_triangle(fi, p0, p1, p2);

		intersection it1 = intersectTriangle3(origin, dir, p2, p1, p0);
		if (it1.t > 0.0 && it1.t < closest_it)
		{
			closest_fi = fi;
			closest_it = it1.t;
			#ifdef INNER_REFLECTION
			closest_norm = it1.denom > 0.0 ? it1.normal : -it1.normal;
			#else
			closest_norm = it1.normal;
			#endif
		}
	}
}

#else // linked list version

void findClosestBucket2(uint list_index, int bucket_size, vec3 origin, vec3 dir, int skip_fi, float max_t, out int closest_fi, out float closest_it, out f16vec3 closest_norm, out f16vec2 closest_bc)
{

	closest_fi = -1;
	closest_it = max_t;

	// try this to remove one if from the loop
	//origin += dir * 1e-4;

    #if 0
	//uint idx = list_index;  // this is our head
	uint idx = in_faces_list_tails_data[list_index];
	for(int fi_idx = 0; fi_idx < bucket_size; fi_idx++)
	{
		int fi = int(in_faces_list_data.node_buffer[idx].value);
		idx = in_faces_list_data.node_buffer[idx].next;
    #else
    //uint idx = list_index;  // this is our head
	uint idx = in_faces_list_tails_data[list_index];
	while(idx != -1)
	{
		int fi = int(in_faces_list_data.node_buffer[idx].value);
		idx = in_faces_list_data.node_buffer[idx].next;
    #endif
		if (fi == skip_fi)
			continue;

		vec3 p0, p1, p2;
		build_triangle(fi, p0, p1, p2);

		intersection it1 = intersectTriangle3(origin, dir, p2, p1, p0);
		{
			if (it1.t >= 0.0 && it1.t <= closest_it)
			{
				closest_fi = fi;
				closest_it = it1.t;
				#ifdef INNER_REFLECTION
				closest_norm = it1.denom > 0.0 ? it1.normal : -it1.normal;
				#else
				closest_norm = it1.normal;
				#endif

				closest_bc = it1.bc;
			}
		}
	}
}

#endif

#if 0
int countClosest(vec3 origin, vec3 dir, int skip_fi)
{
	int closest_fi = -1;
	float closest_it = 1000000.0;
	int count = 0;
	vec3 closest_norm;

	for(int fi = 0; fi < numFaces; fi++)
	{
		if (fi == skip_fi)
			continue;

		float it = 0.0;

		vec3 p0, p1, p2;
		build_triangle(fi, p0, p1, p2);

		if (intersectTriangle(origin, dir, p0, p1, p2, it, closest_norm))
		{
			if (it < closest_it)
			{
				closest_fi = fi;
				closest_it = it;
			}
			count++;
		}

		if (intersectTriangle(origin, dir, p2, p1, p0, it, closest_norm))
		{
			if (it < closest_it)
			{
				closest_fi = fi;
				closest_it = it;
			}
			count++;
		}

	}

	return count;
}
#endif

int countClosestBucket2(int list_index, int bucket_size, vec3 origin, vec3 dir, int skip_fi)
{
    int cnt = 0;

    #if 1
	uint idx = list_index;  // this is our head
	for(int fi_idx = 0; fi_idx < bucket_size; fi_idx++)
	{
		int fi = int(in_faces_list_data.node_buffer[idx].value);
		idx = in_faces_list_data.node_buffer[idx].next;
    #else
    for(int fi = 0; fi < numFaces; fi++)
    {
    #endif
		if (fi == skip_fi)
			continue;

		vec3 p0, p1, p2;
		build_triangle(fi, p0, p1, p2);

		intersection it1 = intersectTriangle3(origin, dir, p2, p1, p0);
		if (it1.t > 0.0)
		{
			cnt++;
		}
	}

    return cnt;
}

//--------------- DDA Intersection support ----------------------------------

struct ray_state
{
	f16vec3 color;
	f16vec3 normal;
	vec3 dir;
	vec3 origin;
	float16_t transparency;
	bool running;
	int16_t bounces;
	int16_t material;
	int16_t tests;
	bool hit;
	bool left;
	uint  active_threads_factor;
	uint  active_threads_samples;
};

f16vec4 sample_env_map(in vec3 r, float roughness)
{
    vec4 v = textureLod(s_reflection, r, 0.5f + roughness * 15.0f);
	//v = (v - 0.1f) * 4.2f;
	//v = v * 10.0;
	return f16vec4(max(vec4(0.0f), v));
}

vec3 calculate_lighting_world(LightProperties light, in vec3 pos, in f16vec3 normal, in vec3 light_pos, in float16_t NdotL)
{
	float16_t d = NdotL;
	if (d < 0.0)
		d = float16_t(0.0);
	return light.diffuse.xyz * d;
}

// VERY simple version
#if 0
vec3 evaluate_lighting(vec3 p, vec3 o, f16vec3 n, int material, f16vec3 albedo, f16vec3 emissive)
{
	vec3 c = vec3(0.0);
	vec3 world = p;

	f16vec3 f0 = f16vec3(0.04, 0.04, 0.04);
	float16_t metallic = float16_t(materials.material_properties[material].metalness);
	f16vec3 baseColor = f16vec3(materials.material_properties[material].diffuse.rgb) * albedo.rgb;
	f16vec3 diffuseColor = baseColor.rgb * (f16vec3(1.0) - f0) * (1.0 - metallic);
	f16vec3 specularColor = mix(f0, baseColor.rgb, metallic);

	// iterate through lights. no specular component, only diffuse
	for(int light_idx = 0; light_idx < rt_setup.lights_num; light_idx++)
	{
		LightProperties light = lights.light_properties[light_idx];

		vec3 light_color = vec3(0.0);
		vec4 projector_color = vec4(0.0);

		float16_t NdotL = dot(n, f16vec3(normalize(light.position.xyz - world.xyz)));

		// For transparents we assume light passed. Bit silly, but...
		if ((materials.material_properties[material].flags & MaterialFlag_Transparent) != 0)
			NdotL = abs(NdotL);
		vec3 lighting = calculate_lighting_world(light, world, n, light.position.xyz, NdotL);
		
		c += lighting;
	}

	c += emissive * materials.material_properties[material].emissive_factor;

	return c;
}

#else

f16vec3 evaluate_lighting(vec3 p, vec3 o, f16vec3 n, int material, f16vec3 albedo, f16vec3 emissive)
{
	vec3 c = vec3(0.0);
	vec3 world = p;
	vec3 view = normalize(o - p);

	f16vec3 f0 = f16vec3(0.04f, 0.04f, 0.04f);
	float16_t metallic = float16_t(materials.material_properties[material].metalness);
	f16vec3 baseColor = f16vec3(clamp(materials.material_properties[material].diffuse.rgb * albedo.rgb, vec3(0.0), vec3(1.0)));
	f16vec3 diffuseColor = baseColor.rgb * (f16vec3(1.0f) - f0) * (float16_t(1.0) - metallic);
	f16vec3 specularColor = f16vec3(mix(f0, baseColor.rgb, metallic));

	f16vec3 specularEnvironmentR0 = specularColor.rgb;
    // Anything less than 2% is physically impossible and is instead considered to be shadowing. Compare to "Real-Time-Rendering" 4th editon on page 325.
    float reflectance = float16_t(max(max(specularColor.r, specularColor.g), specularColor.b));
    vec3 specularEnvironmentR90 = vec3(1.0, 1.0, 1.0) * clamp(reflectance * 25.0, 0.0, 1.0);

	float roughness = materials.material_properties[material].roughness;
	float alphaRoughness = roughness * roughness;

	MaterialInfo materialInfo =
    {
        roughness,
        specularEnvironmentR0,
        alphaRoughness,
        diffuseColor,
        specularEnvironmentR90,
        specularColor
    };

	// iterate through lights. no specular component, only diffuse
	for(int light_idx = 0; light_idx < rt_setup.lights_num; light_idx++)
	{
		LightProperties light = lights.light_properties[light_idx];
		if ((light.type & (LightType_Spot| LightType_Directional)) != 0)
		//if ((light.type & (LightType_Spot)) != 0)
		{
			vec3 pointToLight = light.position.xyz - world.xyz;
			if ((light.type & LightType_Directional) != 0)
				pointToLight = -light.direction.xyz;

			float NdotL = dot(vec3(n), normalize(pointToLight));

			vec3 lighting = vec3(0.0);
			if (NdotL > 0.0)
			{
				float attenuation = 1.0;
				if ((light.type & LightType_Spot) != 0)
					attenuation = light_calculate_spot_attenuation(light, world.xyz);
			
				if (attenuation > 0.0)
				{
					lighting = getPointShade(pointToLight, materialInfo, n, view) * (light.intensity * attenuation) * light.diffuse.rgb;
					c += lighting;
				}
			}
		}
	}

	c += emissive * materials.material_properties[material].emissive_factor;
	//c = world * 0.0001;
	//c = n * 0.5 + 0.5;

	return f16vec3(c);
	//return f16vec3(baseColor);
}

#endif


f16vec4 evaluate_albedo(int16_t materialIndex, int fi, f16vec2 bc)
{
	f16vec4 c = f16vec4(1.0);
	if (materials.material_properties[materialIndex].albedo_sampler >= 0)
	{
		f16vec2 uv = interpolate_uv_from_bc_yz(fi, bc);
		c = f16vec4(texture(material_textures[materials.material_properties[materialIndex].albedo_sampler], uv).rgba);
		//c.rg = f16vec2(uv);
		//c.b = 1.0;
	}

	return c;
}

f16vec3 evaluate_emissive(int16_t materialIndex, int fi, f16vec2 bc)
{
	f16vec3 c = f16vec3(1.0);
	if (materials.material_properties[materialIndex].emissive_sampler >= 0)
	{
		f16vec2 uv = interpolate_uv_from_bc_yz(fi, bc);
		c = f16vec3(texture(material_textures[materials.material_properties[materialIndex].emissive_sampler], uv).rgb);
	}

	return c;
}

vec3 yCgCo2rgb(vec3 ycc)
{
	float R = ycc.x - ycc.y + ycc.z;
	float G = ycc.x + ycc.y;
	float B = ycc.x - ycc.y - ycc.z;
	return vec3(R,G,B);
}

vec3 spectrum_offset_ycgco( float t )
{
	//vec3 ygo = vec3( 1.0, 1.5*t, 0.0 ); //green-pink
	//vec3 ygo = vec3( 1.0, -1.5*t, 0.0 ); //green-purple
	vec3 ygo = vec3( 1.0, 0.0, -1.25*t ); //cyan-orange
	//vec3 ygo = vec3( 1.0, 0.0, 1.5*t ); //brownyello-blue
	return yCgCo2rgb( ygo );
}

vec3 spectrum_offset_rgb( float t )
{
    float t0 = 3.0 * t - 1.5;
    vec3 ret = clamp( vec3( -t0, 1.0-abs(t0), t0), vec3(0.0), vec3(1.0));
    return ret;
}

void evaluate_material(in out ray_state state, in vec3 prev_state_origin, int16_t hit_material, int hit_face, f16vec2 bc)
{
	f16vec4 hit_albedo = evaluate_albedo(hit_material, hit_face, bc);
	f16vec3 hit_emissive = evaluate_emissive(hit_material, hit_face, bc) * f16vec3(materials.material_properties[hit_material].emissive.rgb);

	f16vec3 hit_color = evaluate_lighting(state.origin, prev_state_origin, state.normal, hit_material, hit_albedo.rgb, hit_emissive);
	//vec3 hit_color = vec3(bc, 1.0);//evaluate_lighting(state.origin, prev_state_origin, state.normal, hit_material, hit_albedo, hit_emissive);
	//vec3 hit_color = materials.material_properties[hit_material].diffuse.rgb;// evaluate_lighting(state.origin, prev_state_origin, state.normal, hit_material, hit_albedo, hit_emissive);

	// NOTE: Why not using _Reflective property? Because this instruments render passes if the material should
	// be raytraced and we might not want that. Maybe it should be separated into two fields actually?

	bool hit_material_reflective = materials.material_properties[hit_material].roughness < 1.0;
	if ((materials.material_properties[state.material].flags & MaterialFlag_Transparent) == 0)
	{
		hit_color = f16vec3(mix(state.color * hit_color, state.color, float16_t(materials.material_properties[state.material].roughness)));
		state.color = f16vec3(mix(state.color, hit_color, float16_t(state.transparency)));
		//state.transparency = mix(state.transparency, 0.0, (1.0 - materials.material_properties[hit_material].transparency));
		//state.transparency = min(state.transparency, materials.material_properties[hit_material].transparency);
		//state.transparency = mix(state.transparency, 1.0, 1.0 - materials.material_properties[hit_material].transparency);
		// NOTE: don't modulate transparency as we want to keep current one for this ray

	}
	else if (hit_material_reflective == false && (materials.material_properties[hit_material].flags & MaterialFlag_Transparent) == 0) // diffuse, no reflection, abort
	{
		state.color = f16vec3(mix(state.color, hit_color, state.transparency));
		state.running = false;
		state.hit = true;
	}
	else if (hit_material_reflective == true && (materials.material_properties[hit_material].flags & MaterialFlag_Transparent) == 0) // reflective
	{
		// reflective material
		state.color = f16vec3(mix(state.color, hit_color, state.transparency));
    }
    else //if (materials.material_properties[state.material].transparency > 0.0)
    {
		float hit_transparency = 1.0 - (1.0 - materials.material_properties[hit_material].transparency) * hit_albedo.a;
        state.transparency = state.transparency * float16_t(hit_transparency);
		state.color = f16vec3(mix(state.color, hit_color, state.transparency));
    }

    if ((materials.material_properties[hit_material].flags & MaterialFlag_Transparent) == 0)
    {
        state.dir = reflect(state.dir, vec3(state.normal));
    }
    else
    {
		vec3 refracted_dir = glass_refract(state.dir, state.normal);
		float angle = dot(refracted_dir, state.dir);

		#if 1
		float threshold = 0.;
		angle = 1.0 - abs(angle);
		{
			//angle = angle * angle;
			//angle = sqrt(angle);
			float factor = clamp(angle * 5.0, 0.0, 1.0);
			angle = angle * 10.5;
			//angle = clamp(angle, 0.0, 1.0);
			//state.color = mix(state.color, spectrum_offset_rgb(angle), factor);
			state.color = mix(state.color, spectrum_offset_ycgco(angle), factor);
			//state.color = spectrum_offset_ycgco(angle);
			//state.color = vec3(factor);
		}
		#endif
        state.dir = refracted_dir;
		//state.color.rgb = vec3(hit_albedo.a);
    }

	if ((materials.material_properties[hit_material].flags & MaterialFlag_Reflective) == 0
		&& (materials.material_properties[hit_material].flags & MaterialFlag_Transparent) == 0)
	{
		state.running = false;
		state.hit = true;
	}

	//state.color = hit_color;//state.normal * 0.5 + 0.5;
	//state.running = false;
	//state.hit = true;

	state.material = hit_material;
}

struct dda
{
	bool     is_high;
	vec3     res;
	vec3     res_rcp;
	vec3     ro;
	vec3     ird;
	vec3     delta;
	vec3     t_max;
	float    prev_next_t;
	vec3     prev_t_max;
	float    next_t;

	vec3     pf;
	float    tmin;
    float    tmax;
};

bool is_pos_inside_grid(i16vec3 icell)
{
	int16_t icell_mask = icell.x | icell.y | icell.z;   // test against <0 and >GRID_RES-1
	if ((icell_mask & (~(GRID_RES-1))) == 0)
		return true;
	else
		return false;
}

void build_dda(inout dda dda, vec3 ro, vec3 rd, float tmin, float tmax, bool restart)
{
	if (!restart)
	{
		dda.res = f16vec3(GRID_SIZE);
		dda.res_rcp = 1.0 / dda.res;
		dda.is_high = false;
	}
	dda.ro = ro - in_bbox_data.bbox_raytrace_min.xyz;

//	ivec3 p = ivec3(floor(dda.ro));
//	if (is_pos_inside_grid(p))
//		dda.res = is_high_level_cell_occupied(p.x, p.y, p.z) != 0 ? 4 : 1;

	dda.ird = vec3(1.0) / rd;
	vec3 s = step(vec3(0.0), rd);

	dda.delta = (s * 2.0 - 1.0) * dda.res * dda.ird;
	dda.t_max = ((floor(dda.ro * dda.res_rcp) + s) * dda.res - dda.ro) * dda.ird;

	dda.prev_next_t = 0.0f;
	dda.prev_t_max = dda.t_max;

	if (!restart)
		dda.pf = dda.ro;

	dda.tmin = tmin;
	dda.tmax = tmax;
}

i16vec3 grd_icell_dda(inout dda dda)
{
	//ivec3 p = ivec3(floor(dda.pf * (dda.res_rcp * (dda.is_high ? 4.0 : 1.0))));
	i16vec3 p = i16vec3(floor(dda.pf / vec3(GRID_SIZE)));
	return p;
}

bool is_inside_grid_dda(inout dda dda)
{
	i16vec3 p = grd_icell_dda(dda);
	return is_pos_inside_grid(p);
}

bool is_high_level_empty_dda(inout dda dda)
{
	#if 1
	ivec3 p = grd_icell_dda(dda);
	//return is_high_level_cell_occupied(p.x, p.y, p.z) == 0 ? true : false;
	return rt_read_grid_marker_high_res(s_grid_marker, p).r == 0 ? true : false;
	#else
	ivec3 p = ivec3(floor(dda.pf * (dda.res_rcp * (dda.is_high ? 1.0 : 0.25))));
	return rt_read_grid_marker_low_res(s_grid_marker, p).r > 0 ? false : true;
	#endif
}


float get_current_intersection_dda(inout dda dda, in ray_state state)
{
	dda.next_t = min(dda.t_max.x, min(dda.t_max.y, dda.t_max.z));
	dda.pf = dda.ro + state.dir * (dda.prev_next_t + dda.next_t) * 0.5;
	return dda.next_t;
}

bool dda_is_abort(in dda dda)
{
	return dda.next_t >= dda.tmax;
}

void step_dda(inout dda dda)
{
	dda.prev_next_t = dda.next_t;
	vec3 cmp = step(dda.t_max.xyz, dda.t_max.yxy) * step(dda.t_max.xyz, dda.t_max.zzx);
	dda.prev_t_max = dda.t_max;
	dda.t_max += cmp * dda.delta;
}

float dda_scale_grid_res(inout dda dda, in ray_state state, float scale)
{
	dda.res *= scale;
	dda.res_rcp *= 1.0 / scale;
	dda.delta *= scale;

	// NOTE: This will still nuke sometimes. When we step exacly on both edges floor() might
	//       move us back one cell and this will nuke. Hope it is unlikely enough...
	// we calculate intersection with next block using new scales
					
	// NOTE: Try to merge these two conditions
	// NOTE: floor(float) in processing returns int!!!!

	float prev_t_min = min(dda.prev_t_max.x, min(dda.prev_t_max.y, dda.prev_t_max.z));  // this is checking where the grid intersection happened. can we just keep track of it?

	float blend_x = prev_t_min == dda.prev_t_max.x ? 1.0 : 0.0;
	float blend_y = prev_t_min == dda.prev_t_max.y ? 1.0 : 0.0;
	float blend_z = prev_t_min == dda.prev_t_max.z ? 1.0 : 0.0;

	vec3 next_ro = dda.ro + state.dir * prev_t_min;

	float dt_x_1 = dda.delta.x;
	float dt_y_1 = dda.delta.y;
	float dt_z_1 = dda.delta.z;
							
	vec3 s = step(vec3(0.0), state.dir);
	float dt_x_2 = ((floor(next_ro.x * dda.res_rcp.x) + s.x) * dda.res.x - next_ro.x) * dda.ird.x;
	float dt_y_2 = ((floor(next_ro.y * dda.res_rcp.y) + s.y) * dda.res.y - next_ro.y) * dda.ird.y;
	float dt_z_2 = ((floor(next_ro.z * dda.res_rcp.z) + s.z) * dda.res.z - next_ro.z) * dda.ird.z;
							
	dda.t_max.x = mix(dt_x_2, dt_x_1, blend_x);
	dda.t_max.y = mix(dt_y_2, dt_y_1, blend_y);
	dda.t_max.z = mix(dt_z_2, dt_z_1, blend_z);

	dda.t_max += prev_t_min;
	dda.next_t = min(dda.t_max.x, min(dda.t_max.y, dda.t_max.z));

	if (scale < 1.0) // no need to calculate sampling cube more precisely when going to high res
		dda.pf = dda.ro + state.dir * (dda.prev_next_t + dda.next_t) * 0.5;

	return dda.next_t;
}

int fetch_grid_marker_for_cell(ivec3 icell, int mip)
{
	//int grid_marker = int(texelFetch(s_grid_marker, icell, 0).r);
	if (mip == 0)
		return int(rt_read_grid_marker_high_res(s_grid_marker, icell));
	else
		return int(rt_read_grid_marker_low_res(s_grid_marker, icell >> 2));

}


#if 0  // version without multiresolution grid
int findClosestDDAMultibounce(inout ray_state state, int skip_fi, out int closest_fi, out float closest_it, int max_bounces)
{ 
    //state.dir = normalize(state.dir);

	closest_fi = -1;
	closest_it = 1000000.0;

	vec3 cellDimension = vec3(GRID_SIZE);

	float tmin = 0.0;
	float tmax = rt_setup.trace_range;

	ivec3 icell_step;
    ivec3 icell;
    
	vec3 deltaT, nextCrossingT; 

	vec3 ro_cell = state.origin - in_bbox_data.bbox_raytrace_min.xyz;
	icell = ivec3(floor(ro_cell / cellDimension));

	{
		vec3 s = step(vec3(0.0), state.dir);       // 0.0 or 1.0
		vec3 sgn = s * 2.0 - 1.0;                  // -1.0 or 1.0 

		deltaT = sgn * cellDimension / state.dir;    // same as 'dir'
		nextCrossingT = tmin + ((floor(ro_cell / cellDimension) + s) * cellDimension - ro_cell) / state.dir;
		icell_step = ivec3(sgn);
	}
 
	// walk through each cell of the grid and test for an intersection if
	// current cell contains geometry
	float rt = tmin;
	int inside = 0;

	int max_iter = MAX_TRACE_LENGTH;                        // this includes all boundes we are now tracking

	while(rt < tmax && max_iter >= 0)
    //hile(max_iter > 0)
	{
		bool hit = false;
		max_iter--;
        state.tests++;

		if (state.running == false)
			break;
		//if (state.tests > )
		//	break;

        // t for next next crossing intersection
        rt = tmin + min(nextCrossingT.x, min(nextCrossingT.y, nextCrossingT.z));

#if 0
		if (fetch_grid_marker_for_cell(icell, 0) > 0)
#else
		if (is_pos_inside_grid(i16vec3(icell)))
#endif
		{
			int icell_idx = icell.z * GRID_RES * GRID_RES + icell.y * GRID_RES + icell.x;
			inside = 1;

#if 1
			int bucket_size_4 = fetch_grid_marker_for_cell(icell, 2);
			int bucket_size = 0;
			if (bucket_size_4 > 0)
				bucket_size = fetch_grid_marker_for_cell(icell, 0);
#else
			int bucket_size = fetch_grid_bucket_size_for_cell(icell);
#endif
			if (bucket_size > 0)
			{
				f16vec2 closest_bc;
				f16vec3 closest_normal;

				//state.tests += bucket_size;

				#ifndef USE_LINKED_LISTS
				int bucket_offset = int(in_buckets.offsets[icell_idx]);
				findClosestBucket2(bucket_offset, bucket_size, state.origin, state.dir, skip_fi, closest_fi, closest_it, state.normal);
				//findClosestBucket(icell_idx, state.origin, state.dir, skip_fi, closest_fi, closest_it, state.normal);
				#else
				findClosestBucket2(icell_idx, bucket_size, state.origin, state.dir, skip_fi, rt, closest_fi, closest_it, closest_normal, closest_bc);
				#endif
				
				if (closest_fi != -1)
				{
					state.bounces += int16_t(1);
					state.normal = closest_normal;

                    #if 0
					if (state.bounces == 1)
                    {
                        int cnt = countClosestBucket2(icell_idx, bucket_size, state.origin, state.dir, skip_fi);
                        state.color = vec3(cnt) * 0.2;
                        //state.color.r = cnt > 1 ? 0.0 : 0.4;
                        if (false)
                        {
                            state.origin = state.origin + state.dir * closest_it;
                            state.material = rt_get_triangle_material(closest_fi);
                            evaluate_material(state);
                            state.color.rgb = state.color.rgb * 0.1 + 0.5 + 0.5 * sin(vec3(20.0, 0.0, 0.0)* vec3(icell) / GRID_RES);
                        }

						state.running = false;
                        state.hit = true;
                        //break;
                    }
                    #endif

					int hit_material = rt_get_triangle_material(closest_fi);
					if (state.bounces >= max_bounces || (materials.material_properties[hit_material].flags & MaterialFlag_RaytraceTerminate) != 0)
                    {
                        state.running = false;
                        //state.hit = true;
                        //break;
                    }

                    skip_fi = closest_fi;
					hit = true;

					// either reflect of refract
					//vec3 bc = vec3(closest_bc.x, closest_bc.y, 1.0 - closest_bc.x - closest_bc.y); //barycentric_for_face(closest_fi, state.origin + state.dir * closest_it);
					f16vec2 bc = barycentric_for_face_yz(closest_fi, state.origin + state.dir * closest_it);

					if ((materials.material_properties[hit_material].flags & MaterialFlag_Flat) == 0)
						state.normal = interpolate_normal_from_bc_yz(closest_fi, bc, state.normal);

					// calculate new origin and recalculate tracing parameters for the bounce
					vec3 prev_state_origin = state.origin; // for lighting calculation
					state.origin = state.origin + state.dir * closest_it;

					// we hit solid object which is not perfectly rough, reflect
                    evaluate_material(state, prev_state_origin, int16_t(hit_material), closest_fi, bc);
					
					// rebuild stepping parameters TODO: factor this out
                    
					vec3 ird = 1.0 / state.dir;
					ro_cell = state.origin - in_bbox_data.bbox_raytrace_min.xyz;
					//cell = floor(ro_cell / cellDimension);

					vec3 s = step(vec3(0.0), state.dir);            // 0.0 or 1.0
					vec3 sgn = s * 2.0 - 1.0;                       // -1.0 or 1.0 

					deltaT = sgn * cellDimension * ird;         // same as 'dir'
					nextCrossingT = tmin + ((floor(ro_cell / cellDimension) + s) * cellDimension - ro_cell) * ird;
					icell_step = ivec3(sgn);
                    
					// walk through each cell of the grid and test for an intersection if
					// current cell contains geometry
					rt = tmin;
				}
			}
		}
		else
		{
			if (inside == 1)
            {
				state.running = false;
                state.left = true;
            }
		}

		if (!hit)
		{
			// all components of minimum mask (i.e. x <= y && x <= z, y <= x && y <= z, z <= y && z <= x) 
			// are false except for the corresponding smallest component of dt (if no mask), which 
			// is the axis along which the ray should be incremented
			// stolen from https://github.com/guozhou/voxelizer/blob/master/raycasting_fs.glsl
            // NOTE: nextCrossingT == dt
            
		    vec3 mm = step(nextCrossingT.xyz, nextCrossingT.yxy) * step(nextCrossingT.xyz, nextCrossingT.zzx);
		    icell += ivec3(mm) * icell_step;
		    nextCrossingT += mm * deltaT;
		}
	} 

	return 0;
} 
#else

uint ballot_count(bool v)
{
	#ifndef SPIRV_VULKAN
	uint cnt = bitCount(ballotThreadNV(v));
	#else
	uint cnt;
	{
		uvec4 ballot = subgroupBallot(v);
		cnt  = bitCount(ballot.x);
		cnt += bitCount(ballot.y);
	}
	#endif

	return cnt;
}

int findClosestDDAMultibounce(inout ray_state state, int skip_fi, out int closest_fi, out float closest_it, int max_bounces)
{ 
	dda dda;
	build_dda(dda, state.origin, state.dir, 0.0, rt_setup.trace_range, false);

	closest_fi = -1;
	closest_it = 1000000.0;

	// walk through each cell of the grid and test for an intersection if
	// current cell contains geometry
	bool inside = false;
	bool prev_inside = inside;

	int max_iter = MAX_TRACE_LENGTH;                        // this includes all boundes we are now tracking

	while(max_iter >= 0)
	{
		bool hit = false;
		max_iter--;
        state.tests += int16_t(1);

		if (state.running == false)
			break;
		
		float rt = get_current_intersection_dda(dda, state);
		prev_inside = prev_inside || inside;
		inside = is_inside_grid_dda(dda);

		if (false && max_bounces > 1)
		{
			bool change_to_high = (inside == false || (inside == true && dda.is_high == false && is_high_level_empty_dda(dda)));
			bool change_to_low = (inside == true && (dda.is_high == true && is_high_level_empty_dda(dda) == false));

			uint change_to_high_ballot_cnt = ballot_count(change_to_high);

			// change to low is always executed!
			if (change_to_high_ballot_cnt < 16)
				change_to_high = false;

			if (change_to_high)
			{
				state.active_threads_factor += ballot_count(true);
				state.active_threads_samples += 1;

				// go to high res
				dda.is_high = true;
				float change = 4.0;
				rt = dda_scale_grid_res(dda, state, 4.0);
			}
			else if (change_to_low)
			{
				state.active_threads_factor += ballot_count(true);
				state.active_threads_samples += 1;

				// go to low res
				dda.is_high = false;
				float change = 0.25;
				rt = dda_scale_grid_res(dda, state, 0.25);
			}
		}

		bool search_bucket = false;
		ivec3 icell = ivec3(0);
		uint icell_idx = 0;
		int bucket_size = 0;

		// calculate cell index and check for intersections if valid (we don't do grid-box intersection yet)
        // check if we can skip whole high level cell. first the naive way
        if (inside)
		{
			if (dda.is_high == false)
			{
				icell = grd_icell_dda(dda);
				icell_idx = icell.z * GRID_RES * GRID_RES + icell.y * GRID_RES + icell.x;
#if 1
				//int bucket_size = fetch_grid_marker_for_cell(icell, 0);
				int bucket_size_4 = fetch_grid_marker_for_cell(icell, 2);
				bucket_size = 0;
				if (bucket_size_4 > 0)
					bucket_size = fetch_grid_marker_for_cell(icell, 0);
#else
				bucket_size = fetch_grid_marker_for_cell(icell, 0);
#endif
				if (bucket_size > 0)
					search_bucket = true;
			}
		}

		// try to advance also other threads to increase coherency. this is mostly to skip empty space and sync threads
		if (true)
		{
			bool done_criteria = (search_bucket == true) || (prev_inside == true && inside == false);
			uint hit_threads = ballot_count(done_criteria);
			int ii = 0;
			while(hit_threads < 32 && ii < 4)
			{
				if (search_bucket == false)
				{
					max_iter--;
					step_dda(dda);
					rt = get_current_intersection_dda(dda, state); // this not only calculates 'rt' but updates prev_t and pf for call calculation

					inside = is_inside_grid_dda(dda);
					prev_inside = prev_inside || inside;

					if (inside)
					{
						icell = grd_icell_dda(dda);
						icell_idx = icell.z * GRID_RES * GRID_RES + icell.y * GRID_RES + icell.x;

						int bucket_size_4 = fetch_grid_marker_for_cell(icell, 2);
						bucket_size = 0;
						if (bucket_size_4 > 0)
							bucket_size = fetch_grid_marker_for_cell(icell, 0);

						if (bucket_size > 0)
							search_bucket = true;
					}
				}

				ii++;
				done_criteria = (search_bucket == true) || (prev_inside == true && inside == false);
				hit_threads = ballot_count(done_criteria);
			}
		}

		if (search_bucket)
		{
			f16vec2 closest_bc;
			f16vec3 closest_normal;

			state.active_threads_factor += ballot_count(true);
			state.active_threads_samples += 1;

			#ifndef USE_LINKED_LISTS
			int bucket_offset = int(in_buckets.offsets[icell_idx]);
			findClosestBucket2(bucket_offset, bucket_size, state.origin, state.dir, rt, skip_fi, closest_fi, closest_it, closest_normal);
			//findClosestBucket(icell_idx, state.origin, state.dir, skip_fi, closest_fi, closest_it, state.normal);
			#else
			findClosestBucket2(icell_idx, bucket_size, state.origin, state.dir, skip_fi, rt, closest_fi, closest_it, closest_normal, closest_bc);
			#endif

			if (closest_fi != -1)
			{
				state.bounces += int16_t(1);
				state.normal = closest_normal;

				hit = true;
				skip_fi = closest_fi;

				//
				int hit_material = rt_get_triangle_material(closest_fi);

				if (state.bounces >= max_bounces || (materials.material_properties[hit_material].flags & MaterialFlag_RaytraceTerminate) != 0)
				{
					state.running = false;
					//state.hit = true;
					//break;
				}

				f16vec2 bc = barycentric_for_face_yz(closest_fi, state.origin + state.dir * closest_it);
				if ((materials.material_properties[hit_material].flags & MaterialFlag_Flat) == 0)
					state.normal = interpolate_normal_from_bc_yz(closest_fi, bc, state.normal);

				// calculate new origin and recalculate tracing parameters for the bounce
				vec3 prev_state_origin = state.origin; // for lighting calculation
				state.origin = state.origin + state.dir * closest_it;

				// we hit solid object which is not perfectly rough, reflect
				evaluate_material(state, prev_state_origin, int16_t(rt_get_triangle_material(closest_fi)), closest_fi, bc);

				//state.origin += state.dir * 50.2;

				if (max_bounces > 1 && state.running)
					build_dda(dda, state.origin, state.dir, 0.0, rt_setup.trace_range, true);
			}
		}

		if (inside == false && prev_inside == true)
		{
			state.running = false;
			state.left = true;
		}

		if (!hit)
		{
			step_dda(dda);
			if (dda_is_abort(dda))
				state.running = false;
		}
	} 

	return 0;
} 
#endif

//#define findClosest findClosestNaive
#define findClosest findClosestDDA

vec3 TurboColormap(in float x)
{
  const vec4 kRedVec4 = vec4(0.13572138, 4.61539260, -42.66032258, 132.13108234);
  const vec4 kGreenVec4 = vec4(0.09140261, 2.19418839, 4.84296658, -14.18503333);
  const vec4 kBlueVec4 = vec4(0.10667330, 12.64194608, -60.58204836, 110.36276771);
  const vec2 kRedVec2 = vec2(-152.94239396, 59.28637943);
  const vec2 kGreenVec2 = vec2(4.27729857, 2.82956604);
  const vec2 kBlueVec2 = vec2(-89.90310912, 27.34824973);
  
  x = clamp(x, 0.0, 1.0);
  vec4 v4 = vec4( 1.0, x, x * x, x * x * x);
  vec2 v2 = v4.zw * v4.z;
  return vec3(
    dot(v4, kRedVec4)   + dot(v2, kRedVec2),
    dot(v4, kGreenVec4) + dot(v2, kGreenVec2),
    dot(v4, kBlueVec4)  + dot(v2, kBlueVec2)
  );
}

void main() {

#ifndef RAYTRACE_PASS
#else

	#ifndef GLASS_REFLECTION_PASS
	//outAlbedo = vec4(1.0);
	//return;
	#endif

	ivec2 scaled_sample_pos = ivec2(gl_FragCoord.xy) * rt_setup.screen_sampling_scale;
	ivec2 native_sample_pos =  ivec2(gl_FragCoord.xy);

#ifdef SPIRV_VULKAN
	//scaled_sample_pos.y = 1080 - scaled_sample_pos.y;
	//native_sample_pos.y = 1080 - native_sample_pos.y;
#endif

	vec4 color = vec4(1.0);

	// for now also trace to the point we just hit. we would need some way to identify the face we are rasterizing
	// 

	vec3 dir = -normalize(transform_params.vCameraPosition.xyz - vtx_input.vWorldPos.xyz);
	vec3 origin = transform_params.vCameraPosition.xyz;
	f16vec3 normal;

	int closest_fi = -1;
	bool have_initial_face_index = false;
	int16_t material = int16_t(0);

	// if not running with stencil we simply discard based on material
	{
		float metalness;
		float roughness;
		uint materialIndex;
		decode_metalness_roughness_material(imageLoad(imMetalnessRoughnessMaterialTags, scaled_sample_pos).rg, metalness, roughness, materialIndex);
		//MaterialPropertiesGPU material = materials.material_properties[materialIndex];
		if (materialIndex == 1) // || materials.material_properties[materialIndex].raytrace == 0.0)
		{
			//discard;
			//return;
		}

		material = int16_t(materialIndex);
	}

	vec3 view_direction;
	view_direction.x = -rt_setup.camera_projection_params.x * 0.5 + rt_setup.camera_projection_params.x * scaled_sample_pos.x / 1920.0;
	view_direction.y = -rt_setup.camera_projection_params.y * 0.5 + rt_setup.camera_projection_params.y * scaled_sample_pos.y / 1080.0;
	view_direction.z = 1.0;

#ifdef SPIRV_VULKAN
	view_direction.y = -view_direction.y;
#endif

	float depth = linearizeDepth(texelFetch(sTextureDepth, native_sample_pos, 0).r);
	vec3 view_coords = positionFromDepth(view_direction, depth);
	view_coords = (rt_setup.mat_model * vec4(view_coords, 1.0)).xyz;

	dir = -normalize(transform_params.vCameraPosition.xyz - view_coords.xyz);

	//outAlbedo.rgb = vec3(fract(view_coords.xyz * 0.1));
	//outAlbedo.rgb = vec3(fract(dir.xyz));
	//outAlbedo.rgb = vec3(TurboColormap(fract(float(closest_fi) * 0.001)));
	//return;

	float closest_it;

	// TODO: optimize with simple length

	f16vec3 worldNorm;
	{
		closest_it = length(origin - view_coords.xyz);
		uint encoded_normal_material = imageLoad(imNormalMaterial, scaled_sample_pos).r;
		normal = f16vec3(normalize(decode_normal(encoded_normal_material)));
		worldNorm = normal;

		//outNormalMaterial = encode_normal_material(normalize(normal), 0);
		//outAlbedo.rgb = normal.rgb * 2.5 + 0.5;
		//return;
	}

#ifndef GLASS_REFLECTION_PASS
	{
		vec3 ro = origin + dir * closest_it - in_bbox_data.bbox_raytrace_min.xyz;
		ivec3 icell = ivec3(floor(ro / vec3(GRID_SIZE)));
		if (icell.x >=0 && icell.y >= 0 && icell.z >= 0 && icell.x < GRID_RES && icell.y < GRID_RES && icell.z < GRID_RES)
		{
			uint icell_idx = icell.z * (GRID_RES * GRID_RES) + icell.y * GRID_RES + icell.x;
			//outAlbedo.rgb = vec3(in_buckets.sizes[icell_idx]) / 10.0;
			//return;
		}
		else
		{

		}
	}
#endif

#ifdef GLASS_REFLECTION_PASS
	outFresnelReflection = vec4(0.0);
#endif

	{
		origin = origin + dir * closest_it;

		// NOTE: Start with albedo color. lighting for the primary hit is calculated in normal lighting pass
		// not here, so it will be multiplied later

		ray_state state;

		{
			// albedo holds opacity. 0 = 100% transparent
			vec4 albedo_value = imageLoad(imAlbedo, scaled_sample_pos).rgba;
			state.color = f16vec3(albedo_value.rgb);
			state.transparency = float16_t(materials.material_properties[material].transparency == 0.0 ? 1.0 : materials.material_properties[material].transparency);
			state.transparency = 1.0 - (1.0 - state.transparency) * albedo_value.a;
		}
		state.normal = normal;
		state.material = material;
		state.running = true;
		state.dir = dir;
		state.origin = origin;
		state.bounces = int16_t(1);
		state.hit = false;
		state.left = false;
		state.tests = int16_t(0);
		state.active_threads_factor = 0;
		state.active_threads_samples = 0;

		if (materials.material_properties[material].roughness == 1.0 && (materials.material_properties[material].flags & MaterialFlag_Transparent) == 0)
		{
			dir = reflect(dir, vec3(normal));
			//state.color = materials.material_properties[material].diffuse.rgb;
			state.hit = true;
		}
		else
		{
			// glass pass:
			// only one look for the reflection. we will limit fresnel so it is not too visible
			if ((materials.material_properties[material].flags & MaterialFlag_Transparent) != 0)
			{

#ifdef GLASS_REFLECTION_PASS

				float fresnel = pow(1.0f - max(0.0f, dot(vec3(worldNorm), normalize(transform_params.vCameraPosition - view_coords.xyz))), 4.0f);

				vec3 glass_reflection = vec3(0.0f);
				const float fresnel_threshold = 0.0f;
				if (fresnel > fresnel_threshold) // no point tracing below that
				{
					vec3 ref_dir = reflect(dir, vec3(normal));
					int ref_closest_fi;
					float ref_closest_it;

					state.dir = ref_dir;
					if (!have_initial_face_index)
						state.origin += state.dir * INITIAL_FACE_START_DISTANCE;

					findClosestDDAMultibounce(state, closest_fi, ref_closest_fi, ref_closest_it, 1);

					if (ref_closest_fi != -1)
						glass_reflection = state.color;
					else
						glass_reflection = sample_env_map(normal, materials.material_properties[state.material].roughness).rgb;

					fresnel = (fresnel - fresnel_threshold) / (1.0f - fresnel_threshold); // normalize
				}

				outFresnelReflection = vec4(glass_reflection.rgb, fresnel);
				return;
#endif

			}

#ifdef GLASS_REFLECTION_PASS
			return;
#endif

			if ((materials.material_properties[state.material].flags & MaterialFlag_Transparent) == 0)
			{
				// this is initial sample and this is only place where for now we apply roughness
				state.dir = reflect(state.dir, vec3(state.normal));

				#if 1
				// NOTE: This is REALLY costly when divergens goes to hell, so for now because we don't cluster
				// rays just try to limit the roughness...
				if (materials.material_properties[state.material].roughness > 0.0f)
				{
					const float golden_ratio = 1.61803398875;
					int frame = globals.monotonic;
					float clamped_roughness = materials.material_properties[state.material].roughness;
					//clamped_roughness = clamped_roughness * clamped_roughness;
					clamped_roughness = min(rt_setup.roughness_clamp, clamped_roughness * clamped_roughness);

					ivec2 screen_pos = native_sample_pos % ivec2(128);
					vec2 hash = fract(texelFetch(s_BlueNoise, ivec3(screen_pos.xy, 0), 0).rg + frame * golden_ratio);
					vec3 d = CosineSampleHemisphere(hash.x * clamped_roughness, hash.y);
					mat3 vecSpace = matrixFromVector(state.dir);
					state.dir = vecSpace * d;
				}
				#endif

				#ifndef GLASS_REFLECTION_PASS
				{
					//vec3 ro = origin + dir * closest_it - in_bbox_data.bbox_raytrace_min.xyz;
					//ivec3 icell = ivec3(floor(ro / vec3(GRID_SIZE)));
					//if (icell.x >=0 && icell.y >= 0 && icell.z >= 0 && icell.x < GRID_RES && icell.y < GRID_RES && icell.z < GRID_RES)
					//{
						//uint icell_idx = icell.z * (GRID_RES * GRID_RES) + icell.y * GRID_RES + icell.x;
						//outAlbedo.rgb = vec3(in_buckets.sizes[icell_idx]) / 10.0f;
						//return;
					//}
					//outAlbedo.rgb = sample_env_map(state.dir, 1.0f).rgb;
					//outAlbedo.rgb = state.dir;
					//return;
				}
				#endif

			}
			else //if (materials.material_properties[state.material].transparency > 0.0f)
			{
				state.dir = glass_refract(state.dir, -state.normal);
			}

			{
				//vec3 color = TurboColormap(float(state.tests) / 256.0f);
				//outAlbedo.rgb = color;
				//return;
			}

            dir = state.dir;

			if (!have_initial_face_index)
				state.origin += state.dir * INITIAL_FACE_START_DISTANCE;

            findClosestDDAMultibounce(state, closest_fi, closest_fi, closest_it, MAX_BOUNCES);

            if (state.hit == false && state.bounces <= int16_t(1))
            {
				float current_roughness = materials.material_properties[material].roughness;
                if ((materials.material_properties[material].flags & MaterialFlag_Transparent) != 0)
                {
                    //state.color = vec4(state.transparency); //mix(state.color, sample_env_map(state.dir), mix(factor, 1.0f, state.transparency));
                    //state.color = state.color;//mix(state.color, sample_env_map(state.dir), mix(factor, 1.0f, state.transparency));
                    //state.color = mix(state.color, mix(sample_env_map(state.dir), state.color, materials.material_properties[material].roughness), state.transparency);
					state.color = mix(state.color, sample_env_map(state.dir, current_roughness).rgb, state.transparency);
                }
                else
                {
                   //state.color = vec4(state.transparency); //mix(state.color, sample_env_map(state.dir), mix(factor, 1.0f, state.transparency));
                   
					f16vec3 mat_color = f16vec3(mix(
						materials.material_properties[material].diffuse.rgb * sample_env_map(dir, current_roughness).rgb,
						materials.material_properties[material].diffuse.rgb,
						materials.material_properties[material].roughness
					));
                   //state.color = mix(state.color, state.color * sample_env_map(state.dir, current_roughness).rgb, 1.0f - state.transparency);
                   state.color = f16vec3(mix(state.color, state.color * sample_env_map(state.dir, current_roughness).rgb, state.transparency));
                   state.color = f16vec3(mix(mat_color, state.color, state.transparency));
                }
            }
            else
            {
                //state.color = vec4(1.0f, 0.0f, 1.0f, 0.0f);
                //state.color = mix(state.color, materials.material_properties[material].diffuse, 1.0f - materials.material_properties[material].transparency);
            }

			// state.color = vec4(float(state.bounces) / 10.0f);
			//state.color = mix(state.color, materials.material_properties[material].diffuse, 1.0f - materials.material_properties[material].transparency);

		}

		color.rgb = state.color;
		//color.rgb = vec3(state.transparency);
		//color.rgb = state.normal;// * 0.5 + 0.5;
		//color.rgb = TurboColormap(fract(float(state.material) / 10.0f));

		//#ifdef VISUALIZE_HEATMAP
		
		//color.rgb = TurboColormap(float(state.tests) / 256.0f);
		//if (state.tests < 32)
		{
			//color.rgb = vec3(0.0f, 1.0f, 0.0f);
		}
		
		//#endif

		//color.rgb = TurboColormap(float(state.active_threads_factor) / (state.active_threads_samples * 32));

#if 0
		int bi = state.bounces;
		if (bi == 0)
			color.rgb = vec3(1.0f);
		if (bi == 1)
			color.rgb = vec3(1.0f, 0.0f, 0.0f);
		if (bi == 2)
			color.rgb = vec3(0.0f, 1.0f, 0.0f);
		if (bi == 3)
			color.rgb = vec3(0.0f, 0.0f, 1.0f);
		if (bi > 3)
			color.rgb = vec3(0.5f, 0.0f, 1.0f);
#endif

		//color.rgb = vec3(state.transparency);
		#ifndef GLASS_REFLECTION_PASS

		if (materials.material_properties[material].transparency == 0.0f)
			outAlbedo.a = 0.0f;
		else
			outAlbedo.a = state.transparency;

		#endif
	}

#ifndef GLASS_REFLECTION_PASS

	vec4 glass_reflection = texelFetch(sFresnelReflection, native_sample_pos, 0);
	//color.rgb = mix(color.rgb, glass_reflection.rgb, vec3(glass_reflection.a));
	color.rgb = color.rgb + glass_reflection.rgb * vec3(glass_reflection.a);
	outAlbedo.rgb = color.rgb;
#endif

#endif

}

