skiplist.pyx 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import numpy as np
  2. cimport numpy as np
  3. cimport cython
  4. from cython.parallel import prange
  5. from libc.stdlib cimport malloc, free, calloc
  6. cdef import from "math.h":
  7. double log(double) nogil
  8. cimport openmp
  9. cdef import from "omp.h":
  10. void omp_set_num_threads(int num_threads)
  11. ctypedef struct omp_lock_t:
  12. pass
  13. import time, os
  14. cdef double INFINITY = np.inf
  15. cdef double NEG_INFINITY = -np.inf
  16. cdef class SortedXYZList:
  17. ''' skiplist like structure that ensures inserted elements are always sorted and that removes
  18. the smallest entries if capacity is exceeded
  19. http://igoro.com/archive/skip-lists-are-fascinating/
  20. http://code.activestate.com/recipes/576930-efficient-running-median-using-an-indexable-skipli/
  21. '''
  22. cdef public:
  23. double minimum
  24. char isfilled
  25. cdef:
  26. int maxlevels, capacity, insertion_count
  27. int *nodelinks
  28. double *nodevalues
  29. int *xyz
  30. int fill
  31. def __init__(self, int capacity):
  32. self.fill = 2 # start at position 2 as the head and the tail are allready in lists
  33. self.maxlevels = int(1 + log(capacity) / log(2))
  34. self.capacity = capacity + 2
  35. self.clear(True)
  36. self.isfilled = 0
  37. @cython.cdivision(True)
  38. cdef void insert(SortedXYZList self, double value, int x, int y, int z) nogil:
  39. cdef int level, maxlevels, curr, next, new, smallest
  40. cdef int offset_curr, offset_smallest, offset_xyz
  41. maxlevels = self.maxlevels
  42. if self.isfilled:
  43. # insert a new element
  44. new = self.fill
  45. self.fill += 1
  46. if self.fill == self.capacity:
  47. smallest = self.nodelinks[0] # the head node points to the smallest element
  48. offset_smallest = smallest * maxlevels
  49. # update the minimum to the value of the smallest element
  50. self.minimum = self.nodevalues[self.nodelinks[offset_smallest]]
  51. # check if curent value is even smaller
  52. if value < self.minimum:
  53. self.minimum = value
  54. self.isfilled = 1
  55. # update minimum
  56. #~ if value < self.minimum:
  57. #~ self.minimum = value
  58. else:
  59. if value <= self.minimum:
  60. # nothing to do - we will not insert a element
  61. # that is too small
  62. return
  63. # replace the smallest element
  64. smallest = self.nodelinks[0] # the head node points to the smallest element
  65. offset_smallest = smallest * maxlevels
  66. # update the minimum to the value of the second-smallest element
  67. self.minimum = self.nodevalues[self.nodelinks[offset_smallest]]
  68. if value < self.minimum:
  69. # the new incoming value is still smaller than the *currently* second smallest
  70. self.minimum = value
  71. # now update all next-links on head node to the outgoing links of the "node to be removed = the smallest"
  72. for level in range(maxlevels):
  73. next_on_level = self.nodelinks[offset_smallest + level]
  74. if next_on_level > 0: # will be 0 if the "historical" level_for_insert is reached
  75. self.nodelinks[level] = next_on_level
  76. # clear old entries!
  77. self.nodelinks[offset_smallest + level] = 0
  78. else:
  79. break
  80. # the new node will get the position of the smallest node (in the array)
  81. new = smallest
  82. # With a probability of 1/2, make the node a part of the lowest-level list only.
  83. # With 1/4 probability, the node will be a part of the lowest two lists.
  84. # With 1/8 probability, the node will be a part of three lists. And so forth.
  85. cdef int level_for_insert, i
  86. #~ level_for_insert = min(self.maxlevels, 1 - int(log(random()) / log(2)))
  87. for i in range(self.maxlevels - 1, -1, -1):
  88. if self.insertion_count % (1 << i) == 0:
  89. level_for_insert = i + 1
  90. break
  91. self.insertion_count += 1
  92. curr = 0 # the head node
  93. self.nodevalues[new] = value
  94. # store xyz values
  95. offset_xyz = new*3
  96. self.xyz[offset_xyz] = x
  97. self.xyz[offset_xyz + 1] = y
  98. self.xyz[offset_xyz + 2] = z
  99. for level in range(self.maxlevels - 1, -1, -1):
  100. offset_curr = curr * maxlevels + level
  101. next = self.nodelinks[offset_curr]
  102. while value > self.nodevalues[next]:
  103. curr = next
  104. offset_curr = curr * maxlevels + level
  105. next = self.nodelinks[offset_curr]
  106. if level < level_for_insert:
  107. self.nodelinks[new * maxlevels + level] = self.nodelinks[offset_curr]
  108. self.nodelinks[offset_curr] = new
  109. def getall(self):
  110. ''' debug function'''
  111. cdef int curr = 0, next, offset_curr = curr * self.maxlevels
  112. next = self.nodelinks[offset_curr]
  113. while self.nodevalues[next] < INFINITY:
  114. v = self.nodevalues[self.nodelinks[curr * self.maxlevels]]
  115. curr = next
  116. print v, curr, self.nodelinks[curr * self.maxlevels]
  117. next = self.nodelinks[curr * self.maxlevels]
  118. @cython.boundscheck(False)
  119. def xyzvAsArray(self):
  120. '''transform result to sorted np.array'''
  121. cdef np.ndarray[np.double_t, ndim=2] result = np.zeros((self.fill - 2, 4), dtype=np.double)
  122. cdef int i, offset
  123. cdef int curr = 0, next, offset_curr = curr * self.maxlevels
  124. next = self.nodelinks[offset_curr]
  125. i = 0
  126. while self.nodevalues[next] < INFINITY:
  127. offset = self.nodelinks[curr * self.maxlevels] * 3
  128. result[i, 0] = self.xyz[offset]
  129. result[i, 1] = self.xyz[offset+1]
  130. result[i, 2] = self.xyz[offset+2]
  131. result[i, 3] = self.nodevalues[self.nodelinks[curr * self.maxlevels]]
  132. curr = next
  133. next = self.nodelinks[curr * self.maxlevels]
  134. i += 1
  135. return result
  136. cdef clear(SortedXYZList self, realloc=False):
  137. cdef int i
  138. if realloc:
  139. self.nodelinks = <int*>calloc(self.capacity * self.maxlevels, sizeof(int))
  140. self.nodevalues = <double*>calloc(self.capacity, sizeof(double))
  141. self.xyz = <int*>calloc(self.capacity * 3, sizeof(int))
  142. else:
  143. # only set links to 0
  144. for i in range(self.capacity * self.maxlevels):
  145. self.nodelinks[i] = 0
  146. # start at position 2 as the head and the tail are allready in lists
  147. self.fill = 2
  148. self.isfilled = 0
  149. # initialize head node
  150. for i in range(self.maxlevels):
  151. self.nodelinks[0 + i] = 1
  152. # initialize tail node
  153. self.nodevalues[1] = INFINITY
  154. self.insertion_count = 1 # used for generating probabilities
  155. # will be set automatically to valid values if initial loading happens
  156. self.minimum = NEG_INFINITY
  157. def __dealloc__(SortedXYZList self):
  158. free(self.nodelinks)
  159. free(self.nodevalues)
  160. def getData(n, load_stored, seed):
  161. t = time.time()
  162. if not load_stored or not os.path.exists('data_%s_%s.npz' % (n, seed)):
  163. print 'building data with SEED %s' % seed
  164. np.random.seed(seed)
  165. values = np.random.random_sample(size=n).astype(np.double)
  166. xyz = np.random.randint(256, size=n*3).astype(np.uint8).reshape((n, 3))
  167. np.savez('data_%s_%s.npz' % (n, seed), values=values, xyz=xyz)
  168. print 'random + store to ./data_%s_%s.npz: %.3f sec' % (n, seed, time.time() - t)
  169. else:
  170. npzfile = np.load('data_%s_%s.npz' % (n, seed))
  171. xyz = npzfile['xyz']
  172. values = npzfile['values']
  173. print 'load: %.3f sec' % (time.time() - t)
  174. return values, xyz
  175. @cython.boundscheck(False)
  176. def test(int size, int n, parallel=True, load_stored=False, seed=190180):
  177. cdef np.ndarray[np.double_t, ndim=1] values
  178. cdef np.ndarray[np.uint8_t, ndim=2] xyz
  179. cdef omp_lock_t mylock
  180. values, xyz = getData(n, load_stored, seed)
  181. cdef int i
  182. ll = SortedXYZList(size)
  183. t = time.time()
  184. if not parallel:
  185. for i in range(n):
  186. #~ print '%s: inserting %s' % (i, xx[i])
  187. ll.insert(values[i], xyz[i, 0], xyz[i, 1], xyz[i, 2])
  188. pass
  189. else:
  190. omp_set_num_threads(8)
  191. openmp.omp_init_lock(&mylock) # initialize
  192. for i in prange(n, nogil=True):
  193. if values[i] > ll.minimum:
  194. openmp.omp_set_lock(&mylock) # acquire
  195. ll.insert(values[i], xyz[i, 0], xyz[i, 1], xyz[i, 2])
  196. openmp.omp_unset_lock(&mylock) # release
  197. openmp.omp_destroy_lock(&mylock) # deallocate the lock
  198. #~ ll.getall()
  199. print 'insert: %.3f secs, %.2fmio per sec' % (time.time() - t, n / 1000.0**2 / (time.time() - t + 0.000000000001))
  200. #~ if seed == 190180 and size == 1e4 and n == 1e8:
  201. #~ # do not compare double values
  202. #~ res = ll.xyzvAsArray()
  203. #~ np.savetxt('result_190180_1e4_1e8.txt', res, fmt='%d %d %d %s')
  204. #~ stored_res = np.loadtxt('result_190180_1e4_1e8.txt')[:, :3]
  205. #~ assert np.all(np.equal(res[:, :3], stored_res)), 'skiplist result differs'