#version 410 core

uniform vec2 resolution;
uniform sampler2D farColorTex;
uniform sampler2D nearColorTex;
uniform sampler2D cocTex;

uniform float polygonSides;
uniform float polygonAmount;
uniform float maxCoc;

in vec2 texCoords;

layout(location = 0) out vec4 outFarColor;
layout(location = 1) out vec4 outNearColor;

// TODO: Proper struct return
vec3 unitSquareToNGon(vec2 p, float n, float amount)
{
    float a = p.x * 2.0 - 1.0;
    float b = p.y * 2.0 - 1.0;

    float pi = 3.141592;

    float r, theta;
    if (a > -b)
    {
        if (a > b)
        {
            r = a;
            theta = (pi / 4.0) * (b / a);
        }
        else
        {
            r = b;
            theta = (pi / 4.0) * (2.0 - (a / b));
        }
    }
    else
    {
        if (a < b)
        {
            r = -a;
            theta = (pi / 4.0) * (4.0 + (b / a));
        }
        else
        {
            r = -b;
            if (b != 0.0)
            {
                theta = (pi / 4.0) * (6.0 - (a / b));
            }
            else
            {
                theta = 0.0;
            }
        }
    }

    float circleRadius = r;

    r *= mix(1.0, cos(pi / n) / cos(theta - (2.0 * pi / n) * floor((n * theta + pi) / (2.0 * pi))), amount);
    // This is just so that the shape isn't aligned to an axis, which looks a bit nicer
    theta += .6;

    float u = r * cos(theta);
    float v = r * sin(theta);
    return vec3(u, v, circleRadius);
}

void main()
{
    vec2 currentPixelUv = texCoords;
    float currentPixelCoc = texture(cocTex, currentPixelUv).x;
    vec3 currentPixelColor = texture(farColorTex, currentPixelUv).xyz;

    // TODO: Uniforms, f-stop etc
    const int blurRes = 10;

    // Far field
    vec3 farColor = vec3(0.0);

    if (currentPixelCoc >= 0.0)
    {
        farColor = currentPixelColor;
    }
    else
    {
        // Separable gather with weighted CoC
        float farColorDiv = 0.0;
        for (int y = 0; y < blurRes; y++)
        {
            for (int x = 0; x < blurRes; x++)
            {
                vec2 unitSquareOffset = vec2(float(x), float(y)) / float(blurRes - 1);
                vec3 nGonOffsetAndRadius = unitSquareToNGon(unitSquareOffset, polygonSides, polygonAmount);
                vec2 offset = nGonOffsetAndRadius.xy * -currentPixelCoc * maxCoc * .6;
                vec2 uv = texCoords + offset / resolution;

                float coc = texture(cocTex, uv).x;
                if (coc < 0.0)
                {
                    float weight = coc <= currentPixelCoc ? 1.0 : -coc;
                    farColor += texture(farColorTex, uv).xyz * weight;
                    farColorDiv += weight;
                }
            }
        }
        if (farColorDiv > 0.0)
            farColor /= farColorDiv;
    }

    // Near field
    vec3 nearColor = vec3(0.0);

    // Scatter-as-gather approximation
    float nearColorDiv = 0.0;
    float nearColorAlpha = 0.0;
    for (int y = 0; y < blurRes; y++)
    {
        for (int x = 0; x < blurRes; x++)
        {
            vec2 unitSquareOffset = vec2(float(x), float(y)) / float(blurRes - 1);
            vec3 nGonOffsetAndRadius = unitSquareToNGon(unitSquareOffset, polygonSides, polygonAmount);
            vec2 offset = nGonOffsetAndRadius.xy * maxCoc * .6;
            float radius = nGonOffsetAndRadius.z;
            vec2 uv = texCoords + offset / resolution;

            float coc = texture(cocTex, uv).x;
            if (coc >= radius)
            {
                nearColor += texture(nearColorTex, uv).xyz;
                nearColorDiv += 1.0;
                nearColorAlpha += clamp(1.0 / coc / float(blurRes * blurRes), 0.0, 1.0);
            }
        }
    }
    if (nearColorDiv > 0.0)
        nearColor /= nearColorDiv;

    outFarColor = vec4(farColor, 1.0);
    outNearColor = vec4(nearColor, nearColorAlpha);
}
