Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sum_nd_f32|f16|qs8|qu8 operator. #7210

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build_srcs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ SUBGRAPH_SRCS = [
"src/subgraph/static-mean.c",
"src/subgraph/static-resize-bilinear-2d.c",
"src/subgraph/static-slice.c",
"src/subgraph/static-sum.c",
"src/subgraph/static-transpose.c",
"src/subgraph/tanh.c",
"src/subgraph/unpooling-2d.c",
Expand Down
103 changes: 103 additions & 0 deletions include/xnnpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,26 @@ enum xnn_status xnn_define_static_mean(
uint32_t output_id,
uint32_t flags);

/// Define a Sum Node and add it to a Subgraph.
///
/// @param subgraph - a Subgraph object that will own the created Node.
/// @param num_reduction_axes - number of axes along which sum is computed.
/// @param reduction_axes - axes along which sum is computed.
/// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with at least
/// @a num_reduction_axes dimensions defined in the @a subgraph.
/// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor defined in the
/// @a subgraph with @a num_reduction_axes fewer dimensions than the input tensor (if
/// XNN_FLAG_KEEP_DIMS is not specified), or has same dimension rank but the dimension at
/// @a reduction_axes reduced to 1 (if XNN_FLAG_KEEP_DIMS is specified).
/// @param flags - binary features of the Sum Node. The only currently supported value is XNN_FLAG_KEEP_DIMS
enum xnn_status xnn_define_static_sum(
xnn_subgraph_t subgraph,
size_t num_reduction_axes,
const size_t* reduction_axes,
uint32_t input_id,
uint32_t output_id,
uint32_t flags);

/// Define a 2-Input Concatenate Node and add it to a Subgraph.
///
/// The 2-Input Concatenate Node concatenates two tensors along a specified axis.
Expand Down Expand Up @@ -5080,6 +5100,89 @@ enum xnn_status xnn_setup_mean_nd_qu8(
const void* input,
void* output);

enum xnn_status xnn_create_sum_nd_f16(
uint32_t flags,
xnn_operator_t* sum_op_out);

enum xnn_status xnn_create_sum_nd_f32(
uint32_t flags,
xnn_operator_t* sum_op_out);

enum xnn_status xnn_create_sum_nd_qs8(
float scale,
int8_t input_zero_point,
int8_t output_zero_point,
uint32_t flags,
xnn_operator_t* sum_op_out);

enum xnn_status xnn_create_sum_nd_qu8(
float scale,
uint8_t input_zero_point,
uint8_t output_zero_point,
uint32_t flags,
xnn_operator_t* sum_op_out);

enum xnn_status xnn_reshape_sum_nd_f16(
xnn_operator_t sum_op,
size_t num_reduction_axes,
const size_t* reduction_axes,
size_t num_input_dims,
const size_t* input_shape,
size_t* workspace_size,
size_t* workspace_alignment,
pthreadpool_t threadpool);

enum xnn_status xnn_reshape_sum_nd_f32(
xnn_operator_t sum_op,
size_t num_reduction_axes,
const size_t* reduction_axes,
size_t num_input_dims,
const size_t* input_shape,
pthreadpool_t threadpool);

enum xnn_status xnn_reshape_sum_nd_qs8(
xnn_operator_t sum_op,
size_t num_reduction_axes,
const size_t* reduction_axes,
size_t num_input_dims,
const size_t* input_shape,
size_t* workspace_size,
size_t* workspace_alignment,
pthreadpool_t threadpool);

enum xnn_status xnn_reshape_sum_nd_qu8(
xnn_operator_t sum_op,
size_t num_reduction_axes,
const size_t* reduction_axes,
size_t num_input_dims,
const size_t* input_shape,
size_t* workspace_size,
size_t* workspace_alignment,
pthreadpool_t threadpool);

enum xnn_status xnn_setup_sum_nd_f16(
xnn_operator_t sum_op,
void* workspace,
const void* input,
void* output);

enum xnn_status xnn_setup_sum_nd_f32(
xnn_operator_t sum_op,
const float* input,
float* output);

enum xnn_status xnn_setup_sum_nd_qs8(
xnn_operator_t sum_op,
void* workspace,
const void* input,
void* output);

enum xnn_status xnn_setup_sum_nd_qu8(
xnn_operator_t sum_op,
void* workspace,
const void* input,
void* output);

enum xnn_status xnn_create_negate_nc_f16(
uint32_t flags,
xnn_operator_t* negate_op_out);
Expand Down
Loading