Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 7th No.41】NO.41 为 Paddle 代码转换工具新增 API 转换规则(第 8 组) #493

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 136 additions & 4 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,22 @@
"memory_format"
]
},
"torch.Tensor.cauchy_": {},
"torch.Tensor.cauchy_": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.cauchy_",
"min_input_args": 0,
"args_list": [
"median",
"sigma",
"*",
"generator"
],
"kwargs_change": {
"median": "loc",
"sigma":"scale",
"generator":""
}
},
"torch.Tensor.cdouble": {
"Matcher": "TensorCdoubleMatcher",
"paddle_api": "paddle.Tensor.astype",
Expand Down Expand Up @@ -1621,7 +1636,20 @@
"other": "y"
}
},
"torch.Tensor.geometric_": {},
"torch.Tensor.geometric_": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.geometric_",
"min_input_args": 1,
"args_list": [
"p",
"*",
"generator"
],
"kwargs_change": {
"p": "probs",
"generator":""
}
},
"torch.Tensor.geqrf": {},
"torch.Tensor.ger": {
"Matcher": "GenericMatcher",
Expand Down Expand Up @@ -1957,7 +1985,10 @@
"paddle_api": "paddle.Tensor.is_floating_point",
"min_input_args": 0
},
"torch.Tensor.is_inference": {},
"torch.Tensor.is_inference": {
"Matcher": "Is_InferenceMatcher",
"paddle_api": "paddle.Tensor.stop_gradient"
},
"torch.Tensor.is_pinned": {
"Matcher": "Is_PinnedMatcher",
"min_input_args": 0
Expand Down Expand Up @@ -2995,7 +3026,16 @@
"Matcher": "UnchangeMatcher",
"min_input_args": 0
},
"torch.Tensor.random_": {},
"torch.Tensor.random_": {
"Matcher": "TensorRandom_Matcher",
"min_input_args": 0,
"args_list": [
"from",
"to",
"*",
"generator"
]
},
"torch.Tensor.ravel": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.flatten",
Expand Down Expand Up @@ -6059,6 +6099,19 @@
],
"min_input_args": 1
},
"torch.distributed.rpc.remote":{
"Matcher": "RpcRemoteMatcher",
"paddle_api": "paddle.distributed.rpc.rpc_async",
"min_input_args": 2,
"args_list": [
"to",
"func",
"args",
"kwargs",
"timeout"
]
},
"torch.distributed.optim.DistributedOptimizer":{},
"torch.distributed.rpc.shutdown": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distributed.rpc.shutdown",
Expand Down Expand Up @@ -6173,6 +6226,85 @@
"validate_args": ""
}
},
"torch.distributions.chi2.Chi2":{
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Chi2",
"min_input_args": 1,
"args_list": [
"df",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
}
},
"torch.distributions.constraints.Constraint" : {
"Matcher": "DistributionsConstrainMatcher",
"paddle_api": "paddle.distribution.constraint.Constraint",
"abstract": true
},
"torch.distributions.gamma.Gamma":{
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Gamma",
"min_input_args": 2,
"args_list": [
"concentration",
"rate",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
}
},
"torch.distributions.poisson.Poisson":{
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Poisson",
"min_input_args": 1,
"args_list": [
"rate",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
}
},
"torch.distributions.lkj_cholesky.LKJCholesky":{
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.LKJCholesky",
"min_input_args": 1,
"args_list": [
"dim",
"concentration",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
}
},
"torch.distributions.studentT.StudentT":{
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.StudentT",
"min_input_args": 1,
"args_list": [
"df",
"loc",
"scale",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
}
},
"torch.distributions.transforms.PositiveDefiniteTransform":{
"Matcher": "TransformsPositiveDefiniteTransformMatcher",
"min_input_args": 0,
"args_list": [
"cache_size"
],
"kwargs_change": {
"cache_size": ""
}
},
"torch.distributions.Categorical": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Categorical",
Expand Down
133 changes: 133 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,45 @@ def generate_code(self, kwargs):
return code


class RpcRemoteMatcher(BaseMatcher):
def generate_aux_code(self):
CODE_TEMPLATE = textwrap.dedent(
"""
import paddle
import paddle.distributed.rpc as rpc
class rpc_remote:
def __init__(to, func, args=None, kwargs=None, timeout=-1):
self.remote = rpc.rpc_async(to=to, fn=func, args=args, kwargs=kwargs, timeout=timeout)

def to_here():
return self.remote.wait()
"""
)
return CODE_TEMPLATE

def generate_code(self, kwargs):
self.write_aux_code()
if "args" not in kwargs.keys():
kwargs["args"] = None
if "kwargs" not in kwargs.keys():
kwargs["kwargs"] = None
if "timeout" not in kwargs.keys():
kwargs["timeout"] = -1
API_TEMPLATE = textwrap.dedent(
"""
paddle_aux.rpc_remote(to={}, func={}, args={}, kwargs={}, timeout={})
"""
)
code = API_TEMPLATE.format(
kwargs["to"],
kwargs["func"],
kwargs["args"],
kwargs["kwargs"],
kwargs["timeout"]
)
return code


class AtleastMatcher(BaseMatcher):
def get_paddle_nodes(self, args, kwargs):
new_args = self.parse_args(args)
Expand Down Expand Up @@ -681,6 +720,100 @@ def generate_code(self, kwargs):
return "unchange"


class DistributionsConstrainMatcher(BaseMatcher):
def generate_aux_code(self):
API_TEMPLATE = textwrap.dedent(
"""
import paddle
def Distributions_Constraint():
class DistributionsConstrain:
def check(self, value):
return paddle.distribution.constraint.Constraint()(value)
return DistributionsConstrain()
"""
)

return API_TEMPLATE
def generate_code(self, kwargs):
self.write_aux_code()
API_TEMPLATE = textwrap.dedent(
"""
paddle_aux.Distributions_Constraint()
"""
)
return API_TEMPLATE


class TransformsPositiveDefiniteTransformMatcher(BaseMatcher):
def generate_aux_code(self):
API_TEMPLATE = textwrap.dedent(
"""
import paddle
from paddle import Tensor
class TransformsPositiveDefiniteTransform:
def __call__(self, x: Tensor):
x = x.tril(-1) + x.diagonal(axis1=-2, axis2=-1).exp().diag_embed()
shape_list = list(range(x.ndim))
shape_list[-1], shape_list[-2] = shape_list[-2], shape_list[-1]
y = x.transpose(perm=shape_list)
return x @ y

def inv(self, y):
y = paddle.linalg.cholesky(y)
return y.tril(-1) + y.diagonal(axis1=-2, axis2=-1).log().diag_embed()

"""
)

return API_TEMPLATE
def generate_code(self, kwargs):
self.write_aux_code()
API_TEMPLATE = textwrap.dedent(
"""
paddle_aux.TransformsPositiveDefiniteTransform()
"""
)
return API_TEMPLATE


class Is_InferenceMatcher(BaseMatcher):
def generate_aux_code(self):
API_TEMPLATE = textwrap.dedent(
"""
import paddle
def is_inference(x):
return not x.stop_gradient
"""
)

return API_TEMPLATE
def generate_code(self, kwargs):
self.write_aux_code()
API_TEMPLATE = textwrap.dedent(
"""
paddle_aux.is_inference({})
"""
)
code = API_TEMPLATE.format(self.paddleClass)
return code


class TensorRandom_Matcher(BaseMatcher):
def generate_code(self, kwargs):
self.write_aux_code()
API_TEMPLATE = textwrap.dedent(
"""
paddle.assign(paddle.cast(paddle.randint(low={}, high={}, shape={}.shape), dtype='float32'), {})
"""
)
code = API_TEMPLATE.format(
kwargs["from"],
kwargs["to"],
self.paddleClass,
self.paddleClass,
)
return code

class TransposeMatcher(BaseMatcher):
def generate_aux_code(self):
API_TEMPLATE = textwrap.dedent(
Expand Down
Loading