Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 7th No.36】为 Paddle 代码转换工具新增 API 转换规则(第 3 组)-part #479

Open
wants to merge 42 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
28ed535
fix
enkilee Sep 19, 2024
cdc060b
code style
enkilee Sep 19, 2024
131dac1
fix
enkilee Sep 19, 2024
fdd7ae6
fix
enkilee Sep 19, 2024
9eeb83d
fix
enkilee Sep 20, 2024
73e0929
fix
enkilee Sep 24, 2024
9ea271b
Merge branch 'PaddlePaddle:master' into hackathon7-part1
enkilee Sep 24, 2024
9e9cab5
CI
enkilee Sep 24, 2024
7e1675a
Merge remote-tracking branch 'refs/remotes/origin/hackathon7-part1' i…
enkilee Sep 24, 2024
ce46bff
fix
enkilee Sep 25, 2024
70cfd55
fix
enkilee Sep 25, 2024
206980a
fix
enkilee Sep 25, 2024
7c8b49e
fix
enkilee Sep 25, 2024
6a892d0
fix
enkilee Sep 25, 2024
f155669
fix
enkilee Sep 26, 2024
7c8108f
test blackman
enkilee Sep 26, 2024
563a1be
fix
enkilee Sep 26, 2024
4aee5c1
fix
enkilee Sep 26, 2024
f40a61e
fix
enkilee Sep 26, 2024
ecd7677
fix
enkilee Sep 26, 2024
6b22746
fix
enkilee Sep 26, 2024
5aca4f5
test
enkilee Sep 26, 2024
4d9440d
fix
enkilee Sep 26, 2024
fb27f13
fix
enkilee Sep 26, 2024
854bd4c
fix
enkilee Sep 26, 2024
6c0d004
fix
enkilee Sep 26, 2024
2b56f1a
fix
enkilee Sep 26, 2024
3761ca9
fix
enkilee Sep 26, 2024
51f6f81
fix
enkilee Sep 27, 2024
c66d5ed
fix
enkilee Sep 27, 2024
6eef583
fix
enkilee Sep 27, 2024
476edb3
fix
enkilee Sep 27, 2024
9c2aa00
fix
enkilee Sep 27, 2024
67972d8
fix
enkilee Sep 27, 2024
466d826
test
enkilee Sep 27, 2024
3c5e0e4
fix
enkilee Sep 27, 2024
38d7bc8
fix
enkilee Sep 27, 2024
3ce42da
fix
enkilee Sep 27, 2024
a34f023
fix
enkilee Sep 27, 2024
814e9fd
fix
enkilee Sep 27, 2024
d2d0a82
fix
enkilee Sep 27, 2024
fc5b930
fix
enkilee Sep 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -14571,6 +14571,195 @@
"input": "x"
}
},
"torch.signal.windows.blackman": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.audio.functional.get_window",
"min_input_args": 1,
"args_list": [
"M",
"*",
"sym",
"dtype",
"layout",
"device",
"requires_grad"
],
"kwargs_change": {
"M": "win_length",
"sym": "fftbins",
"dtype": "dtype"
},
"paddle_default_kwargs": {
"dtype": "paddle.float64",
"window": "'blackman'"
}
},
"torch.signal.windows.cosine": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.audio.functional.get_window",
"min_input_args": 1,
"args_list": [
"M",
"*",
"sym",
"dtype",
"layout",
"device",
"requires_grad"
],
"kwargs_change": {
"M": "win_length",
"sym": "fftbins",
"dtype": "dtype"
},
"paddle_default_kwargs": {
"dtype": "paddle.float64",
"window": "'cosine'"
}
},
"torch.signal.windows.exponential": {
"Matcher": "SignalWindowsWatcher",
"paddle_api": "paddle.audio.functional.get_window",
"min_input_args": 1,
"args_list": [
"M",
"*",
"center",
"tau",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@enkilee enkilee Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

恩,经过测试,是需要转写,已开pr改文档。

"sym",
"dtype",
"layout",
"device",
"requires_grad"
],
"unsupport_args": [
"center"
],
"kwargs_change": {
"M": "win_length",
"sym": "fftbins",
"dtype": "dtype"
},
"paddle_default_kwargs": {
"dtype": "paddle.float64"
}
},
"torch.signal.windows.gaussian": {
"Matcher": "SignalWindowsWatcher",
"paddle_api": "paddle.audio.functional.get_window",
"min_input_args": 1,
"args_list": [
"M",
"*",
"std",
"sym",
"dtype",
"layout",
"device",
"requires_grad"
],
"kwargs_change": {
"M": "win_length",
"sym": "fftbins",
"dtype": "dtype"
},
"paddle_default_kwargs": {
"dtype": "paddle.float64"
}
},
"torch.signal.windows.general_cosine": {
"Matcher": "SignalWindowsWatcher",
"paddle_api": "paddle.audio.functional.get_window",
"min_input_args": 1,
"args_list": [
"M",
"*",
"a",
"sym",
"dtype",
"layout",
"device",
"requires_grad"
],
"kwargs_change": {
"M": "win_length",
"a": "",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文档的转写方式看起来不是直接删除?是否有没对应上的diff?

没对应上的注意回头修正文档

Copy link
Contributor Author

@enkilee enkilee Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文档中a需要转写,a传入的是list, paddle组合成了 tuple(str, list)。a作为非可选参数,没有默认值。paddle中没有a这个参数,所以在general_cosine中,a默认为空。从pytorch拿值给a

pytorch:

torch.signal.windows.general_cosine(10, a=[0.46, 0.23, 0.31])

paddle:

paddle.audio.functional.get_window(('general_cosine', [0.46, 0.23, 0.31]), 10)

"sym": "fftbins",
"dtype": "dtype"
},
"paddle_default_kwargs": {
"dtype": "paddle.float64"
}
},
"torch.signal.windows.general_hamming": {
"Matcher": "SignalWindowsWatcher",
"paddle_api": "paddle.audio.functional.get_window",
"min_input_args": 1,
"args_list": [
"M",
"*",
"alpha",
"sym",
"dtype",
"layout",
"device",
"requires_grad"
],
"kwargs_change": {
"M": "win_length",
"sym": "fftbins",
"dtype": "dtype"
},
"paddle_default_kwargs": {
"dtype": "paddle.float64"
}
},
"torch.signal.windows.hamming": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.audio.functional.get_window",
"min_input_args": 1,
"args_list": [
"M",
"*",
"sym",
"dtype",
"layout",
"device",
"requires_grad"
],
"kwargs_change": {
"M": "win_length",
"sym": "fftbins",
"dtype": "dtype"
},
"paddle_default_kwargs": {
"dtype": "paddle.float64",
"window": "'hamming'"
}
},
"torch.signal.windows.hann": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.audio.functional.get_window",
"min_input_args": 1,
"args_list": [
"M",
"*",
"sym",
"dtype",
"layout",
"device",
"requires_grad"
],
"kwargs_change": {
"M": "win_length",
"sym": "fftbins",
"dtype": "dtype"
},
"paddle_default_kwargs": {
"dtype": "paddle.float64",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个的设置是什么原因,文档中需要强调吗

Copy link
Contributor Author

@enkilee enkilee Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle中默认float64。已开pr改文档。

def get_window(
    window: _WindowLiteral | tuple[_WindowLiteral, float],
    win_length: int,
    fftbins: bool = True,
    dtype: str = 'float64',
) -> Tensor:

def _gaussian(
    M: int, std: float, sym: bool = True, dtype: str = 'float64'
) -> Tensor:

def _exponential(
    M: int, center=None, tau=1.0, sym: bool = True, dtype: str = 'float64'
) -> Tensor:

"window": "'hann'"
}
},
"torch.signbit": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.signbit",
Expand Down
28 changes: 28 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,34 @@ def generate_code(self, kwargs):
return super().generate_code(kwargs)


class SignalWindowsWatcher(BaseMatcher):
def generate_code(self, kwargs):
new_kwargs = {}
if "exponential" in self.torch_api:
if "tau" in kwargs:
tau_value = float(str(kwargs.pop("tau")).split("=")[-1].strip("()"))
new_kwargs["window"] = ("exponential", tau_value)
else:
new_kwargs["window"] = ("exponential", 1.0)
if "gaussian" in self.torch_api:
if "std" in kwargs:
std_value = float(str(kwargs.pop("std")).split("=")[-1].strip("()"))
new_kwargs["window"] = ("gaussian", std_value)
else:
new_kwargs["window"] = ("gaussian", 1.0)
if "general_hamming" in self.torch_api:
if "alpha" in kwargs:
alpha_value = float(str(kwargs.pop("alpha")).split("=")[-1].strip("()"))
new_kwargs["window"] = ("general_hamming", alpha_value)
else:
new_kwargs["window"] = ("general_hamming", 0.54)
if "general_cosine" in self.torch_api:
a_value = [v for v in kwargs.values()][1]
new_kwargs["window"] = ("general_cosine", a_value)
new_kwargs.update(kwargs)
return GenericMatcher.generate_code(self, new_kwargs)


class Num2TensorBinaryWithAlphaMatcher(BaseMatcher):
def generate_code(self, kwargs):
kwargs_change = self.api_mapping.get("kwargs_change", {})
Expand Down
79 changes: 79 additions & 0 deletions tests/test_signal_windows_blackman.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.signal.windows.blackman")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.signal.windows.blackman(5)
"""
)
obj.run(pytorch_code, ["result"], check_value=False, check_dtype=False)


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.signal.windows.blackman(5, dtype=torch.float64)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.signal.windows.blackman(5, dtype=torch.float64, requires_grad=True)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.signal.windows.blackman(5, dtype=torch.float64, layout=torch.strided, requires_grad=True)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.signal.windows.blackman(5, dtype=torch.float64, layout=torch.strided, device=torch.device('cpu'), requires_grad=True)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_6():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.signal.windows.blackman(5, sym=False, dtype=torch.float64)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)
79 changes: 79 additions & 0 deletions tests/test_signal_windows_cosine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.signal.windows.cosine")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.signal.windows.cosine(10)
"""
)
obj.run(pytorch_code, ["result"], check_value=False, check_dtype=False)


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.signal.windows.cosine(10, dtype=torch.float64)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.signal.windows.cosine(10, dtype=torch.float64, requires_grad=True)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.signal.windows.cosine(10, dtype=torch.float64, layout=torch.strided, requires_grad=True)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.signal.windows.cosine(10, dtype=torch.float64, layout=torch.strided, device=torch.device('cpu'), requires_grad=True)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_7():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.signal.windows.cosine(10, sym=False, dtype=torch.float64)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)
Loading