diff --git a/onert-micro/onert-micro/include/OMConfig.h b/onert-micro/onert-micro/include/OMConfig.h index 0ffbf024b1f..cce3022d6f7 100644 --- a/onert-micro/onert-micro/include/OMConfig.h +++ b/onert-micro/onert-micro/include/OMConfig.h @@ -52,6 +52,7 @@ enum OMLoss CROSS_ENTROPY, MSE, MAE, + SPARSE_CROSS_ENTROPY, }; /* diff --git a/onert-micro/onert-micro/include/train/losses_functions/SparseCrossEntropy.h b/onert-micro/onert-micro/include/train/losses_functions/SparseCrossEntropy.h new file mode 100644 index 00000000000..198cbf35dc9 --- /dev/null +++ b/onert-micro/onert-micro/include/train/losses_functions/SparseCrossEntropy.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_TRAIN_LOSSES_FUNCTIONS_SPARSE_CROSS_ENTROPY_H +#define ONERT_MICRO_TRAIN_LOSSES_FUNCTIONS_SPARSE_CROSS_ENTROPY_H + +#include "OMStatus.h" + +#include + +namespace onert_micro +{ +namespace train +{ +namespace losses_functions +{ + +// Cross Entropy +struct SparseCrossEntropy +{ + // Calculate sparse cross entropy error backpropagation between calculated and target data + static void calculateErrorBackpropagation(const uint32_t flat_size, const float *calculated_data, + const float *target_data, float *output_grad); +}; + +} // namespace losses_functions +} // namespace train +} // namespace onert_micro + +#endif // ONERT_MICRO_TRAIN_LOSSES_FUNCTIONS_SPARSE_CROSS_ENTROPY_H \ No newline at end of file diff --git a/onert-micro/onert-micro/src/core/train/OMTrainingHandler.cpp b/onert-micro/onert-micro/src/core/train/OMTrainingHandler.cpp index 24535625a02..097a39d3ba5 100644 --- a/onert-micro/onert-micro/src/core/train/OMTrainingHandler.cpp +++ b/onert-micro/onert-micro/src/core/train/OMTrainingHandler.cpp @@ -19,6 +19,7 @@ #include "core/train/OMTrainingHandler.h" #include "train/losses_functions/MSE.h" #include "train/losses_functions/CrossEntropy.h" +#include "train/losses_functions/SparseCrossEntropy.h" #include "train/metrics/MSE.h" #include "train/metrics/CrossEntropy.h" #include "train/metrics/Accuracy.h" @@ -56,11 +57,18 @@ OMStatus OMTrainingHandler::handleError(const OMConfig &config, OMRuntimeStorage OMStatus status = forward_storage.getDataByTensorIndex(&calculated_data, forward_output_index); assert(calculated_data != nullptr); + OMLoss loss_type = config.training_context.loss; + // Get target data auto data_type_size = sizeof(core::OMDataType(forward_output_tensor->type())); size_t offset = batch_num * data_type_size * flat_size; + + // Need to check loss type to control proper offset. + if (loss_type == SPARSE_CROSS_ENTROPY) + { + offset = batch_num * data_type_size; + } uint8_t *target_data = _training_storage.getTargetData(i) + offset; - OMLoss loss_type = config.training_context.loss; // Allocate data for error gradient for current calculated data and target data uint8_t *output_grad_data; @@ -85,6 +93,13 @@ OMStatus OMTrainingHandler::handleError(const OMConfig &config, OMRuntimeStorage reinterpret_cast(target_data), reinterpret_cast(output_grad_data)); break; } + case SPARSE_CROSS_ENTROPY: + { + losses_functions::SparseCrossEntropy::calculateErrorBackpropagation( + flat_size, reinterpret_cast(calculated_data), + reinterpret_cast(target_data), reinterpret_cast(output_grad_data)); + break; + } default: { assert(false && "Unsupported loss type"); diff --git a/onert-micro/onert-micro/src/train/CMakeLists.txt b/onert-micro/onert-micro/src/train/CMakeLists.txt index 6374a9dfc9e..baf3e1034b1 100644 --- a/onert-micro/onert-micro/src/train/CMakeLists.txt +++ b/onert-micro/onert-micro/src/train/CMakeLists.txt @@ -14,6 +14,7 @@ set(SOURCES OMBackpropExecutionBuilder.cpp losses_functions/MSE.cpp losses_functions/CrossEntropy.cpp + losses_functions/SparseCrossEntropy.cpp metrics/CrossEntropy.cpp metrics/MAE.cpp metrics/MSE.cpp diff --git a/onert-micro/onert-micro/src/train/losses_functions/SparseCrossEntropy.cpp b/onert-micro/onert-micro/src/train/losses_functions/SparseCrossEntropy.cpp new file mode 100644 index 00000000000..08292b680de --- /dev/null +++ b/onert-micro/onert-micro/src/train/losses_functions/SparseCrossEntropy.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "train/losses_functions/SparseCrossEntropy.h" +#include + +using namespace onert_micro; +using namespace onert_micro::train; +using namespace onert_micro::train::losses_functions; + +/* + * dE/dZi = (dE/dy) * (dy / dZi) + * where Z - vector of logits, + * y - probaility of target. + * + * Since dE/dy = -(1/y), + * (true label) if i == y : dE/dZi = py - 1 = y - 1 + * (wrong label) else : dE/dZi = pj + * + */ +void SparseCrossEntropy::calculateErrorBackpropagation(const uint32_t flat_size, + const float *calculated_data, + const float *target_data, float *output_grad) +{ + uint32_t label = static_cast(target_data[0]); + + for (uint32_t i = 0; i < flat_size; ++i) + { + output_grad[i] = (calculated_data[i] + float(10.0e-32)) - (i == label); + } +} \ No newline at end of file