#ifndef TAA_SUPPORT_H
#define TAA_SUPPORT_H

#include <shaders/commons_hlsl.glsl>

// by Matt Pettineo https://gist.github.com/TheRealMJP/c83b8c0f46b63f3a88a5986f4fa982b1
float4 SampleTextureCatmullRom(sampler2D tex, in float2 uv, in float2 texSize)
{
	// We're going to sample a a 4x4 grid of texels surrounding the target UV coordinate. We'll do this by rounding
	// down the sample location to get the exact center of our "starting" texel. The starting texel will be at
	// location [1, 1] in the grid, where [0, 0] is the top left corner.
	float2 samplePos = uv * texSize;
	float2 texPos1 = floor(samplePos - 0.5f) + 0.5f;

	// Compute the fractional offset from our starting texel to our original sample location, which we'll
	// feed into the Catmull-Rom spline function to get our filter weights.
	float2 f = samplePos - texPos1;

	// Compute the Catmull-Rom weights using the fractional offset that we calculated earlier.
	// These equations are pre-expanded based on our knowledge of where the texels will be located,
	// which lets us avoid having to evaluate a piece-wise function.
	float2 w0 = f * (-0.5f + f * (1.0f - 0.5f * f));
	float2 w1 = 1.0f + f * f * (-2.5f + 1.5f * f);
	float2 w2 = f * (0.5f + f * (2.0f - 1.5f * f));
	float2 w3 = f * f * (-0.5f + 0.5f * f);

	// Work out weighting factors and sampling offsets that will let us use bilinear filtering to
	// simultaneously evaluate the middle 2 samples from the 4x4 grid.
	float2 w12 = w1 + w2;
	float2 offset12 = w2 / (w1 + w2);

	// Compute the final UV coordinates we'll use for sampling the texture
	float2 texPos0 = texPos1 - 1;
	float2 texPos3 = texPos1 + 2;
	float2 texPos12 = texPos1 + offset12;

	texPos0 /= texSize;
	texPos3 /= texSize;
	texPos12 /= texSize;

	float4 result = float4(0.0);
	result += textureLod(tex, float2(texPos0.x, texPos0.y), 0.0f) * w0.x * w0.y;
	result += textureLod(tex, float2(texPos12.x, texPos0.y), 0.0f) * w12.x * w0.y;
	result += textureLod(tex, float2(texPos3.x, texPos0.y), 0.0f) * w3.x * w0.y;

	result += textureLod(tex, float2(texPos0.x, texPos12.y), 0.0f) * w0.x * w12.y;
	result += textureLod(tex, float2(texPos12.x, texPos12.y), 0.0f) * w12.x * w12.y;
	result += textureLod(tex, float2(texPos3.x, texPos12.y), 0.0f) * w3.x * w12.y;

	result += textureLod(tex, float2(texPos0.x, texPos3.y), 0.0f) * w0.x * w3.y;
	result += textureLod(tex, float2(texPos12.x, texPos3.y), 0.0f) * w12.x * w3.y;
	result += textureLod(tex, float2(texPos3.x, texPos3.y), 0.0f) * w3.x * w3.y;

	return result;
}

float FilterCubic(in float x, in float B, in float C)
{
	float y = 0.0f;
	float x2 = x * x;
	float x3 = x * x * x;
	if(x < 1)
		y = (12 - 9 * B - 6 * C) * x3 + (-18 + 12 * B + 6 * C) * x2 + (6 - 2 * B);
	else if (x <= 2)
		y = (-B - 6 * C) * x3 + (6 * B + 30 * C) * x2 + (-12 * B - 48 * C) * x + (8 * B + 24 * C);

	return y / 6.0f;
}

float Mitchell(in float x)
{
	bool rescaleCubic = true;
	float cubicX = rescaleCubic ? x * 2.0f : x;
	return FilterCubic(cubicX, 1.0 / 3.0f, 1.0 / 3.0f);
}

float Mitchell2(in float x)
{
	bool rescaleCubic = false;
	float cubicX = rescaleCubic ? x * 2.0f : x;
	return FilterCubic(cubicX, 1.0 / 3.0f, 1.0 / 3.0f);
}

const float CLIP_AABB_FLT_EPS = 0.00000001f;

float3 clip_aabb(float3 aabb_min, float3 aabb_max, float3 p, float3 q)
{

	float3 r = q - p;
	float3 rmax = aabb_max - p.xyz;
	float3 rmin = aabb_min - p.xyz;

	const float eps = CLIP_AABB_FLT_EPS;

	if (r.x > rmax.x + eps)
		r *= (rmax.x / r.x);
	if (r.y > rmax.y + eps)
		r *= (rmax.y / r.y);
	if (r.z > rmax.z + eps)
		r *= (rmax.z / r.z);

	if (r.x < rmin.x - eps)
		r *= (rmin.x / r.x);
	if (r.y < rmin.y - eps)
		r *= (rmin.y / r.y);
	if (r.z < rmin.z - eps)
		r *= (rmin.z / r.z);

	return p + r;
}

// From "Temporal Reprojection Anti-Aliasing"
// https://github.com/playdeadgames/temporal
float3 ClipAABB(float3 aabbMin, float3 aabbMax, float3 prevSample, float3 avg)
{
    #if 1
        // note: only clips towards aabb center (but fast!)
        float3 p_clip = 0.5 * (aabbMax + aabbMin);
        float3 e_clip = 0.5 * (aabbMax - aabbMin);

        float3 v_clip = prevSample - p_clip;
        float3 v_unit = v_clip.xyz / e_clip;
        float3 a_unit = abs(v_unit);
        float ma_unit = max(a_unit.x, max(a_unit.y, a_unit.z));

        if (ma_unit > 1.0)
            return p_clip + v_clip / ma_unit;
        else
            return prevSample;// point inside aabb
    #else
        float3 r = prevSample - avg;
        float3 rmax = aabbMax - avg.xyz;
        float3 rmin = aabbMin - avg.xyz;

        const float eps = 0.000001f;

        if (r.x > rmax.x + eps)
            r *= (rmax.x / r.x);
        if (r.y > rmax.y + eps)
            r *= (rmax.y / r.y);
        if (r.z > rmax.z + eps)
            r *= (rmax.z / r.z);

        if (r.x < rmin.x - eps)
            r *= (rmin.x / r.x);
        if (r.y < rmin.y - eps)
            r *= (rmin.y / r.y);
        if (r.z < rmin.z - eps)
            r *= (rmin.z / r.z);

        return avg + r;
    #endif
}

float ClipAABBAlpha(float aabbMin, float aabbMax, float prevSample, float avg)
{
    #if 1
        // note: only clips towards aabb center (but fast!)
        float p_clip = 0.5 * (aabbMax + aabbMin);
        float e_clip = 0.5 * (aabbMax - aabbMin);

        float v_clip = prevSample - p_clip;
        float v_unit = v_clip / e_clip;
        float a_unit = abs(v_unit);
        float ma_unit = a_unit;

        if (ma_unit > 1.0)
            return p_clip + v_clip / ma_unit;
        else
            return prevSample;// point inside aabb
    #else
        float r = prevSample - avg;
        float rmax = aabbMax - avg;
        float rmin = aabbMin - avg;

        const float eps = 0.000001f;

        if (r > rmax + eps)
            r *= (rmax / r);

        if (r < rmin - eps)
            r *= (rmin / r);

        return avg + r;
    #endif
}

#endif
