Skip to content

Commit

Permalink
Added quint8 and qint8 for mean_nd in xnnpack_delegate.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675950736
  • Loading branch information
Misha Gutman authored and xnnpack-bot committed Sep 26, 2024
1 parent 2286715 commit d655271
Show file tree
Hide file tree
Showing 2 changed files with 370 additions and 2 deletions.
35 changes: 34 additions & 1 deletion src/subgraph/static-mean.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,20 @@ static enum xnn_status create_mean_operator(
const int8_t output_zero_point = (int8_t) values[output_id].quantization.zero_point;

status = xnn_create_mean_nd_qs8(
input_scale * output_scale, input_zero_point, output_zero_point,
input_scale / output_scale, input_zero_point, output_zero_point,
node->flags,
&opdata->operator_objects[0]);
break;
}
case xnn_datatype_quint8:
{
const float input_scale = values[input_id].quantization.scale;
const float output_scale = values[output_id].quantization.scale;
const uint8_t input_zero_point = (uint8_t) values[input_id].quantization.zero_point;
const uint8_t output_zero_point = (uint8_t) values[output_id].quantization.zero_point;

status = xnn_create_mean_nd_qu8(
input_scale / output_scale, input_zero_point, output_zero_point,
node->flags,
&opdata->operator_objects[0]);
break;
Expand Down Expand Up @@ -126,6 +139,17 @@ static enum xnn_status reshape_mean_operator(
&opdata->workspace_alignment,
threadpool);
break;
case xnn_operator_type_mean_nd_qu8:
status = xnn_reshape_mean_nd_qu8(
opdata->operator_objects[0],
opdata->num_reduction_axes,
opdata->reduction_axes,
input_value->shape.num_dims,
input_value->shape.dim,
&opdata->workspace_size,
&opdata->workspace_alignment,
threadpool);
break;
default:
XNN_UNREACHABLE;
}
Expand Down Expand Up @@ -212,6 +236,11 @@ static enum xnn_status setup_mean_operator(
opdata->operator_objects[0],
opdata->workspace,
input_data, output_data);
case xnn_operator_type_mean_nd_qu8:
return xnn_setup_mean_nd_qu8(
opdata->operator_objects[0],
opdata->workspace,
input_data, output_data);
default:
XNN_UNREACHABLE;
}
Expand Down Expand Up @@ -245,6 +274,7 @@ enum xnn_status xnn_define_static_mean(
case xnn_datatype_fp16:
case xnn_datatype_fp32:
case xnn_datatype_qint8:
case xnn_datatype_quint8:
break;
default:
xnn_log_error(
Expand Down Expand Up @@ -276,6 +306,9 @@ enum xnn_status xnn_define_static_mean(
case xnn_datatype_qint8:
compute_type = xnn_compute_type_qs8;
break;
case xnn_datatype_quint8:
compute_type = xnn_compute_type_qu8;
break;
default:
xnn_log_error(
"failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
Expand Down
Loading

0 comments on commit d655271

Please sign in to comment.