From 2c0a33ad514763b081be8fc5e7be55d7b2d05cc1 Mon Sep 17 00:00:00 2001 From: Xu Jun Date: Sat, 14 Sep 2024 16:16:43 +0800 Subject: [PATCH] implement qs8 x8c8 pack using avxvnniint8 --- cmake/gen/avxvnniint8_microkernels.cmake | 3 +- gen/avxvnniint8_microkernels.bzl | 1 + scripts/generate-x8-packw.sh | 2 + src/configs/hardware-config.c | 1 + .../gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c | 441 ++++++++++++++++++ src/qs8-packw/qs8-packw.h | 2 + src/x8-packw/kr-avxvnniint8.c.in | 398 ++++++++++++++++ src/xnnpack/hardware-config.h | 1 + 8 files changed, 848 insertions(+), 1 deletion(-) create mode 100644 src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c create mode 100644 src/x8-packw/kr-avxvnniint8.c.in diff --git a/cmake/gen/avxvnniint8_microkernels.cmake b/cmake/gen/avxvnniint8_microkernels.cmake index 60e0412af5f..ee2ac244a81 100644 --- a/cmake/gen/avxvnniint8_microkernels.cmake +++ b/cmake/gen/avxvnniint8_microkernels.cmake @@ -15,6 +15,7 @@ SET(PROD_AVXVNNIINT8_MICROKERNEL_SRCS src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8c8-minmax-fp32-avxvnniint8-prfm.c src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x8c8-minmax-fp32-avxvnniint8-prfm.c) -SET(NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS) +SET(NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS + src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c) SET(ALL_AVXVNNIINT8_MICROKERNEL_SRCS ${PROD_AVXVNNIINT8_MICROKERNEL_SRCS} + ${NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS}) diff --git a/gen/avxvnniint8_microkernels.bzl b/gen/avxvnniint8_microkernels.bzl index 97f1657ee4a..a1b149b017c 100644 --- a/gen/avxvnniint8_microkernels.bzl +++ b/gen/avxvnniint8_microkernels.bzl @@ -13,6 +13,7 @@ PROD_AVXVNNIINT8_MICROKERNEL_SRCS = [ ] NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS = [ + "src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c", ] ALL_AVXVNNIINT8_MICROKERNEL_SRCS = PROD_AVXVNNIINT8_MICROKERNEL_SRCS + NON_PROD_AVXVNNIINT8_MICROKERNEL_SRCS diff --git a/scripts/generate-x8-packw.sh b/scripts/generate-x8-packw.sh index 281bba59237..69e9b42ee1c 100755 --- a/scripts/generate-x8-packw.sh +++ b/scripts/generate-x8-packw.sh @@ -23,4 +23,6 @@ tools/xngen src/x8-packw/kr-scalar.c.in -D NR=16 -D KR=4 -D TYPE=int8_t -o src/q tools/xngen src/x8-packw/kr-scalar.c.in -D NR=32 -D KR=4 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c & tools/xngen src/x8-packw/kr-scalar.c.in -D NR=64 -D KR=4 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-scalar.c & +tools/xngen src/x8-packw/kr-avxvnniint8.c.in -D NR=8 -D KR=8 -D TYPE=int8_t -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c & + wait diff --git a/src/configs/hardware-config.c b/src/configs/hardware-config.c index 573c132d06d..8954b6b8d69 100644 --- a/src/configs/hardware-config.c +++ b/src/configs/hardware-config.c @@ -324,6 +324,7 @@ static void init_hardware_config(void) { if (hardware_config.use_x86_avx256skx) hardware_config.arch_flags |= xnn_arch_x86_avx256skx; if (hardware_config.use_x86_avx256vnni) hardware_config.arch_flags |= xnn_arch_x86_avx256vnni; if (hardware_config.use_x86_avx256vnnigfni) hardware_config.arch_flags |= xnn_arch_x86_avx256vnnigfni; + if (hardware_config.use_x86_avxvnniint8) hardware_config.arch_flags |= xnn_arch_x86_avxvnniint8; #endif #if XNN_ARCH_RISCV if (hardware_config.use_riscv_vector) hardware_config.arch_flags |= xnn_arch_riscv_vector; diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c new file mode 100644 index 00000000000..807b27ae4fa --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnniint8.c @@ -0,0 +1,441 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnniint8.c.in +// Generator: tools/xngen +// +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" + +void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnniint8( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + const int8_t izp = params ? ((const struct xnn_qs8_packw_params*) params)->input_zero_point : 0; + __m256i vzeropoint = _mm256_set1_epi8(izp); + + do { + // NC main loop multiple of 8 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 8; n -= 8) { + __m256i vacc0124x8 = _mm256_setzero_si256(); + __m256i vacc4567x8 = _mm256_setzero_si256(); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb = _mm256_loadu_si256((const __m256i*) b); + _mm256_storeu_si256((__m256i*) out, vb); + b += 8; + } else { + _mm256_storeu_si256((__m256i*) out, _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0123x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w0)); + v0123x8 = _mm256_insert_epi64(v0123x8, *(int64_t *)w1, 1); + v0123x8 = _mm256_insert_epi64(v0123x8, *(int64_t *)w2, 2); + v0123x8 = _mm256_insert_epi64(v0123x8, *(int64_t *)w3, 3); + + __m256i v4567x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w4)); + v4567x8 = _mm256_insert_epi64(v4567x8, *(int64_t *)w5, 1); + v4567x8 = _mm256_insert_epi64(v4567x8, *(int64_t *)w6, 2); + v4567x8 = _mm256_insert_epi64(v4567x8, *(int64_t *)w7, 3); + + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder 1..KR-1 + if (k != 0) { + __m256i v0123x8 = vzeropoint; + __m256i v4567x8 = vzeropoint; + + if (k & 4) { + v0123x8 = _mm256_insert_epi32(v0123x8, *(int32_t *)w0, 0); + v0123x8 = _mm256_insert_epi32(v0123x8, *(int32_t *)w1, 2); + v0123x8 = _mm256_insert_epi32(v0123x8, *(int32_t *)w2, 4); + v0123x8 = _mm256_insert_epi32(v0123x8, *(int32_t *)w3, 6); + + v4567x8 = _mm256_insert_epi32(v4567x8, *(int32_t *)w4, 0); + v4567x8 = _mm256_insert_epi32(v4567x8, *(int32_t *)w5, 2); + v4567x8 = _mm256_insert_epi32(v4567x8, *(int32_t *)w6, 4); + v4567x8 = _mm256_insert_epi32(v4567x8, *(int32_t *)w7, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + } + if (k & 2) { + if (k & 4) { + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w0, 2); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w1, 6); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w2, 10); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w3, 14); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w4, 2); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w5, 6); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w6, 10); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w7, 14); + } else { + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w0, 0); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w1, 4); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w2, 8); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w3, 12); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w4, 0); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w5, 4); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w6, 8); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w7, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w0, 6); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w1, 14); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w2, 22); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w3, 30); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w4, 6); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w5, 14); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w6, 22); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w7, 30); + } + else if (k & 4) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w0, 4); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w1, 12); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w2, 20); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w3, 28); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w4, 4); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w5, 12); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w6, 20); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w7, 28); + } + else if (k & 2) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w0, 2); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w1, 10); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w2, 18); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w3, 26); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w4, 2); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w5, 10); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w6, 18); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w7, 26); + } + else { + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w0, 0); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w1, 8); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w2, 16); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w3, 24); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w4, 0); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w5, 8); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w6, 16); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w7, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + } + + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + + out += 64; + } + + __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); + vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); + _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w7; + } + + // NC remainder (1..7) + if XNN_UNLIKELY(n != 0) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((int32_t*) out) = *b++; + out += sizeof(int32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((int32_t*) out) = 0; + out += sizeof(int32_t); + } while (--nb != 0); + } + out += (8 - n) * sizeof(int32_t); + + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + + __m256i vacc0124x8 = _mm256_setzero_si256(); + __m256i vacc4567x8 = _mm256_setzero_si256(); + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0123x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w0)); + v0123x8 = _mm256_insert_epi64(v0123x8, *(int64_t *)w1, 1); + v0123x8 = _mm256_insert_epi64(v0123x8, *(int64_t *)w2, 2); + v0123x8 = _mm256_insert_epi64(v0123x8, *(int64_t *)w3, 3); + + __m256i v4567x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w4)); + v4567x8 = _mm256_insert_epi64(v4567x8, *(int64_t *)w5, 1); + v4567x8 = _mm256_insert_epi64(v4567x8, *(int64_t *)w6, 2); + v4567x8 = _mm256_insert_epi64(v4567x8, *(int64_t *)w7, 3); + + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + __m256i v0123x8 = vzeropoint; + __m256i v4567x8 = vzeropoint; + + if (k & 4) { + v0123x8 = _mm256_insert_epi32(v0123x8, *(int32_t *)w0, 0); + v0123x8 = _mm256_insert_epi32(v0123x8, *(int32_t *)w1, 2); + v0123x8 = _mm256_insert_epi32(v0123x8, *(int32_t *)w2, 4); + v0123x8 = _mm256_insert_epi32(v0123x8, *(int32_t *)w3, 6); + + v4567x8 = _mm256_insert_epi32(v4567x8, *(int32_t *)w4, 0); + v4567x8 = _mm256_insert_epi32(v4567x8, *(int32_t *)w5, 2); + v4567x8 = _mm256_insert_epi32(v4567x8, *(int32_t *)w6, 4); + v4567x8 = _mm256_insert_epi32(v4567x8, *(int32_t *)w7, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + } + if (k & 2) { + if (k & 4) { + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w0, 2); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w1, 6); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w2, 10); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w3, 14); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w4, 2); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w5, 6); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w6, 10); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w7, 14); + } else { + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w0, 0); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w1, 4); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w2, 8); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w3, 12); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w4, 0); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w5, 4); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w6, 8); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w7, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w0, 6); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w1, 14); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w2, 22); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w3, 30); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w4, 6); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w5, 14); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w6, 22); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w7, 30); + } + else if (k & 4) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w0, 4); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w1, 12); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w2, 20); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w3, 28); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w4, 4); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w5, 12); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w6, 20); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w7, 28); + } + else if (k & 2) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w0, 2); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w1, 10); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w2, 18); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w3, 26); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w4, 2); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w5, 10); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w6, 18); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w7, 26); + } + else { + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w0, 0); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w1, 8); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w2, 16); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w3, 24); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w4, 0); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w5, 8); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w6, 16); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w7, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + } + + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + + out += 64; + } + + __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); + vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); + _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-packw/qs8-packw.h b/src/qs8-packw/qs8-packw.h index d0d55e50e31..0037eb49ac8 100644 --- a/src/qs8-packw/qs8-packw.h +++ b/src/qs8-packw/qs8-packw.h @@ -21,6 +21,8 @@ XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x16c4__scalar, 16, 4, 1, 4, 1) XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x32c4__scalar, 32, 4, 1, 4, 1) XNN_QS8_UKERNEL(0, xnn_qs8_packw_gemm_goi_ukernel_x64c4__scalar, 64, 4, 1, 4, 1) +XNN_QS8_UKERNEL(xnn_arch_x86_avxvnniint8, xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnniint8, 8, 8, 1, 8, 1) + #ifdef XNN_DEFINED_UKERNEL_WITH_PARAMS #undef XNN_DEFINED_UKERNEL_WITH_PARAMS #undef XNN_QS8_UKERNEL_WITH_PARAMS diff --git a/src/x8-packw/kr-avxvnniint8.c.in b/src/x8-packw/kr-avxvnniint8.c.in new file mode 100644 index 00000000000..eb666a5dc6c --- /dev/null +++ b/src/x8-packw/kr-avxvnniint8.c.in @@ -0,0 +1,398 @@ +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$assert NR > 1 +$assert KR > 1 +$assert TYPE in ["int8_t"] + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" + +$BITS = {"int8_t": 8, "uint16_t": 16, "uint32_t": 32, "float": 32}[TYPE] +$BTYPE = {"int8_t": "int32_t", "uint16_t": "uint16_t", "uint32_t": "uint32_t", "float": "float"}[TYPE] +$WTYPE = {"int8_t": "int8_t", "uint16_t": "uint16_t", "uint32_t": "uint32_t", "float": "uint32_t"}[TYPE] +void xnn_qs${BITS}_packw_gemm_goi_ukernel_x${NR}c${KR}__avxvnniint8( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const ${WTYPE}* weights, + $if BITS == 8: + const int32_t* bias, + $else: + const ${WTYPE}* bias, + const void* scale, + ${WTYPE}* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == ${NR}); + assert(kr == ${KR}); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + ${TYPE}* out = (${TYPE}*) packed_weights; + const ${BTYPE}* b = (const ${BTYPE}*) bias; + $if BITS == 8: + const int8_t izp = params ? ((const struct xnn_qs8_packw_params*) params)->input_zero_point : 0; + __m256i vzeropoint = _mm256_set1_epi8(izp); + + do { + // NC main loop multiple of ${NR} + const ${TYPE}* w0 = (const ${TYPE}*) weights; + size_t n = nc; + for (;n >= ${NR}; n -= ${NR}) { + __m256i vacc0124x8 = _mm256_setzero_si256(); + __m256i vacc4567x8 = _mm256_setzero_si256(); + + $if BITS == 8: + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb = _mm256_loadu_si256((const __m256i*) b); + _mm256_storeu_si256((__m256i*) out, vb); + b += ${NR}; + } else { + _mm256_storeu_si256((__m256i*) out, _mm256_setzero_si256()); + } + $if BTYPE == TYPE: + out += ${NR}; + $else: + out += ${NR} * sizeof(${BTYPE}); + + $for N in range(1, NR): + const ${TYPE}* w${N} = w${N-1} + kc; + + // KC main loop multiple of ${NR}x${KR} + size_t k = kc; + for (; k >= ${KR}; k -= ${KR}) { + __m256i v0123x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w0)); + v0123x8 = _mm256_insert_epi64(v0123x8, *(int64_t *)w1, 1); + v0123x8 = _mm256_insert_epi64(v0123x8, *(int64_t *)w2, 2); + v0123x8 = _mm256_insert_epi64(v0123x8, *(int64_t *)w3, 3); + + __m256i v4567x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w4)); + v4567x8 = _mm256_insert_epi64(v4567x8, *(int64_t *)w5, 1); + v4567x8 = _mm256_insert_epi64(v4567x8, *(int64_t *)w6, 2); + v4567x8 = _mm256_insert_epi64(v4567x8, *(int64_t *)w7, 3); + + $if BITS == 8: + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + + $for N in range(NR): + w${N} += ${KR}; + out += ${NR*KR}; + } + + // KC remainder 1..KR-1 + if (k != 0) { + __m256i v0123x8 = vzeropoint; + __m256i v4567x8 = vzeropoint; + + if (k & 4) { + v0123x8 = _mm256_insert_epi32(v0123x8, *(int32_t *)w0, 0); + v0123x8 = _mm256_insert_epi32(v0123x8, *(int32_t *)w1, 2); + v0123x8 = _mm256_insert_epi32(v0123x8, *(int32_t *)w2, 4); + v0123x8 = _mm256_insert_epi32(v0123x8, *(int32_t *)w3, 6); + + v4567x8 = _mm256_insert_epi32(v4567x8, *(int32_t *)w4, 0); + v4567x8 = _mm256_insert_epi32(v4567x8, *(int32_t *)w5, 2); + v4567x8 = _mm256_insert_epi32(v4567x8, *(int32_t *)w6, 4); + v4567x8 = _mm256_insert_epi32(v4567x8, *(int32_t *)w7, 6); + $for N in range(NR): + w${N} += 4; + } + if (k & 2) { + if (k & 4) { + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w0, 2); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w1, 6); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w2, 10); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w3, 14); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w4, 2); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w5, 6); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w6, 10); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w7, 14); + } else { + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w0, 0); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w1, 4); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w2, 8); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w3, 12); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w4, 0); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w5, 4); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w6, 8); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w7, 12); + } + + $for N in range(NR): + w${N} += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w0, 6); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w1, 14); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w2, 22); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w3, 30); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w4, 6); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w5, 14); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w6, 22); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w7, 30); + } + else if (k & 4) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w0, 4); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w1, 12); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w2, 20); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w3, 28); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w4, 4); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w5, 12); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w6, 20); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w7, 28); + } + else if (k & 2) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w0, 2); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w1, 10); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w2, 18); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w3, 26); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w4, 2); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w5, 10); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w6, 18); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w7, 26); + } + else { + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w0, 0); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w1, 8); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w2, 16); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w3, 24); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w4, 0); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w5, 8); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w6, 16); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w7, 24); + } + + $for N in range(NR): + w${N} += 1; + } + + $if BITS == 8: + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + + out += ${NR*KR}; + } + + $if BITS == 8: + __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); + vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); + _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + out = (${TYPE}*) ((uintptr_t) out + extra_bytes); + w0 = w${NR-1}; + } + + // NC remainder (1..${NR-1}) + if XNN_UNLIKELY(n != 0) { + $if BITS == 8: + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + $if BTYPE == TYPE: + *out++ = *b++; + $else: + *((${BTYPE}*) out) = *b++; + out += sizeof(${BTYPE}); + } while (--nb != 0); + } else { + size_t nb = n; + do { + $if BTYPE == TYPE: + *out++ = 0; + $else: + *((${BTYPE}*) out) = 0; + out += sizeof(${BTYPE}); + } while (--nb != 0); + } + $if BTYPE == TYPE: + out += (${NR} - n); + $else: + out += (${NR} - n) * sizeof(${BTYPE}); + + $if NR > 2: + $for N in range(1, NR): + const ${TYPE}* w${N} = w${N-1} + kc; + $if N % 2 == 0: + if XNN_UNPREDICTABLE(n <= ${N}) { + w${N} = w${N-1}; + } + $else: + if XNN_UNPREDICTABLE(n < ${N+1}) { + w${N} = w${N-1}; + } + + $if BITS == 8: + __m256i vacc0124x8 = _mm256_setzero_si256(); + __m256i vacc4567x8 = _mm256_setzero_si256(); + + // KC main loop multiple of ${NR}x${KR} + size_t k = kc; + for (; k >= ${KR}; k -= ${KR}) { + __m256i v0123x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w0)); + v0123x8 = _mm256_insert_epi64(v0123x8, *(int64_t *)w1, 1); + v0123x8 = _mm256_insert_epi64(v0123x8, *(int64_t *)w2, 2); + v0123x8 = _mm256_insert_epi64(v0123x8, *(int64_t *)w3, 3); + + __m256i v4567x8 = _mm256_castsi128_si256(_mm_loadl_epi64((const __m128i *)w4)); + v4567x8 = _mm256_insert_epi64(v4567x8, *(int64_t *)w5, 1); + v4567x8 = _mm256_insert_epi64(v4567x8, *(int64_t *)w6, 2); + v4567x8 = _mm256_insert_epi64(v4567x8, *(int64_t *)w7, 3); + + $if BITS == 8: + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + + $for N in range(NR): + w${N} += ${KR}; + out += ${NR*KR}; + } + + // KC remainder of 1..${KR-1} + if (k != 0) { + __m256i v0123x8 = vzeropoint; + __m256i v4567x8 = vzeropoint; + + if (k & 4) { + v0123x8 = _mm256_insert_epi32(v0123x8, *(int32_t *)w0, 0); + v0123x8 = _mm256_insert_epi32(v0123x8, *(int32_t *)w1, 2); + v0123x8 = _mm256_insert_epi32(v0123x8, *(int32_t *)w2, 4); + v0123x8 = _mm256_insert_epi32(v0123x8, *(int32_t *)w3, 6); + + v4567x8 = _mm256_insert_epi32(v4567x8, *(int32_t *)w4, 0); + v4567x8 = _mm256_insert_epi32(v4567x8, *(int32_t *)w5, 2); + v4567x8 = _mm256_insert_epi32(v4567x8, *(int32_t *)w6, 4); + v4567x8 = _mm256_insert_epi32(v4567x8, *(int32_t *)w7, 6); + $for N in range(NR): + w${N} += 4; + } + if (k & 2) { + if (k & 4) { + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w0, 2); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w1, 6); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w2, 10); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w3, 14); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w4, 2); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w5, 6); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w6, 10); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w7, 14); + } else { + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w0, 0); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w1, 4); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w2, 8); + v0123x8 = _mm256_insert_epi16(v0123x8, *(int16_t *)w3, 12); + + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w4, 0); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w5, 4); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w6, 8); + v4567x8 = _mm256_insert_epi16(v4567x8, *(int16_t *)w7, 12); + } + + $for N in range(NR): + w${N} += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w0, 6); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w1, 14); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w2, 22); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w3, 30); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w4, 6); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w5, 14); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w6, 22); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w7, 30); + } + else if (k & 4) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w0, 4); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w1, 12); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w2, 20); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w3, 28); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w4, 4); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w5, 12); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w6, 20); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w7, 28); + } + else if (k & 2) { + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w0, 2); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w1, 10); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w2, 18); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w3, 26); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w4, 2); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w5, 10); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w6, 18); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w7, 26); + } + else { + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w0, 0); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w1, 8); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w2, 16); + v0123x8 = _mm256_insert_epi8(v0123x8, *(int8_t *)w3, 24); + + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w4, 0); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w5, 8); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w6, 16); + v4567x8 = _mm256_insert_epi8(v4567x8, *(int8_t *)w7, 24); + } + + $for N in range(NR): + w${N} += 1; + } + + $if BITS == 8: + vacc0124x8 = _mm256_dpbssd_epi32(vacc0124x8, v0123x8, vzeropoint); + vacc4567x8 = _mm256_dpbssd_epi32(vacc4567x8, v4567x8, vzeropoint); + + _mm256_storeu_si256((__m256i *)&out[0], v0123x8); + _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + + out += ${NR*KR}; + } + + $if BITS == 8: + __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); + vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); + _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + out = (${TYPE}*) ((uintptr_t) out + extra_bytes); + } + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/xnnpack/hardware-config.h b/src/xnnpack/hardware-config.h index 1268cab5932..2114acae934 100644 --- a/src/xnnpack/hardware-config.h +++ b/src/xnnpack/hardware-config.h @@ -51,6 +51,7 @@ enum xnn_arch_flags { xnn_arch_x86_avx256skx = 1 << 14, xnn_arch_x86_avx256vnni = 1 << 15, xnn_arch_x86_avx256vnnigfni = 1 << 16, + xnn_arch_x86_avxvnniint8 = 1 << 17, #endif #if XNN_ARCH_RISCV xnn_arch_riscv_vector = 1 << 0,