#pragma kernel CS_Shadowmap SHADOWMAP
#pragma kernel CS_Lightmap LIGHTMAP
#pragma kernel CS_Illumination ILLUMINATION
#pragma kernel CS_CopyDepth COPY_DEPTH
#pragma multi_compile_local __ INPUT_2D_ARRAY
#pragma exclude_renderers gles gles3

#include <RenderingUtils.cginc>

#define MAX_UINT16 0xffffu
#define PARTICLE_THREADS 256
#define DEPTH_COPY_WORKGROUP 16
#define WORKGROUP_SIZE_X 8
#define WORKGROUP_SIZE_Y 8
#define WORKGROUP_SIZE_Z 6

#ifdef SHADOWMAP

RWTexture3D<float> ShadowmapOUT;

[numthreads(WORKGROUP_SIZE_X,WORKGROUP_SIZE_Y,WORKGROUP_SIZE_Z)]
void CS_Shadowmap(uint3 id : SV_DispatchThreadID)
{
    if (any((float3(id) + 0.5) > ShadowGridSize))
    {
        return;
    }

    float simulationScale =
        (1.0f / 3.0f) * (ContainerScale.x + ContainerScale.y + ContainerScale.z);
    float3 cellRelativeSize = 1.0 / GridSize;
    float cellSize = ContainerScale.x / GridSize.x;

    float dt = ShadowStepSize * cellSize;
    float stepScale = dt / simulationScale;

    int3 ShadowmapSize;
    ShadowmapOUT.GetDimensions(ShadowmapSize.x, ShadowmapSize.y, ShadowmapSize.z);

    float3 normalizedPos = float3(id) / float3(ShadowmapSize) - 0.5;
    float3 samplingPos = ContainerScale * normalizedPos + ContainerPosition;
    float opticalDepth = 0.0;
    float edgeDist;
    if(Sum(LightColor) > RENDERING_EPS)
        opticalDepth = TraceShadow(samplingPos, LightDirWorld, dt, stepScale, FAR_DISTANCE, edgeDist);
         
    ShadowmapOUT[id] = opticalDepth;
}

#endif

#ifdef LIGHTMAP

RWTexture3D<float4> LightmapOUT;

[numthreads(WORKGROUP_SIZE_X,WORKGROUP_SIZE_Y,WORKGROUP_SIZE_Z)]
void CS_Lightmap(uint3 id : SV_DispatchThreadID)
{
    if (any((float3(id) + 0.5) > LightGridSize))
    {
        return;
    }

    float simulationScale =
        (1.0f / 3.0f) * (ContainerScale.x + ContainerScale.y + ContainerScale.z);
    float3 cellRelativeSize = 1.0 / GridSize;
    float cellSize = ContainerScale.x / GridSize.x;

    float dt = ShadowStepSize * cellSize;
    float stepScale = dt / simulationScale;

    int3 LightmapSize;
    LightmapOUT.GetDimensions(LightmapSize.x, LightmapSize.y, LightmapSize.z);

    float3 normalizedPos = float3(id) / float3(LightmapSize) - 0.5;
    float3 samplingPos = ContainerScale * normalizedPos + ContainerPosition;

    float3 illumination = 0.0;
    float2 averageShadow = RENDERING_EPS;
    for (int i = 0; i < LightCount; i++)
    {
        float dist = distance(LightPositionArray[i].xyz, samplingPos);
        float3 dir = (LightPositionArray[i].xyz - samplingPos) / dist;
        float edgeDist;
        float depth = ShadowIntensity * TraceShadow(samplingPos, dir, dt, stepScale, dist, edgeDist);
        float3 lightBrightness = LightColorArray[i].xyz * GetLightAttenuation(dist, LightPositionArray[i].w).x;
        float shadowWeight = Sum(lightBrightness) * (1 + 0.1*simulationScale / (edgeDist + RENDERING_EPS));
        averageShadow += float2(exp(-depth), 1.0) * shadowWeight;
        float3 scatter = ShadowedScattering(depth);
        illumination += scatter * lightBrightness;
    }
    LightmapOUT[id] = float4(illumination, averageShadow.x/averageShadow.y);
}

#endif

#ifdef ILLUMINATION

RWTexture3D<float3> IlluminationOUT;

float ReactionSpeed;
float TempThreshold;
float TemperatureDensityDependence;

[numthreads(WORKGROUP_SIZE_X,WORKGROUP_SIZE_Y,WORKGROUP_SIZE_Z)]
void CS_Illumination(uint3 id : SV_DispatchThreadID)
{
    if (any((float3(id) + 0.5) > GridSize))
    {
        return;
    }

    float3 simPos = Simulation2World(id);
    float simulationScale = (1.0f / 3.0f) * (ContainerScale.x + ContainerScale.y + ContainerScale.z);

    float rho = Density[id];
    if (rho < 1e-4)
    {
        IlluminationOUT[id] = float3(0.0f, 0.0f, 0.0f);
        return;
    }
    
    float3 color = rho;
    float3 density = rho;
    
    if (SimulationMode == SIMULATION_MODE_COLORED_SMOKE)
    {
        density = float3(rho/SmokeDensity, Color[id]);
        color = Density2RGB(density);
    }
    
    if (SimulationMode == SIMULATION_MODE_FIRE)
    {
        density = float3(rho, Color[id]);
    }

    float3 normalizedColor = color.xyz / max(1e-3, Max(color.xyz));
    
    float OpticalDensity = 0;
    
    if (MainLightMode == 1)
    {
        OpticalDensity = SampleShadowmapSmooth(simPos, simulationScale);
    }
    else
    {
        OpticalDensity = 100.0;
    }
    
    float3 colorOUT = LightColor * normalizedColor * ShadowedScattering(OpticalDensity); 
    
    if(LightCount > 0)
    {
        colorOUT += normalizedColor * LightmapSample(float3(id)/GridSize).xyz;
    }
    
    if (SimulationMode == SIMULATION_MODE_FIRE)
    {
        float reactionRate = ReactionSpeed * max(sqrt(density.z) - TempThreshold, 0.0);
        
        colorOUT +=
            BlackBodyBrightness * rho * BlackBodyRadiation(0.5 * density.z) +
            FireBrightness * FireColor.rgb * density.y * reactionRate / (rho + 1.0);
    }

    IlluminationOUT[id] = colorOUT;
}

#endif

#ifdef COPY_DEPTH

RWTexture2D<float> DepthDest;

#if defined(INPUT_2D_ARRAY)
Texture2DArray _CameraDepthTexture;
#else
Texture2D _CameraDepthTexture;
#endif

float LoadCameraDepth(float2 pos)
{
#if defined(INPUT_2D_ARRAY)
    float sceneDepth = _CameraDepthTexture.Load(int4(pos, 0, 0)).x;
#else
    float sceneDepth = _CameraDepthTexture.Load(int3(pos, 0)).x;
#endif
#if !defined(UNITY_REVERSED_Z)
    sceneDepth = 1.0 - sceneDepth;
#endif
    return sceneDepth;
}

[numthreads(DEPTH_COPY_WORKGROUP, DEPTH_COPY_WORKGROUP, 1)]
void CS_CopyDepth(uint3 id : SV_DispatchThreadID)
{
    if (any(id.xy > (uint2)OriginalCameraResolution))
    {
        return;
    }
    
    DepthDest[id.xy] = LoadCameraDepth(id.xy);
}

#endif