#include "data\\shaders\\common.h"
#include "data\\shaders\\input_formats.h"

Texture2D g_coc : register(t0);
Texture2D g_near_field : register(t1);
Texture2D g_far_field : register(t2);

SamplerState g_sam_point : register(s1);
SamplerState g_sam_linear : register(s3);

struct PSOut
{
    float4 near_field			: SV_Target0;
    float4 far_field			: SV_Target1;
};

PSOut main(VertexTOut pin)
{
  float2 dim;
  g_coc.GetDimensions(dim.x,dim.y);
  float2 texel_size = 1.0f.xx / dim;

  PSOut output;
  float2 coc = g_coc.SampleLevel(g_sam_point, pin.uv, 0).rg;
  
  output.near_field = g_near_field.SampleLevel(g_sam_linear, pin.uv, 0);
  if (coc.r > 0.0f)
  {
    for (int i = -1; i <= 1; i++)
		{
			for (int j = -1; j <= 1; j++)
			{
        float4 sample = g_near_field.SampleLevel(g_sam_linear, pin.uv + float2(i,j)*texel_size*coc.r, 0);
        output.near_field += sample;
      }
    }
    output.near_field /= 9.0f;
    output.near_field.a = 1.0f;
  }
    
  output.far_field = g_far_field.SampleLevel(g_sam_linear, pin.uv, 0);
  if (coc.g > 0.0f)
  {
    float total_weight = 0.0f;
    for (int i = -1; i <= 1; i++)
		{
			for (int j = -1; j <= 1; j++)
			{
        float4 sample = g_far_field.SampleLevel(g_sam_linear, pin.uv + float2(i,j)*texel_size*coc.g, 0);
        float far_coc = g_coc.SampleLevel(g_sam_point, pin.uv + + float2(i,j)*texel_size*coc.g, 0).g;
        if ((coc.g - far_coc) < 0.5f)
        {
          output.far_field += sample;
          total_weight += 1.0f;
        }     
      }
    }
    output.far_field /= total_weight;
    output.far_field.a = 1.0f;
  }

  return output;
}