Lower latency associative scan option #10599
Unanswered
oliverdutton
asked this question in
Ideas
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
In one of my problems the implementation was bottlenecked by a cumulative matmul. JAX has a handy implemention of a work-efficient associative scan for this
lax.associative_scan
. This reduces the procedure from N steps to 2 log_2{N}-2 steps. There is a work-inefficient implementation that reduces this to log_2{N} steps, shown below which is faster for small problem sizes where the GPU is not saturated.Would anyone else be interested in having a work-inefficient option in the lax.associative scan? If so I can put together a pull request.
see https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda, http://www.cs.cmu.edu/~guyb/papers/Ble93.pdf, https://en.wikipedia.org/wiki/Prefix_sum
Beta Was this translation helpful? Give feedback.
All reactions