#include "lib/shared/hash-functions.hlsl"
#include "lib/shared/noise-functions.hlsl"
#include "lib/shared/point.hlsl"

cbuffer Params : register(b0)
{
    float time;
    float dt;
    float reset;
    float forceAvoid;
    float forceTarget;
    float forceFollow;
    float rangeFollow;
    float frictionScale;
    float globalSpeed;
    float maxSpeed;
}

StructuredBuffer<Point> SourcePoints : t0;        
StructuredBuffer<Point> Targets : t1;        
RWStructuredBuffer<Point> ResultPoints : u0;   

[numthreads(64,1,1)]
void main(uint3 DTid : SV_DispatchThreadID)
{
    uint i = DTid.x;
    uint pointCount, targetCount, _;
    SourcePoints.GetDimensions(pointCount, _);
    Targets.GetDimensions(targetCount, _);
    if (i >= pointCount) return;

    Point s = SourcePoints[i];
    Point p = ResultPoints[i];

    float f = i/float(pointCount);
    float index = f*targetCount;
    float blend = frac(f*targetCount);
    float ia = index % targetCount;
    float ib = (index+1) % targetCount;
    float iaa = (index-1+targetCount) % targetCount;
    float ibb = (index+2) % targetCount;
    Point pa = Targets[ia];
    Point pb = Targets[ib];
    Point paa = Targets[iaa];
    Point pbb = Targets[ibb];
    float3 target = lerp(pa.position, pb.position, blend);
    float tw = lerp(pa.w, pb.w, blend);

    target = !(!isnan(pa.w) * !isnan(pb.w)) ? Targets[hash11(i+196)*targetCount].position : target;

    if (reset)
    {
        ResultPoints[i] = SourcePoints[i];
        ResultPoints[i].position = target + (hash31(i)-.5)*.01;
        // ResultPoints[i].position = (hash31(i+9416)-.5);
        ResultPoints[i].rotation.xyz = curlNoise(target);
        return;
    }
    
    float3 pos = p.position.xyz;
    float3 velocity = p.rotation.xyz;
    float3 move = 0;
    float3 avoid = 0;
    float3 follow = 0;
    float dist = 0;
    float w = 0;
    float size = tw * s.w;
    
    // size *= !isnan(pa.w) * !isnan(pb.w);
    // size *= isnan(paa.w) ? blend : isnan(pbb.w) ? 1 - blend : 1;

    for (uint ii = 0; ii < pointCount; ++ii)
    {
        Point other = ResultPoints[ii];
        dist = distance(other.position, pos);
        w = p.w + other.w;
        if (dist > 0)
        {
            if (dist < w)
            {
                avoid += normalize(pos - other.position) * (w - dist);
            }
            else if (dist < w * rangeFollow)
            {
                follow += other.rotation.xyz * (w * rangeFollow - dist);
            }
        }
    }

    dist = distance(target, pos);
    target = dist > 0 ? normalize(target - pos) * smoothstep(0.0, 0.1, dist) : 0;
    move = avoid * forceAvoid + target * forceTarget + follow * forceFollow;

    float friction = clamp(1-dt*frictionScale,0,1);
    velocity = velocity * friction + move * dt * 0.1 * globalSpeed / size;
    dist = length(velocity);
    velocity = dist > maxSpeed ? normalize(velocity) * maxSpeed : velocity;
    pos += velocity * dt;

    p.position = pos;
    p.rotation.xyz = velocity;
    p.w = size;

    ResultPoints[i] = p;
}