// #define TEST
#define X	100
#define Y	100

//#define USE_DBL
#ifdef USE_DBL
#define TYPE_FLT  double
#define TYPE_INT  long
#define MASK      0xFFFF
#define SHIFT     16
#define EPSILSON  4.94065645841247E-324
#else
#define TYPE_FLT  float
#define TYPE_INT  int
#define MASK      0x00FF
#define SHIFT     8
#define EPSILSON  1.401298E-45
#endif

constant sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE |
                                CLK_ADDRESS_CLAMP_TO_EDGE |
                                CLK_FILTER_LINEAR;


struct input {
  int state;
  int gain;
  int dyn_cont_selector;
  int compression_type;
};

TYPE_FLT read_data(image2d_t input_frame, int x, int z)
{
  int2 gid = (int2)(x, z);
  uint4 pixel = read_imageui(input_frame, sampler, gid);
  TYPE_INT temp =  (TYPE_INT)((TYPE_INT)pixel.x & MASK) | 
                    (TYPE_INT)(((TYPE_INT)pixel.y & MASK) << SHIFT) |
                    (TYPE_INT)(((TYPE_INT)pixel.z & MASK) << (SHIFT * 2)) | 
                    (TYPE_INT)(((TYPE_INT)pixel.w & MASK) << (SHIFT * 3));
  TYPE_FLT raw_data = *((TYPE_FLT*)(&temp));
  return raw_data;
}


kernel void DynCont(read_only image2d_t input_frame, read_write image2d_t output_frame, local TYPE_FLT* max, global TYPE_FLT* scratch_pad, struct input params) 
{
  int2 gid = (int2)(get_global_id(0), get_global_id(1));
  TYPE_FLT input = read_data(input_frame, gid.x, gid.y);

  if(params.state == 1)
  {
    uint local_id = get_local_id(0);
    uint group_size = get_local_size(0);
    max[local_id] = input;

    uint is_odd = group_size % 2;
    // Loop for computing localMaxes : divide WorkGroup into 2 parts
    for (uint stride = group_size / 2; stride > 0; stride /= 2)
    {
      // Waiting
      barrier(CLK_LOCAL_MEM_FENCE);

      if (local_id < stride)
        max[local_id] = max[local_id] > max[local_id + stride] ? max[local_id] : max[local_id + stride];

      if (local_id == 0)
      {
        if(is_odd)
        max[local_id] = max[local_id] > max[2 * stride] ? max[local_id] : max[2 * stride];
        is_odd = stride % 2;
      }    
    }

    // Write result into scratchPad[nWorkGroups]
    if (local_id == 0)
    {
      scratch_pad[get_group_id(1)] = max[0];
    }
  }
  else if(params.state == 2)
  {
    uint local_id = get_local_id(0);
    uint group_size = get_local_size(0);
    max[local_id] = scratch_pad[gid.x];

    uint is_odd = group_size % 2;
    // Loop for computing localMaxes : divide WorkGroup into 2 parts
    for (uint stride = group_size / 2; stride > 0; stride /= 2)
    {
      // Waiting
      barrier(CLK_LOCAL_MEM_FENCE);

      if (local_id < stride)
        max[local_id] = max[local_id] > max[local_id + stride] ? max[local_id] : max[local_id + stride];

      if (local_id == 0)
      {
        if(is_odd)
        max[local_id] = max[local_id] > max[2 * stride] ? max[local_id] : max[2 * stride];
        is_odd = stride % 2;
      }    
    }

    // Write result into scratchPad[nWorkGroups]
    if (local_id == 0)
    {
      scratch_pad[0] = max[0];
    }
  }
  else if(params.state == 3)
  {
    int dynamic_range = 60 + params.dyn_cont_selector;
    TYPE_FLT max = scratch_pad[0];
    TYPE_FLT output_data = 0;
    //log compress
    if(params.compression_type == 1)
    {
      output_data = 20 * log10(input / max);

      if(output_data < -dynamic_range)
        output_data = -dynamic_range;

      output_data += params.gain;

      if(output_data > 0)
        output_data = 0;

      if(output_data < -dynamic_range)
        output_data = -dynamic_range;

      output_data = 255 * (output_data + dynamic_range) / dynamic_range;
    }
    else
    {
      output_data = input / max;
      
      if(output_data < pow(10, (-dynamic_range / 20.0)))
      {
        output_data = 0;
      }

      output_data *= pow(10, (params.gain / 20.0));

      if(output_data < 0)
      {
        output_data = 0;
      }

      if(output_data > 1)
        output_data = 1;

      output_data *= 255;

#ifdef TEST
    if(gid.x == X && gid.y == Y)
      printf("%.6f %.6f %.6f ----", output_data, input, max);
#endif
    }
  
    TYPE_INT out = *((TYPE_INT*)(&output_data));
    uint4 pixel;
    pixel.x = (TYPE_INT)(out & MASK);
    pixel.y = (TYPE_INT)((out >> SHIFT) & MASK);
    pixel.z = (TYPE_INT)((out >> (SHIFT *2)) & MASK);
    pixel.w = (TYPE_INT)((out >> (SHIFT * 3)) & MASK);

    write_imageui(output_frame, gid, pixel);
  }
}