#include "data\\shaders\\common.h"
#include "data\\shaders\\input_formats.h"

static const float g_max_radius = 1.0f;

static const float2 offsets[] =
{
	2.0f * float2(1.000000f, 0.000000f),
	2.0f * float2(0.707107f, 0.707107f),
	2.0f * float2(-0.000000f, 1.000000f),
	2.0f * float2(-0.707107f, 0.707107f),
	2.0f * float2(-1.000000f, -0.000000f),
	2.0f * float2(-0.707106f, -0.707107f),
	2.0f * float2(0.000000f, -1.000000f),
	2.0f * float2(0.707107f, -0.707107f),
	
	4.0f * float2(1.000000f, 0.000000f),
	4.0f * float2(0.923880f, 0.382683f),
	4.0f * float2(0.707107f, 0.707107f),
	4.0f * float2(0.382683f, 0.923880f),
	4.0f * float2(-0.000000f, 1.000000f),
	4.0f * float2(-0.382684f, 0.923879f),
	4.0f * float2(-0.707107f, 0.707107f),
	4.0f * float2(-0.923880f, 0.382683f),
	4.0f * float2(-1.000000f, -0.000000f),
	4.0f * float2(-0.923879f, -0.382684f),
	4.0f * float2(-0.707106f, -0.707107f),
	4.0f * float2(-0.382683f, -0.923880f),
	4.0f * float2(0.000000f, -1.000000f),
	4.0f * float2(0.382684f, -0.923879f),
	4.0f * float2(0.707107f, -0.707107f),
	4.0f * float2(0.923880f, -0.382683f),

	6.0f * float2(1.000000f, 0.000000f),
	6.0f * float2(0.965926f, 0.258819f),
	6.0f * float2(0.866025f, 0.500000f),
	6.0f * float2(0.707107f, 0.707107f),
	6.0f * float2(0.500000f, 0.866026f),
	6.0f * float2(0.258819f, 0.965926f),
	6.0f * float2(-0.000000f, 1.000000f),
	6.0f * float2(-0.258819f, 0.965926f),
	6.0f * float2(-0.500000f, 0.866025f),
	6.0f * float2(-0.707107f, 0.707107f),
	6.0f * float2(-0.866026f, 0.500000f),
	6.0f * float2(-0.965926f, 0.258819f),
	6.0f * float2(-1.000000f, -0.000000f),
	6.0f * float2(-0.965926f, -0.258820f),
	6.0f * float2(-0.866025f, -0.500000f),
	6.0f * float2(-0.707106f, -0.707107f),
	6.0f * float2(-0.499999f, -0.866026f),
	6.0f * float2(-0.258819f, -0.965926f),
	6.0f * float2(0.000000f, -1.000000f),
	6.0f * float2(0.258819f, -0.965926f),
	6.0f * float2(0.500000f, -0.866025f),
	6.0f * float2(0.707107f, -0.707107f),
	6.0f * float2(0.866026f, -0.499999f),
	6.0f * float2(0.965926f, -0.258818f),
};

Texture2D g_color : register(t0);
Texture2D g_color_far : register(t1);
Texture2D g_coc : register(t2);

SamplerState g_sam_point : register(s1);
SamplerState g_sam_linear : register(s3);

float KarisWeight(float3 c)
{
  return pow(Luminance(c),2.0f);
}

float4 Near(float2 uv, float2 pixelSize, float coc)
{
	float4 result = g_color.Sample(g_sam_linear, uv);  
	for (int i = 0; i < 48; i++)
	{
		float2 offset = g_max_radius * coc * offsets[i] * pixelSize;
    float4 sample = g_color.Sample(g_sam_linear, uv + offset);
    result += sample;
  }
  
  result /= 49.0f;
	return result;
}

float4 Far(float2 uv, float2 pixelSize, float coc)
{
	float4 result = g_color_far.Sample(g_sam_linear, uv);
	float weightsSum = g_coc.Sample(g_sam_linear, uv).y;
	for (int i = 0; i < 48; i++)
	{
		float2 offset = g_max_radius * coc * offsets[i] * pixelSize;
		
		float cocSample = g_coc.Sample(g_sam_linear, uv + offset).y;
		float4 sample = g_color_far.Sample(g_sam_linear, uv + offset);
    
    result += sample;
    weightsSum += cocSample;
  }

	return result / weightsSum;	
}

struct PSOut
{
    float4 near_field			: SV_Target0;
    float4 far_field			: SV_Target1;
};

PSOut main(VertexTOut pin)
{
  PSOut output;
  float2 coc = g_coc.Sample(g_sam_point, pin.uv).rg;
  
  float2 dim;
  g_color.GetDimensions(dim.x,dim.y);
  float2 pixel_size = 1.0f.xx / dim;
  
  float4 color = g_color.Sample(g_sam_point, pin.uv);
  
  if (coc.r > 0.0f)
    output.near_field = Near(pin.uv, pixel_size, coc.r);
  else
    output.near_field = color;
  
  if (coc.g > 0.0f)
    output.far_field = Far(pin.uv, pixel_size, coc.g);
  else
    output.far_field = 0.0f.xxxx;
  
  return output;
}