-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(this PR is just for my personal self-learning) [PaddlePaddle Hackathon 2 No.22] add paddle.index_add to Paddle #42475
Changes from 30 commits
94fb159
b4442be
f60ad37
9adf9cc
65cfb27
50e2cd1
9b8380b
65050b4
b3cf6be
04a3fce
5c3f5d1
f6cda0f
9415651
75c1869
35da3f7
88af7e7
2e2cf4c
98ab985
7b2dbef
d0a2631
b546b05
41d7be0
62a7f1c
959dbc2
00b25c2
04491ff
181665f
cfd71bb
88d1295
0e62f69
700dbee
75af4a1
3838749
dc8af85
aaaba0c
c017dac
2bcd776
78ce053
930e314
ae72a00
f66dc0c
33bb556
9de7ee3
43fc9f6
28839dd
b90a58c
a02247d
8fdae4a
c6cefc1
fbabcce
32ef9af
6a6f698
f68f280
354a3de
60454ed
418f613
b714d86
3b203b7
fbc056f
5ef0319
50037cb
66defb6
2e30e48
d59fe50
0362dae
c1b8a22
ddf255e
45c1e81
a0413f7
176d6bd
ceb519f
1b24d8b
a9c1617
01981d6
bfc0e35
d9cde21
33f0bcc
821eee6
d1aeda9
791c825
407673e
3b11f18
a113244
df2ee5d
b2d345b
2483e07
9fbde57
14786cf
ff15e2d
24d3827
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
/*Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <memory> | ||
#include "paddle/fluid/framework/infershape_utils.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/op_version_registry.h" | ||
#include "paddle/phi/core/infermeta_utils.h" | ||
#include "paddle/phi/infermeta/binary.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class IndexAddOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType( | ||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
class IndexAddOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("X", | ||
"(Tensor, default input Tensor<float>), " | ||
"the input feature data of IndexAddOp, dtype should be" | ||
"int32, int64, float16, float32, float64."); | ||
AddInput("Index", | ||
"(Tensor, default 1-d Tensor<int>), " | ||
"the 1-D tensor containing the indices to index, " | ||
"dtype should be int32, int64"); | ||
AddAttr<int>("axis", | ||
"(int, default 0), " | ||
"the dimension in which we index.") | ||
.SetDefault(0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不建议添加默认值,因为没有合理的默认值 此外,axis 希望也能支持 Tensor 输入,亦即也作为一个 Input 而不只是 Attr. 具体实现方法可以参考 |
||
AddAttr<float>("added_value", | ||
"(float, default 0.0f) The value to add.") | ||
.SetDefault(0.0f); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不建议添加默认值,因为没有合理的默认值 |
||
AddOutput("Out", | ||
"(Tensor, default Tensor<float>)," | ||
" the output of IndexAddOp, whose dtype and shape is the same as X."); | ||
AddComment(R"DOC( | ||
index_add operator. | ||
Add the elements of the input tensor with value | ||
by selecting the indices in the order given in 'index' | ||
on the axis 'axis'. | ||
|
||
This operator also supports inplace modification. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class IndexAddGradOpMaker : public framework::SingleGradOpMaker<T> { | ||
public: | ||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
||
void Apply(GradOpPtr<T> op) const override { | ||
op->SetType("index_add_grad"); | ||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); | ||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); | ||
op->SetAttrMap(this->Attrs()); | ||
} | ||
}; | ||
|
||
class IndexAddGradOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType( | ||
OperatorWithKernel::IndicateVarDataType( | ||
ctx, framework::GradVarName("Out")), ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
DECLARE_INPLACE_OP_INFERER(IndexAddInplaceInferer, {"X", "Out"}); | ||
DECLARE_INPLACE_OP_INFERER(IndexAddGradInplaceInferer, | ||
{framework::GradVarName("Out"), | ||
framework::GradVarName("X")}); | ||
// DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexAddGradNoNeedBufferVarsInferer, "X"); | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
DECLARE_INFER_SHAPE_FUNCTOR(index_add, IndexAddInferShapeFunctor, | ||
PD_INFER_META(phi::IndexAddInferMeta)); | ||
|
||
REGISTER_OPERATOR(index_add, ops::IndexAddOp, ops::IndexAddOpMaker, | ||
ops::IndexAddGradOpMaker<paddle::framework::OpDesc>, | ||
ops::IndexAddGradOpMaker<paddle::imperative::OpBase>, | ||
ops::IndexAddInplaceInferer, IndexAddInferShapeFunctor); | ||
|
||
DECLARE_INFER_SHAPE_FUNCTOR(index_add_grad, IndexAddGradInferShapeFunctor, | ||
PD_INFER_META(phi::IndexAddGradInferMeta)); | ||
|
||
REGISTER_OPERATOR(index_add_grad, ops::IndexAddGradOp, | ||
ops::IndexAddGradInplaceInferer, | ||
// ops::IndexAddGradNoNeedBufferVarsInferer, | ||
IndexAddGradInferShapeFunctor); |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -184,6 +184,17 @@ void HuberLossInferMeta(const MetaTensor& input_meta, | |
MetaTensor* residual, | ||
MetaConfig config = MetaConfig()); | ||
|
||
void IndexAddInferMeta(const MetaTensor& x, | ||
const MetaTensor& index, | ||
int axis, | ||
float added_value, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如前面所属,axis, index, add_value 都支持 Tensor 或者 非 Tensor 的话,这些也都需要相应修改。 |
||
MetaTensor* output); | ||
|
||
void IndexAddGradInferMeta(const MetaTensor& out_grad, | ||
int axis, | ||
float added_value, | ||
MetaTensor* x_grad); | ||
|
||
void IndexSampleInferMeta(const MetaTensor& x, | ||
const MetaTensor& y, | ||
MetaTensor* out, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/index_add_grad_kernel.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/core/utils/data_type.h" | ||
#include "paddle/phi/kernels/copy_kernel.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void IndexAddGradKernel(const Context& dev_ctx, | ||
const DenseTensor& out_grad, | ||
int axis, | ||
float added_value, | ||
DenseTensor* x_grad) { | ||
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(index_add_grad, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::IndexAddGradKernel, | ||
float, | ||
phi::dtype::float16, | ||
double, | ||
int, | ||
int64_t) {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/phi/core/dense_tensor.h" | ||
#include "paddle/phi/kernels/copy_kernel.h" | ||
#include "paddle/phi/kernels/funcs/blas/blas.h" | ||
#include "paddle/phi/kernels/funcs/eigen/common.h" | ||
#include "paddle/phi/kernels/funcs/math_function.h" | ||
|
||
namespace phi { | ||
template <typename Context, typename T, typename IndexT = int> | ||
void IndexAddInner(const Context& ctx, | ||
const DenseTensor& index, | ||
DenseTensor* output, | ||
int axis, | ||
T added_val) { | ||
auto output_dim = output->dims(); | ||
auto output_dim_size = output_dim.size(); | ||
auto index_size = index.dims()[0]; | ||
|
||
DenseTensor index_cpu_copy; | ||
if (!paddle::platform::is_cpu_place(index.place())) { | ||
phi::Copy(ctx, index, phi::CPUPlace(), true, &index_cpu_copy); | ||
} | ||
const IndexT* index_data = paddle::platform::is_cpu_place(index.place()) | ||
? index.data<IndexT>() | ||
: index_cpu_copy.data<IndexT>(); | ||
|
||
auto slice_size = 1; | ||
for (auto i = axis + 1; i < output_dim_size; i++) { | ||
slice_size *= output_dim[i]; | ||
} | ||
|
||
auto outer_nums = 1; | ||
for (auto i = 0; i < axis; i++) { | ||
outer_nums *= output_dim[i]; | ||
} | ||
|
||
for (int i = 0; i < index_size; i++) { | ||
PADDLE_ENFORCE_GE( | ||
index_data[i], | ||
0, | ||
phi::errors::InvalidArgument( | ||
"Variable value (index) of OP(index_add) " | ||
"expected >= 0 and < %ld, but got %ld. Please check input " | ||
"value.", | ||
output_dim[axis], | ||
index_data[i])); | ||
PADDLE_ENFORCE_LT( | ||
index_data[i], | ||
output_dim[axis], | ||
phi::errors::InvalidArgument( | ||
"Variable value (index) of OP(index_add) " | ||
"expected >= 0 and < %ld, but got %ld. Please check input " | ||
"value.", | ||
output_dim[axis], | ||
index_data[i])); | ||
} | ||
|
||
output->Resize(phi::make_ddim({outer_nums, output_dim[axis], slice_size})); | ||
|
||
auto output_tensor = EigenTensor<T, 3>::From(*output); | ||
auto& place = *ctx.eigen_device(); | ||
for (auto j = 0; j < index_size; j++) { | ||
IndexT index_value = index_data[j]; | ||
auto output_t = output_tensor.chip(index_value, 1); | ||
// output_t.device(place) = output_t.constant(fill_val); | ||
output_t.device(place) += output_t.constant(added_val); | ||
} | ||
output->Resize(output_dim); | ||
} | ||
|
||
} // namespace phi |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/index_add_kernel.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/core/utils/data_type.h" | ||
#include "paddle/phi/kernels/copy_kernel.h" | ||
#include "paddle/phi/kernels/cpu/index_add_impl.h" | ||
#include "paddle/phi/kernels/funcs/eigen/common.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void IndexAddKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
const DenseTensor& index, | ||
int axis, | ||
float added_value, | ||
DenseTensor* output) { | ||
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, output); | ||
if (axis < 0) { | ||
axis += x.dims().size(); | ||
} | ||
const auto& index_type = index.dtype(); | ||
|
||
bool index_type_match = | ||
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; | ||
PADDLE_ENFORCE_EQ(index_type_match, | ||
true, | ||
phi::errors::InvalidArgument( | ||
"Input(Index) holds the wrong type, it holds %s, but " | ||
"desires to be %s or %s", | ||
index_type, | ||
phi::DataType::INT32, | ||
phi::DataType::INT64)); | ||
|
||
auto added_val = static_cast<T>(added_value); | ||
if (index_type == phi::DataType::INT32) { | ||
IndexAddInner<Context, T, int>(dev_ctx, index, output, axis, added_val); | ||
} else if (index_type == phi::DataType::INT64) { | ||
IndexAddInner<Context, T, int64_t>(dev_ctx, index, output, axis, added_val); | ||
} | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(index_add, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::IndexAddKernel, | ||
float, | ||
phi::dtype::float16, | ||
double, | ||
int, | ||
int64_t) {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 希望也支持 bool, bfloat16 和 complex 数据类型。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Index 希望也能支持 python 的 list 或者 tuple of ints.
这样写是不能支持传入 list or tuple of ints, 这和设计中的描述也不一致。