diff --git a/onert-micro/onert-micro/include/OMConfig.h b/onert-micro/onert-micro/include/OMConfig.h index 0ffbf024b1f..2038b57a1e7 100644 --- a/onert-micro/onert-micro/include/OMConfig.h +++ b/onert-micro/onert-micro/include/OMConfig.h @@ -41,6 +41,7 @@ enum OMMetrics MAE_METRICS, CROSS_ENTROPY_METRICS, ACCURACY, + SPARSE_CROSS_ENTROPY_ACCURACY, }; /* diff --git a/onert-micro/onert-micro/include/train/metrics/SparseCrossEntropyAccuracy.h b/onert-micro/onert-micro/include/train/metrics/SparseCrossEntropyAccuracy.h new file mode 100644 index 00000000000..64c73ae20ea --- /dev/null +++ b/onert-micro/onert-micro/include/train/metrics/SparseCrossEntropyAccuracy.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_TRAIN_METRICS_SPARSE_CROSS_ENTROPY_ACCURACY_H +#define ONERT_MICRO_TRAIN_METRICS_SPARSE_CROSS_ENTROPY_ACCURACY_H + +#include "OMStatus.h" + +#include + +namespace onert_micro +{ +namespace train +{ +namespace metrics +{ + +// Accuracy metric +struct SparseCrossEntropyAccuracy +{ + // Calculate accuracy metric between calculated and target data + static float calculateValue(const uint32_t flat_size, const float *calculated_data, + const float *target_data); +}; + +} // namespace metrics +} // namespace train +} // namespace onert_micro + +#endif // ONERT_MICRO_TRAIN_METRICS_SPARSE_CROSS_ENTROPY_ACCURACY_H diff --git a/onert-micro/onert-micro/src/core/train/OMTrainingHandler.cpp b/onert-micro/onert-micro/src/core/train/OMTrainingHandler.cpp index 24535625a02..8ab4a9eb60b 100644 --- a/onert-micro/onert-micro/src/core/train/OMTrainingHandler.cpp +++ b/onert-micro/onert-micro/src/core/train/OMTrainingHandler.cpp @@ -23,6 +23,7 @@ #include "train/metrics/CrossEntropy.h" #include "train/metrics/Accuracy.h" #include "train/metrics/MAE.h" +#include "train/metrics/SparseCrossEntropyAccuracy.h" using namespace onert_micro::core::train; using namespace onert_micro::core; @@ -222,6 +223,12 @@ OMStatus OMTrainingHandler::evaluateMetric(OMMetrics metric, void *metric_val, assert(calculated_data != nullptr); // Get target data + /** NOTE: + * This offset will always return 0 if the MODEL OUTPUT is returning 1 value of prediction. + * (forward_output->size() == length of output vector.) + * one-hot: size == target_numbers + * Sparse cross : size == 1 + */ size_t offset = batch_num * sizeof(core::OMDataType(forward_output_tensor->type())) * flat_size; uint8_t *target_data = _training_storage.getTargetData(i) + offset; @@ -261,6 +268,14 @@ OMStatus OMTrainingHandler::evaluateMetric(OMMetrics metric, void *metric_val, reinterpret_cast(target_data)); break; } + case SPARSE_CROSS_ENTROPY_ACCURACY: + { + // Note: sum up new calculated value for current sample + *f_metric_val += metrics::SparseCrossEntropyAccuracy::calculateValue( + flat_size, reinterpret_cast(calculated_data), + reinterpret_cast(target_data)); + break; + } default: { assert(false && "Unsupported loss type"); diff --git a/onert-micro/onert-micro/src/train/CMakeLists.txt b/onert-micro/onert-micro/src/train/CMakeLists.txt index 6374a9dfc9e..1df3a864ef3 100644 --- a/onert-micro/onert-micro/src/train/CMakeLists.txt +++ b/onert-micro/onert-micro/src/train/CMakeLists.txt @@ -18,6 +18,7 @@ set(SOURCES metrics/MAE.cpp metrics/MSE.cpp metrics/Accuracy.cpp + metrics/SparseCrossEntropyAccuracy.cpp train_optimizers/SGD.cpp train_optimizers/Adam.cpp ) diff --git a/onert-micro/onert-micro/src/train/metrics/SparseCrossEntropyAccuracy.cpp b/onert-micro/onert-micro/src/train/metrics/SparseCrossEntropyAccuracy.cpp new file mode 100644 index 00000000000..260e66ba2fe --- /dev/null +++ b/onert-micro/onert-micro/src/train/metrics/SparseCrossEntropyAccuracy.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "train/metrics/SparseCrossEntropyAccuracy.h" + +using namespace onert_micro; +using namespace onert_micro::train; +using namespace onert_micro::train::metrics; + +/* + * return 1.0 if predicted class equals to target + * return 0.0 otherwise + */ +float SparseCrossEntropyAccuracy::calculateValue(const uint32_t flat_size, + const float *calculated_data, + const float *target_data) +{ + // Find target class + uint32_t target_class = static_cast(target_data[0]); + + // Find predicted class + float pred_class = 0.f; + float pred_max_val = calculated_data[0]; + for (uint32_t i = 0; i < flat_size; ++i) + { + if (pred_max_val < calculated_data[i]) + { + pred_max_val = calculated_data[i]; + pred_class = static_cast(i); + } + } + + return pred_class == target_class ? 1.0f : 0.0f; +}