Skip to content

Commit

Permalink
Chore: refactor atomic bias (#3654)
Browse files Browse the repository at this point in the history
`var_name` is no longer required as user input for `polar` and `dipole`.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
anyangml and pre-commit-ci[bot] authored Apr 11, 2024
1 parent fd2daeb commit a26b680
Show file tree
Hide file tree
Showing 14 changed files with 723 additions and 85 deletions.
7 changes: 3 additions & 4 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class DipoleFitting(GeneralFitting):
Parameters
----------
var_name
The name of the output variable.
ntypes
The number of atom types.
dim_descrpt
Expand Down Expand Up @@ -86,7 +84,6 @@ class DipoleFitting(GeneralFitting):

def __init__(
self,
var_name: str,
ntypes: int,
dim_descrpt: int,
embedding_width: int,
Expand Down Expand Up @@ -124,7 +121,7 @@ def __init__(
self.r_differentiable = r_differentiable
self.c_differentiable = c_differentiable
super().__init__(
var_name=var_name,
var_name="dipole",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
neuron=neuron,
Expand Down Expand Up @@ -161,6 +158,8 @@ def serialize(self) -> dict:
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
var_name = data.pop("var_name", None)
assert var_name == "dipole"
return super().deserialize(data)

def output_def(self):
Expand Down
7 changes: 3 additions & 4 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ class PolarFitting(GeneralFitting):
Parameters
----------
var_name
The name of the output variable.
ntypes
The number of atom types.
dim_descrpt
Expand Down Expand Up @@ -88,7 +86,6 @@ class PolarFitting(GeneralFitting):

def __init__(
self,
var_name: str,
ntypes: int,
dim_descrpt: int,
embedding_width: int,
Expand Down Expand Up @@ -145,7 +142,7 @@ def __init__(
self.shift_diag = shift_diag
self.constant_matrix = np.zeros(ntypes, dtype=GLOBAL_NP_FLOAT_PRECISION)
super().__init__(
var_name=var_name,
var_name="polar",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
neuron=neuron,
Expand Down Expand Up @@ -201,6 +198,8 @@ def serialize(self) -> dict:
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
var_name = data.pop("var_name", None)
assert var_name == "polar"
return super().deserialize(data)

def output_def(self):
Expand Down
5 changes: 2 additions & 3 deletions deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ class DipoleFittingNet(GeneralFitting):
Parameters
----------
var_name : str
The atomic property to fit, 'dipole'.
ntypes : int
Element count.
dim_descrpt : int
Expand Down Expand Up @@ -97,7 +95,7 @@ def __init__(
self.r_differentiable = r_differentiable
self.c_differentiable = c_differentiable
super().__init__(
var_name=kwargs.pop("var_name", "dipole"),
var_name="dipole",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
neuron=neuron,
Expand Down Expand Up @@ -131,6 +129,7 @@ def serialize(self) -> dict:
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("var_name", None)
return super().deserialize(data)

def output_def(self) -> FittingOutputDef:
Expand Down
5 changes: 2 additions & 3 deletions deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ class PolarFittingNet(GeneralFitting):
Parameters
----------
var_name : str
The atomic property to fit, 'polar'.
ntypes : int
Element count.
dim_descrpt : int
Expand Down Expand Up @@ -127,7 +125,7 @@ def __init__(
ntypes, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
super().__init__(
var_name=kwargs.pop("var_name", "polar"),
var_name="polar",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
neuron=neuron,
Expand Down Expand Up @@ -180,6 +178,7 @@ def serialize(self) -> dict:
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("var_name", None)
return super().deserialize(data)

def output_def(self) -> FittingOutputDef:
Expand Down
Loading

0 comments on commit a26b680

Please sign in to comment.