Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial SIMD acceleration for the XOF (AVX-512-only, Unix-only) #418

Merged
merged 14 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,46 @@ fn bench_two_updates(b: &mut Bencher) {
hasher.finalize()
});
}

fn bench_xof(b: &mut Bencher, len: usize) {
b.bytes = len as u64;
let mut output = [0u8; 64 * BLOCK_LEN];
let output_slice = &mut output[..len];
let mut xof = blake3::Hasher::new().finalize_xof();
b.iter(|| xof.fill(output_slice));
}

#[bench]
fn bench_xof_01_block(b: &mut Bencher) {
bench_xof(b, BLOCK_LEN);
}

#[bench]
fn bench_xof_02_blocks(b: &mut Bencher) {
bench_xof(b, 2 * BLOCK_LEN);
}

#[bench]
fn bench_xof_04_blocks(b: &mut Bencher) {
bench_xof(b, 4 * BLOCK_LEN);
}

#[bench]
fn bench_xof_08_blocks(b: &mut Bencher) {
bench_xof(b, 8 * BLOCK_LEN);
}

#[bench]
fn bench_xof_16_blocks(b: &mut Bencher) {
bench_xof(b, 16 * BLOCK_LEN);
}

#[bench]
fn bench_xof_32_blocks(b: &mut Bencher) {
bench_xof(b, 32 * BLOCK_LEN);
}

#[bench]
fn bench_xof_64_blocks(b: &mut Bencher) {
bench_xof(b, 64 * BLOCK_LEN);
}
34 changes: 20 additions & 14 deletions c/blake3.c
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,30 @@ INLINE void output_chaining_value(const output_t *self, uint8_t cv[32]) {

INLINE void output_root_bytes(const output_t *self, uint64_t seek, uint8_t *out,
size_t out_len) {
if (out_len == 0) {
return;
}
uint64_t output_block_counter = seek / 64;
size_t offset_within_block = seek % 64;
uint8_t wide_buf[64];
while (out_len > 0) {
blake3_compress_xof(self->input_cv, self->block, self->block_len,
output_block_counter, self->flags | ROOT, wide_buf);
size_t available_bytes = 64 - offset_within_block;
size_t memcpy_len;
if (out_len > available_bytes) {
memcpy_len = available_bytes;
} else {
memcpy_len = out_len;
}
memcpy(out, wide_buf + offset_within_block, memcpy_len);
out += memcpy_len;
out_len -= memcpy_len;
if(offset_within_block) {
blake3_compress_xof(self->input_cv, self->block, self->block_len, output_block_counter, self->flags | ROOT, wide_buf);
const size_t available_bytes = 64 - offset_within_block;
const size_t bytes = out_len > available_bytes ? available_bytes : out_len;
memcpy(out, wide_buf + offset_within_block, bytes);
out += bytes;
out_len -= bytes;
output_block_counter += 1;
offset_within_block = 0;
}
if(out_len / 64) {
blake3_xof_many(self->input_cv, self->block, self->block_len, output_block_counter, self->flags | ROOT, out, out_len / 64);
}
output_block_counter += out_len / 64;
out += out_len & -64;
out_len -= out_len & -64;
if(out_len) {
blake3_compress_xof(self->input_cv, self->block, self->block_len, output_block_counter, self->flags | ROOT, wide_buf);
memcpy(out, wide_buf, out_len);
}
}

Expand Down
178 changes: 173 additions & 5 deletions c/blake3_avx512.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,27 @@
_mm_shuffle_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), (c))))

INLINE __m128i loadu_128(const uint8_t src[16]) {
return _mm_loadu_si128((const __m128i *)src);
return _mm_loadu_si128((void*)src);
}

INLINE __m256i loadu_256(const uint8_t src[32]) {
return _mm256_loadu_si256((const __m256i *)src);
return _mm256_loadu_si256((void*)src);
}

INLINE __m512i loadu_512(const uint8_t src[64]) {
return _mm512_loadu_si512((const __m512i *)src);
return _mm512_loadu_si512((void*)src);
}

INLINE void storeu_128(__m128i src, uint8_t dest[16]) {
_mm_storeu_si128((__m128i *)dest, src);
_mm_storeu_si128((void*)dest, src);
}

INLINE void storeu_256(__m256i src, uint8_t dest[16]) {
_mm256_storeu_si256((__m256i *)dest, src);
_mm256_storeu_si256((void*)dest, src);
}

INLINE void storeu_512(__m512i src, uint8_t dest[16]) {
_mm512_storeu_si512((void*)dest, src);
}

INLINE __m128i add_128(__m128i a, __m128i b) { return _mm_add_epi32(a, b); }
Expand Down Expand Up @@ -550,6 +554,54 @@ void blake3_hash4_avx512(const uint8_t *const *inputs, size_t blocks,
storeu_128(h_vecs[7], &out[7 * sizeof(__m128i)]);
}

static
void blake3_xof4_avx512(const uint32_t cv[8],
const uint8_t block[BLAKE3_BLOCK_LEN],
uint8_t block_len, uint64_t counter, uint8_t flags,
uint8_t out[4 * 64]) {
__m128i h_vecs[8] = {
set1_128(cv[0]), set1_128(cv[1]), set1_128(cv[2]), set1_128(cv[3]),
set1_128(cv[4]), set1_128(cv[5]), set1_128(cv[6]), set1_128(cv[7]),
};
uint32_t block_words[16];
load_block_words(block, block_words);
__m128i msg_vecs[16];
for (size_t i = 0; i < 16; i++) {
msg_vecs[i] = set1_128(block_words[i]);
}
__m128i counter_low_vec, counter_high_vec;
load_counters4(counter, true, &counter_low_vec, &counter_high_vec);
__m128i block_len_vec = set1_128(block_len);
__m128i block_flags_vec = set1_128(flags);
__m128i v[16] = {
h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3],
h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7],
set1_128(IV[0]), set1_128(IV[1]), set1_128(IV[2]), set1_128(IV[3]),
counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec,
};
round_fn4(v, msg_vecs, 0);
round_fn4(v, msg_vecs, 1);
round_fn4(v, msg_vecs, 2);
round_fn4(v, msg_vecs, 3);
round_fn4(v, msg_vecs, 4);
round_fn4(v, msg_vecs, 5);
round_fn4(v, msg_vecs, 6);
for (size_t i = 0; i < 8; i++) {
v[i] = xor_128(v[i], v[i+8]);
v[i+8] = xor_128(v[i+8], h_vecs[i]);
}
transpose_vecs_128(&v[0]);
transpose_vecs_128(&v[4]);
transpose_vecs_128(&v[8]);
transpose_vecs_128(&v[12]);
for (size_t i = 0; i < 4; i++) {
storeu_128(v[i+ 0], &out[(4*i+0) * sizeof(__m128i)]);
storeu_128(v[i+ 4], &out[(4*i+1) * sizeof(__m128i)]);
storeu_128(v[i+ 8], &out[(4*i+2) * sizeof(__m128i)]);
storeu_128(v[i+12], &out[(4*i+3) * sizeof(__m128i)]);
}
}

/*
* ----------------------------------------------------------------------------
* hash8_avx512
Expand Down Expand Up @@ -802,6 +854,50 @@ void blake3_hash8_avx512(const uint8_t *const *inputs, size_t blocks,
storeu_256(h_vecs[7], &out[7 * sizeof(__m256i)]);
}

static
void blake3_xof8_avx512(const uint32_t cv[8],
const uint8_t block[BLAKE3_BLOCK_LEN],
uint8_t block_len, uint64_t counter, uint8_t flags,
uint8_t out[8 * 64]) {
__m256i h_vecs[8] = {
set1_256(cv[0]), set1_256(cv[1]), set1_256(cv[2]), set1_256(cv[3]),
set1_256(cv[4]), set1_256(cv[5]), set1_256(cv[6]), set1_256(cv[7]),
};
uint32_t block_words[16];
load_block_words(block, block_words);
__m256i msg_vecs[16];
for (size_t i = 0; i < 16; i++) {
msg_vecs[i] = set1_256(block_words[i]);
}
__m256i counter_low_vec, counter_high_vec;
load_counters8(counter, true, &counter_low_vec, &counter_high_vec);
__m256i block_len_vec = set1_256(block_len);
__m256i block_flags_vec = set1_256(flags);
__m256i v[16] = {
h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3],
h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7],
set1_256(IV[0]), set1_256(IV[1]), set1_256(IV[2]), set1_256(IV[3]),
counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec,
};
round_fn8(v, msg_vecs, 0);
round_fn8(v, msg_vecs, 1);
round_fn8(v, msg_vecs, 2);
round_fn8(v, msg_vecs, 3);
round_fn8(v, msg_vecs, 4);
round_fn8(v, msg_vecs, 5);
round_fn8(v, msg_vecs, 6);
for (size_t i = 0; i < 8; i++) {
v[i] = xor_256(v[i], v[i+8]);
v[i+8] = xor_256(v[i+8], h_vecs[i]);
}
transpose_vecs_256(&v[0]);
transpose_vecs_256(&v[8]);
for (size_t i = 0; i < 8; i++) {
storeu_256(v[i+0], &out[(2*i+0) * sizeof(__m256i)]);
storeu_256(v[i+8], &out[(2*i+1) * sizeof(__m256i)]);
}
}

/*
* ----------------------------------------------------------------------------
* hash16_avx512
Expand Down Expand Up @@ -1146,6 +1242,48 @@ void blake3_hash16_avx512(const uint8_t *const *inputs, size_t blocks,
_mm256_mask_storeu_epi32(&out[15 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[15]));
}

static
void blake3_xof16_avx512(const uint32_t cv[8],
const uint8_t block[BLAKE3_BLOCK_LEN],
uint8_t block_len, uint64_t counter, uint8_t flags,
uint8_t out[16 * 64]) {
__m512i h_vecs[8] = {
set1_512(cv[0]), set1_512(cv[1]), set1_512(cv[2]), set1_512(cv[3]),
set1_512(cv[4]), set1_512(cv[5]), set1_512(cv[6]), set1_512(cv[7]),
};
uint32_t block_words[16];
load_block_words(block, block_words);
__m512i msg_vecs[16];
for (size_t i = 0; i < 16; i++) {
msg_vecs[i] = set1_512(block_words[i]);
}
__m512i counter_low_vec, counter_high_vec;
load_counters16(counter, true, &counter_low_vec, &counter_high_vec);
__m512i block_len_vec = set1_512(block_len);
__m512i block_flags_vec = set1_512(flags);
__m512i v[16] = {
h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3],
h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7],
set1_512(IV[0]), set1_512(IV[1]), set1_512(IV[2]), set1_512(IV[3]),
counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec,
};
round_fn16(v, msg_vecs, 0);
round_fn16(v, msg_vecs, 1);
round_fn16(v, msg_vecs, 2);
round_fn16(v, msg_vecs, 3);
round_fn16(v, msg_vecs, 4);
round_fn16(v, msg_vecs, 5);
round_fn16(v, msg_vecs, 6);
for (size_t i = 0; i < 8; i++) {
v[i] = xor_512(v[i], v[i+8]);
v[i+8] = xor_512(v[i+8], h_vecs[i]);
}
transpose_vecs_512(&v[0]);
for (size_t i = 0; i < 16; i++) {
storeu_512(v[i], &out[i * sizeof(__m512i)]);
}
}

/*
* ----------------------------------------------------------------------------
* hash_many_avx512
Expand Down Expand Up @@ -1218,3 +1356,33 @@ void blake3_hash_many_avx512(const uint8_t *const *inputs, size_t num_inputs,
out = &out[BLAKE3_OUT_LEN];
}
}

void blake3_xof_many_avx512(const uint32_t cv[8],
const uint8_t block[BLAKE3_BLOCK_LEN],
uint8_t block_len, uint64_t counter, uint8_t flags,
uint8_t* out, size_t outblocks) {
while (outblocks >= 16) {
blake3_xof16_avx512(cv, block, block_len, counter, flags, out);
counter += 16;
outblocks -= 16;
out += 16 * BLAKE3_BLOCK_LEN;
}
while (outblocks >= 8) {
blake3_xof8_avx512(cv, block, block_len, counter, flags, out);
counter += 8;
outblocks -= 8;
out += 8 * BLAKE3_BLOCK_LEN;
}
while (outblocks >= 4) {
blake3_xof4_avx512(cv, block, block_len, counter, flags, out);
counter += 4;
outblocks -= 4;
out += 4 * BLAKE3_BLOCK_LEN;
}
while (outblocks > 0) {
blake3_compress_xof_avx512(cv, block, block_len, counter, flags, out);
counter += 1;
outblocks -= 1;
out += BLAKE3_BLOCK_LEN;
}
}
Loading