typedef struct LatticeData
{
	float4 pos;
	float4 dir;
} LatticeData;


LatticeData GetLatticeData(__global struct LatticeData *inLattice, uint num_x_lattice, uint x, uint y )
{
	return inLattice[(x) + (y)*num_x_lattice];
}

LatticeData GetLatticeDataStepped(__global struct LatticeData *inLattice, uint num_x_lattice, uint vertStep, uint x, uint y )
{
	return inLattice[(x*vertStep) + (y*vertStep)*num_x_lattice];
}

float3 Pos( LatticeData datum )
{
	return datum.pos.xyz + datum.dir.xyz;
}

float DistToPlane( float4 plane, float3 pos )
{
	return (dot(plane.xyz, pos) + plane.w) / (length(plane.xyz));
}

__kernel void main( __global struct LatticeData *latticeData, 
					uint vertStep
					  )
{
	// verts
	uint x = get_global_id(0);
	uint y = get_global_id(1);
	uint num_x = get_global_size(0);
	uint num_y = get_global_size(1);

	// lattice (has finer detail than regular)
	uint lattice_x = x * vertStep;
	uint lattice_y = y * vertStep;
	uint num_x_lattice = num_x*vertStep+1;
	uint num_y_lattice = num_y*vertStep+1;

	LatticeData centerLattice = GetLatticeData(latticeData, num_x_lattice, lattice_x, lattice_y );
	float3 centerPos = Pos(centerLattice);

	// TEMP IGNORE EDGES FOR NOW
	if ( x == 0 || x == num_x || y == 0 || y == num_y ) 
	{
		return;
	}


	// calculate normal
	float3 pos[4];
	pos[0] = Pos( GetLatticeData(latticeData, num_x_lattice, lattice_x-1, lattice_y-1 ) );
	pos[1] = Pos( GetLatticeData(latticeData, num_x_lattice, lattice_x+1, lattice_y-1 ) );
	pos[2] = Pos( GetLatticeData(latticeData, num_x_lattice, lattice_x+1, lattice_y+1 ) );
	pos[3] = Pos( GetLatticeData(latticeData, num_x_lattice, lattice_x-1, lattice_y+1 ) );

	// calculate normal by figuring out the sides
	float3 surfaceNormal0 = normalize(cross( pos[1]-pos[0], pos[2]-pos[0] ));
	float3 surfaceNormal1 = normalize(cross( pos[2]-pos[0], pos[3]-pos[0] ));
	
	// calc plane
	float3 planeNormal = normalize( surfaceNormal0 + surfaceNormal1 );
	float planeD = -dot(planeNormal, centerPos);
	float4 plane = (float4)(planeNormal, planeD);


	float visibility = 1.0f;

	// if positive, now it means that
	for ( uint i = 0; i < 4; ++i )
	{
		float distToPt = DistToPlane( plane, pos[i] );
		visibility -= max(0.0f, distToPt*20.0f);
	}

	//centerLattice.pos.xyz = (float3)(1.0f);//visibility;
	centerLattice.pos.w = visibility;

	uint writeIndex = lattice_x + lattice_y*num_x_lattice;
	latticeData[writeIndex] = centerLattice;
}