_gpu.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import pyopencl as cl
  2. import numpy
  3. import numpy.linalg as la
  4. from scipy.misc import pilutil
  5. ctx = cl.create_some_context()
  6. queue = cl.CommandQueue(ctx)
  7. def dostuff(rays, map):
  8. ray_floats, ray_ints = rays
  9. num_rays = len(ray_floats)
  10. ray_floats = ray_floats.ravel()
  11. ray_ints = ray_ints.ravel()
  12. print ray_ints[2000 * 4 + 1]
  13. #~ asd
  14. maparray = numpy.zeros(shape=(map.width*map.height*map.numsources, ), dtype=numpy.float32)
  15. mf = cl.mem_flags
  16. ray_floats_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=ray_floats)
  17. ray_ints_buf = cl.Buffer(ctx, mf.COPY_HOST_PTR, hostbuf=ray_ints)
  18. maparray_buf = cl.Buffer(ctx, mf.COPY_HOST_PTR, hostbuf=maparray)
  19. prg = cl.Program(ctx, """
  20. __kernel void sum(__global const float *ray_floats,
  21. __global const int *ray_ints, __global float *map)
  22. {
  23. int gid = get_global_id(0);
  24. float dx = ray_floats[gid * 5 + 0];
  25. float dy = ray_floats[gid * 5 + 1];
  26. float x = ray_floats[gid * 5 + 2];
  27. float y = ray_floats[gid * 5 + 3];
  28. float power = ray_floats[gid * 5 + 4];
  29. int dist = ray_ints[gid * 4 + 0];
  30. int source_id = ray_ints[gid * 4 + 1];
  31. int generation = ray_ints[gid * 4 + 2];
  32. int medium = ray_ints[gid * 4 + 3];
  33. float newpower, currpower;
  34. int xi, yi, mapidx;
  35. int width = 860, height = 300;
  36. int loop;
  37. int raylength = 1000;
  38. float _raylength = raylength;
  39. for(loop=0; loop < raylength; loop++)
  40. {
  41. newpower = power * (1.0 - (loop / _raylength));
  42. if (newpower == 0)
  43. break;
  44. x = x + dx;
  45. y = y + dy;
  46. if (x <= 0 or x >= width - 1 or y <= 0 or y >= height)
  47. break;
  48. xi = x;
  49. yi = y;
  50. mapidx = source_id * width * height + yi * width + xi;
  51. currpower = map[mapidx];
  52. if (newpower > currpower)
  53. {
  54. map[mapidx] = newpower;
  55. }
  56. }
  57. }
  58. """).build()
  59. #~ print map.height, map.width
  60. #~ return
  61. print 'num rays: %s' % num_rays
  62. prg.sum(queue, (num_rays, ), None, ray_floats_buf, ray_ints_buf, maparray_buf)
  63. cl.enqueue_copy(queue, maparray, maparray_buf)
  64. #~ print numpy.max(maparray)
  65. #~ rs = maparray.reshape(map.width, map.height, map.numsources)
  66. #~ a = numpy.transpose(rs, axes=(2,0,1))[0]
  67. #~ pilutil.imsave('d:/out.png', a)
  68. # load data into map
  69. map.fromArray(maparray)