Skip to content

Commit

Permalink
unify generating interfaces
Browse files Browse the repository at this point in the history
ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani committed Aug 8, 2024
1 parent 917edfc commit 39ce0e7
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 98 deletions.
84 changes: 36 additions & 48 deletions runtime/onert/backend/train/BackendContext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,22 +136,50 @@ getDisposableBackPropTensorList(const ir::train::TrainableGraph &tgraph,
}
} // namespace

backend::ITensorRegistry *BackendContext::genTensors()
FunctionMap BackendContext::gen()
{
planForwardTensors();
planBackwardTensors();

_tensor_builder->allocate();
_tensor_builder->allocateBackward();

return _tensor_registry.get();
}
auto codes = generateFunctionMap();

backend::train::ITensorRegistry *BackendContext::genTrainingTensors()
{
planBackwardTensors();
// Initialize TrainableTensors
trainable_graph()->operands().iterate(
[&](const ir::OperandIndex &ind, const ir::Operand &operand) {
if (external_operands().contains(ind) || !operand.isConstant())
return;

_tensor_builder->allocateBackward();
auto tensor = tensor_registry()->getNativeITensor(ind);
assert(tensor != nullptr);

VERBOSE(FillOperandData) << "Fill data for " << ind << std::endl;

auto data = operand.shareData();
assert(data && data->base());
auto trainable_tensor = dynamic_cast<TrainableTensor *>(tensor);

return _tensor_registry.get();
if (trainable_tensor == nullptr)
throw std::runtime_error{"This tensor is not trainable tensor"};

trainable_tensor->fillBuffer(data);
});

// NOTE For memory optimization, we want to free some operand data
const_cast<ir::train::TrainableGraph &>(*_tdata->tgraph)
.operands()
.iterate([&](const ir::OperandIndex &, ir::Operand &obj) { obj.releaseData(); });

// TODO Enable
// for (auto &&it : ret)
// {
// auto &fn_seq = it.second;
// fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
// }

return codes;
}

void BackendContext::planForwardTensors()
Expand Down Expand Up @@ -209,46 +237,6 @@ void BackendContext::planBackwardTensors()
tensor_planner.planDisposableBackPropTensors(tensor_builder.get());
}

FunctionMap BackendContext::genKernels()
{
auto ret = generateFunctionMap();

// Initialize TrainableTensors
trainable_graph()->operands().iterate(
[&](const ir::OperandIndex &ind, const ir::Operand &operand) {
if (external_operands().contains(ind) || !operand.isConstant())
return;

auto tensor = tensor_registry()->getNativeITensor(ind);
assert(tensor != nullptr);

VERBOSE(FillOperandData) << "Fill data for " << ind << std::endl;

auto data = operand.shareData();
assert(data && data->base());
auto trainable_tensor = dynamic_cast<TrainableTensor *>(tensor);

if (trainable_tensor == nullptr)
throw std::runtime_error{"This tensor is not trainable tensor"};

trainable_tensor->fillBuffer(data);
});

// NOTE For memory optimization, we want to free some operand data
const_cast<ir::train::TrainableGraph &>(*_tdata->tgraph)
.operands()
.iterate([&](const ir::OperandIndex &, ir::Operand &obj) { obj.releaseData(); });

// TODO Enable
// for (auto &&it : ret)
// {
// auto &fn_seq = it.second;
// fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
// }

return ret;
}

FunctionMap BackendContext::generateFunctionMap()
{
train::FunctionMap ret;
Expand Down
5 changes: 1 addition & 4 deletions runtime/onert/backend/train/BackendContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,13 @@ class BackendContext : public onert::backend::train::TrainableBackendContext
BackendContext &operator=(const BackendContext &) = delete;

public:
backend::ITensorRegistry *genTensors() override;
backend::train::ITensorRegistry *genTrainingTensors() override;
FunctionMap gen() override;

private:
void planForwardTensors();
void planBackwardTensors();

public:
FunctionMap genKernels() override;

std::shared_ptr<ExternalContext> external_context() { return _external_context; }

const exec::train::optimizer::Optimizer *optimizer() const { return _optimizer.get(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ class TrainableBackendContext

std::shared_ptr<ITensorRegistry> tensor_registry() { return _tensor_registry; }

virtual ITensorRegistry *genTrainingTensors() = 0;
virtual backend::ITensorRegistry *genTensors() = 0;
virtual FunctionMap genKernels() = 0;
virtual FunctionMap gen() = 0;

private:
const ITrainableBackend *_backend{nullptr};
Expand Down
20 changes: 5 additions & 15 deletions runtime/onert/core/src/backend/builtin/train/BackendContext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,19 @@ namespace builtin
namespace train
{

backend::ITensorRegistry *BackendContext::genTensors()
backend::train::FunctionMap BackendContext::gen()
{
// For now, there is no need to generate tensors for forwarding.
// For now, there is no need to generate tensors for forwarding and backwarding.
// builtin train backend handles 3 operators: `Permute`, `IF`, `WHILE`.
// `Permute`: Tensor generation is not required.
// `IF`, `WHILE`: Not supported yet
return tensor_registry().get();
}

backend::train::ITensorRegistry *BackendContext::genTrainingTensors()
{
// For now, there is no need to generate tensors for backwarding.
return tensor_registry().get();
}

backend::train::FunctionMap BackendContext::genKernels()
{
backend::train::FunctionMap ret;
backend::train::FunctionMap codes;

for (auto &&op_ind : _tdata->op_order)
{
auto tn_seq = kernel_gen->generate(op_ind);
ret.emplace(op_ind, std::move(tn_seq));
codes.emplace(op_ind, std::move(tn_seq));
}

trainable_graph()->operands().iterate(
Expand All @@ -69,7 +59,7 @@ backend::train::FunctionMap BackendContext::genKernels()
// fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
// }

return ret;
return codes;
}

} // namespace train
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,7 @@ class BackendContext : public backend::train::TrainableBackendContext
{
}

backend::ITensorRegistry *genTensors() override;
backend::train::ITensorRegistry *genTrainingTensors() override;

public:
backend::train::FunctionMap genKernels() override;
backend::train::FunctionMap gen() override;

std::shared_ptr<ExternalContext> external_context() { return _external_context; }

Expand Down
55 changes: 32 additions & 23 deletions runtime/onert/core/src/compiler/ExecutorFactory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,21 @@ std::deque<std::pair<const backend::Backend *, Context *>> orderBackendContext(
return ordered_contexts;
}

void extractCodes(backend::train::FunctionMap &codes,
const compiler::train::LoweredTrainableGraph *lowered_graph,
compiler::train::TrainableCodeMap &code_map)
{
for (auto &&[op_ind, tn_seq] : codes)
{
auto &op = lowered_graph->trainable_graph().operation(op_ind);
const auto backend = lowered_graph->lower_info().operation.at(op_ind);

assert(code_map.find(op_ind) == code_map.end());
code_map.insert(
{op_ind, compiler::train::TrainableCodeAndInfo{op_ind, &op, backend, std::move(tn_seq)}});
}
}

} // namespace
} // namespace onert

Expand Down Expand Up @@ -741,15 +756,16 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor(
VERBOSE(ExecutorFactory) << "Linearize for backwarding order" << std::endl;
Linear::dump(*lowered_graph, backward_order);

for (auto &&pair : tbackend_contexts)
train::TrainableCodeMap code_map;
// Generate tensors and kernels
for (auto &&[backend, context] : tbackend_contexts)
{
pair.second->genTensors();
}
// builtin backend's kernel generator requires access to tensors in other backends.
if (backend->config()->id() == "builtin")
continue;

for (auto &&pair : tbackend_contexts)
{
auto tctx = pair.second.get();
tctx->genTrainingTensors();
auto codes = context->gen();
extractCodes(codes, lowered_graph.get(), code_map);
}

prepareMigrantTensors(*lowered_graph, tbackend_contexts);
Expand All @@ -767,6 +783,15 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor(
}
}

for (auto &&[backend, context] : tbackend_contexts)
{
if (backend->config()->id() == "builtin")
{
auto codes = context->gen();
extractCodes(codes, lowered_graph.get(), code_map);
}
}

// Adjust the order of backends for the upcoming iteration
auto ordered_contexts =
onert::orderBackendContext<backend::train::TrainableBackendContext>(tbackend_contexts);
Expand Down Expand Up @@ -845,22 +870,6 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor(
}));
}

train::TrainableCodeMap code_map;
// Generate kernels
for (auto &&pair : ordered_contexts)
{
auto codes = pair.second->genKernels();
for (auto &&[op_ind, tn_seq] : codes)
{
auto &op = lowered_graph->trainable_graph().operation(op_ind);
const auto backend = lowered_graph->lower_info().operation.at(op_ind);

assert(code_map.find(op_ind) == code_map.end());
code_map.insert(
{op_ind, train::TrainableCodeAndInfo{op_ind, &op, backend, std::move(tn_seq)}});
}
}

if (order.size() != code_map.size())
{
throw std::runtime_error("ExecutorFactory: Some kernels are not generated");
Expand Down

0 comments on commit 39ce0e7

Please sign in to comment.