#version 450

layout(location = 0) in vec2 FragCoord;
layout(location = 0) out vec4 FragColor;

layout(set = 3, binding = 0) uniform PushConstants {
    float u_Time;
    vec2 u_Resolution;
};

#define EPOCH 128.
#define EPSILON 0.001
#define PI 3.14159265
#define FOV 90.

#define RI_VACUUM 1.
#define RI_WATER 1.33

#define LIGHT_TRANSITION 16.
#define SKY_COLOR1 vec3(0.4, 0.2, 0.1) * 10.
#define SKY_COLOR2 vec3(1.5, 1.4, 1.3) * 3.

#include <noise.glsl>
#include <rotation.glsl>
#include <sdf.glsl>

const vec3 MTL_COLORS[] = vec3[](
        vec3(0.1, 0.2, 0.4), // Water attenuation
        vec3(0.9),
        vec3(0.56, 0.57, 0.58)
    );

// x = roughness, y = metalness, z = reflectance
const vec3 MTL_PARAMS[] = vec3[](
        vec3(6.), // Water attenuation distance multiplier
        vec3(0.1, 0., 0.9),
        vec3(0.9, 0., 0.1)
    );

float iteration() {
    return floor(u_Time / EPOCH);
}

vec3 cam_target() {
    return vec3(0., 0., -2.);
}

vec3 cam_pos() {
    vec3 sway = vec3(
            fbm(0.3 * u_Time / EPOCH),
            fbm(0.3 * u_Time / EPOCH + 5.),
            fbm(0.3 * u_Time / EPOCH + 11.)
        ) * 2. - 1.;
    sway *= 0.5;
    vec3 base = vec3(
            noise(iteration()),
            noise(iteration() + 3458.),
            noise(iteration() + 45.)
        ) * 2. - 1.;
    base.xy *= min(iteration(), 4.);
    base.z -= 3.;
    return base + sway;
}

float cam_fov() {
    return FOV;
}

float aspectRatio() {
    return u_Resolution.x / u_Resolution.y;
}

mat3 viewMatrix() {
    vec3 f = normalize(cam_target() - cam_pos());
    vec3 s = -normalize(cross(f, vec3(0., 1., 0.)));
    vec3 u = cross(s, f);
    return mat3(s, u, f);
}

vec3 cameraRay() {
    float c = tan((90. - cam_fov() / 2.) * (PI / 180.));
    return normalize(vec3(FragCoord * vec2(aspectRatio(), 1.), c));
}

float sdWallDisplacement(vec3 p) {
    vec2 uv = p.xy / 8.;
    p.z -= clamp(2. - iteration(), 0., 1.);
    p.xy = repetition(vec3(uv, p.z), vec3(noise(iteration()), noise(iteration()), 1.)).xy;
    p.z -= noise(uv + iteration() * 12.);
    float o1 = sdOctahedron(p, 1.);
    if (iteration() > 12. && iteration() < 16.) {
        o1 = min(o1, sdOctahedron(pitch(u_Time) * (p / 1.4), fbm(u_Time + floor(uv.x)) * 2.));
    }
    return o1;
}

vec2 sdWall(vec3 p) {
    float freq = 2. * PI * 0.25;
    float stripes = clamp(sin(p.y * freq) * sin(p.x * freq) * 1024. * 1024., 0.0, 1.);
    return vec2(min(sdPlaneXY(p), sdWallDisplacement(p)), stripes + 1.);
}

vec2 sdBlob(vec3 p) {
    p.x += noise(p + u_Time * 0.01);
    p -= cam_target();
    p.y -= 2048. * step(-3., -iteration()); // Hide when not time for it to show

    // Timing hack
    if (iteration() > 7. && iteration() < 10.) {
        p.y /= 128.;
    }

    if (iteration() > 12.) {
        p.y += noise(p + u_Time * 0.006);
    }

    return vec2(sdSphere(p, 1.), 0.);
}

vec2 sdf(vec3 p) {
    return opUnion(sdBlob(p), sdWall(p));
}

#include <march.glsl>

vec3 sky(vec3 v) {
    vec3 c1 = SKY_COLOR1;
    vec3 c2 = SKY_COLOR2;
    if (iteration() - 4 > LIGHT_TRANSITION) {
        c1 = SKY_COLOR2;
        c2 = SKY_COLOR1;
        c2 *= 0.001;
        c1 *= 0.001;
    }
    return mix(c1, c2, v.y * 0.5 + 0.5);
}

vec3 fresnelSchlick(float cosTheta, vec3 F0) {
    return F0 + (1. - F0) * pow(1.0 - cosTheta, 5.);
}

float D_GGX(float NdotH, float roughness) {
    float alpha = roughness * roughness;
    float alpha2 = alpha * alpha;
    float NdotH2 = NdotH * NdotH;
    float b = NdotH2 * (alpha2 - 1.) + 1.;
    return alpha2 / (PI * b * b);
}

float G1_GGX_Schlick(float NdotV, float k) {
    return max(NdotV, EPSILON) / (NdotV * (1. - k) + k);
}

float G_Smith(float NdotV, float NdotL, float roughness) {
    float alpha = roughness * roughness;
    float k = alpha / 2.;
    return G1_GGX_Schlick(NdotL, k) * G1_GGX_Schlick(NdotV, k);
}

vec3 brdf(vec3 L, vec3 V, vec3 N, float metallic, float roughness, vec3 baseColor, float reflectance) {
    // Cook-Torrance Microfacet BRDF
    // is a sum of diffuse and a specular part.
    // Specular is a product of Fresnel reflectance,
    // normal distribution function and a geomertry term (microfacet shadowing)
    // divided by the product of n dot l and n dot v.

    vec3 H = normalize(V + L);
    float NdotV = clamp(dot(N, V), EPSILON, 1.0);
    float NdotL = clamp(dot(N, L), EPSILON, 1.0);
    float NdotH = clamp(dot(N, H), EPSILON, 1.0);
    float VdotH = clamp(dot(V, H), EPSILON, 1.0);
    float LdotV = clamp(dot(L, V), EPSILON, 1.0);

    vec3 f0 = vec3(0.16 * (reflectance * reflectance));
    f0 = mix(f0, baseColor, metallic);

    vec3 F = fresnelSchlick(VdotH, f0);
    float D = D_GGX(NdotH, roughness);
    float G = G_Smith(NdotV, NdotL, roughness);

    vec3 spec = (F * D * G) / (4. * max(NdotV, EPSILON) * max(NdotL, EPSILON));

    vec3 rhoD = baseColor;
    rhoD *= vec3(1.) - F;
    rhoD *= (1. - metallic);
    //https://github.com/ranjak/opengl-tutorial/blob/master/shaders/illumination/diramb_orennayar_pcn.vert
    float sigma = roughness;
    float sigma2 = sigma * sigma;
    float termA = 1.0 - 0.5 * sigma2 / (sigma2 + 0.57);
    float termB = 0.45 * sigma2 / (sigma2 + 0.09);
    float cosAzimuthSinaTanb = (LdotV - NdotV * NdotL) / max(NdotV, NdotL);
    vec3 diff = rhoD * (termA + termB * max(0.0, cosAzimuthSinaTanb)) / PI;
    //vec3 diff = rhoD / PI;

    return diff + spec;
}

// Compute light output for a world position
// Rendering Equation:
// Radiance out to view = Emitted radiance to view
// + integral (sort of like sum) over the whole hemisphere:
// brdf(v, l) * incoming irradiance (radiance per area)
vec3 light(vec3 pos, vec3 dir, vec3 n, vec3 l, vec3 lc, vec3 ga, int mtlID) {
    // No emissive surfaces
    vec3 albedo = MTL_COLORS[mtlID];
    vec3 params = clamp(MTL_PARAMS[mtlID], 0., 1.);
    // Light received by the surface
    vec3 irradiance = max(dot(l, n), 0.) * lc;
    irradiance += sky(dir) * 0.1;
    // Attenuate irradiance
    irradiance *= ga;
    // Compute BRDF
    vec3 brd = brdf(l, -dir, n, params.y, params.x, albedo, params.z);
    return irradiance * brd;
}

vec3 render(vec3 origin, vec3 dir, vec3 hit) {
    vec3 pos = origin + dir * hit.x;
    vec3 n = normal(pos);

    // Compute a mask for parts of the image that should be sky (ray didn't hit)
    // Ideally, this would be done with the shadow parameter (hit.z)
    // but a fog works well in this case for now
    float mask = clamp((hit.x + 256.) / 256. - 1., 0., 1.);

    vec3 radiance = vec3(0.);
    for (int i = 0; i < 16; i++) {
        float size = .5;
        vec3 lightColor = vec3(
                3. + noise(float(i + 435)) * 2.,
                3. + noise(float(i + 239)) * 1.4,
                3. + noise(float(i + 123)) * 1.
            ) * 4.;
        float t = u_Time * 0.003;
        float spread = 14.;
        vec3 lightPos = vec3(
                (fbm(float(i * 123) + t) * 2. - 1.) * spread,
                (fbm(float(i + 1) + t * 0.33) * 2. - 1.) * spread,
                fbm(float(i + 2) + t * 0.23) * -2. - 1.
            );

        vec3 lightToPos = pos - lightPos;
        float lightDistance = length(lightToPos);
        vec3 lightDir = normalize(lightToPos);
        float attenuation = size / (lightDistance * lightDistance);

        // SYNC hack
        if (iteration() < LIGHT_TRANSITION) {
            lightDir = normalize(vec3(2., 0., 1.));
            attenuation = 1.;
        }

        radiance += light(pos, dir, n, -lightDir, lightColor, vec3(attenuation), int(hit.y));

        // SYNC hack
        if (iteration() < LIGHT_TRANSITION) {
            break;
        }
        if (iteration() + 4 < LIGHT_TRANSITION) {
            mask = 0.;
        }

        // Show the point light in the air
        vec3 lightToOrigin = lightPos - origin;
        float distanceToOrigin = length(lightToOrigin);
        radiance += clamp((max(dot(dir, normalize(lightToOrigin)), 0.) - max(1. - (size * 0.01)
                                / pow(distanceToOrigin, 2.), 0.)) * float(1 << 18), 0., 1.)
                * lightColor * lightColor * 0.01;
    }
    return mix(radiance, sky(dir), mask);
}

vec3 water(vec3 origin, vec3 dir, vec3 hit) {
    vec3 pos = origin + dir * hit.x;
    vec3 n = normal(pos);

    // Reflectance
    float n_minus1 = RI_WATER - 1.;
    float n_plus1 = RI_WATER + 1.;
    float r0 = (n_minus1 * n_minus1) / (n_plus1 * n_plus1);
    vec3 fresnel = fresnelSchlick(dot(dir, -n), vec3(r0));

    // Reflection
    vec3 rfl_origin = pos;
    vec3 rfl_dir = reflect(dir, n);
    vec3 rfl_hit = march(rfl_origin, rfl_dir, vec3(0.1, 1024., 20.), 1.);
    vec3 reflected = render(rfl_origin, rfl_dir, rfl_hit) * fresnel;

    // Refraction into the object
    vec3 rfr_origin = pos;
    vec3 rfr_dir = refract(dir, n, RI_VACUUM / RI_WATER);
    vec3 rfr_hit = march(rfr_origin, rfr_dir, vec3(0.1, 1024., 20.), -1.);

    // Back out of the object
    vec3 rfr_origin2 = rfr_origin + rfr_dir * rfr_hit.x;
    vec3 rfr_dir2 = refract(rfr_dir, -normal(rfr_origin2), RI_WATER / RI_VACUUM);
    // Hack for when refract returns 0
    rfr_dir2 += step(-EPSILON, -dot(rfr_dir2, rfr_dir2)) * rfr_dir;
    vec3 rfr_hit2 = march(rfr_origin2, rfr_dir2, vec3(2., 1024., 20.), 1.);

    vec3 refracted = render(rfr_origin2, rfr_dir2, rfr_hit2);

    return refracted * exp((1. - MTL_COLORS[0]) * 0.15 * -(rfr_hit.x * MTL_PARAMS[0].x)) + reflected;
}

void main() {
    vec3 origin = cam_pos();
    vec3 dir = viewMatrix() * cameraRay();

    // Spheretrace all surfaces in view
    vec3 hit = march(origin, dir, vec3(EPSILON, 1024., 20.), 1.);
    vec3 radiance = render(origin, dir, hit);

    if (hit.y < 1.) {
        radiance = water(origin, dir, hit);
    }

    // Fade in
    radiance *= min(u_Time / (EPOCH * 2.), 1.);

    // Fade out
    radiance *= min(15. - u_Time / (EPOCH * 2.), 1.);

    FragColor = vec4(radiance, 1.);
}
