Skip to content

Commit

Permalink
feat: support serialize and deserialize for atomic model's compute ou…
Browse files Browse the repository at this point in the history
…tput stat (#3649)

- implement serialize and deserialize for atomic model's compute output
stat
- intensive change to the linear atomic models' serialization. @anyangml
please help checking the changes.

---------

Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
wanghan-iapcm and Han Wang authored Apr 9, 2024
1 parent baffd39 commit fd2daeb
Show file tree
Hide file tree
Showing 10 changed files with 312 additions and 116 deletions.
113 changes: 113 additions & 0 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Dict,
List,
Expand Down Expand Up @@ -30,11 +31,44 @@ def __init__(
type_map: List[str],
atom_exclude_types: List[int] = [],
pair_exclude_types: List[Tuple[int, int]] = [],
rcond: Optional[float] = None,
preset_out_bias: Optional[Dict[str, np.ndarray]] = None,
):
super().__init__()
self.type_map = type_map
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)
self.rcond = rcond
self.preset_out_bias = preset_out_bias

def init_out_stat(self):
"""Initialize the output bias."""
ntypes = self.get_ntypes()
self.bias_keys: List[str] = list(self.fitting_output_def().keys())
self.max_out_size = max(
[self.atomic_output_def()[kk].size for kk in self.bias_keys]
)
self.n_out = len(self.bias_keys)
out_bias_data = np.zeros([self.n_out, ntypes, self.max_out_size])
out_std_data = np.ones([self.n_out, ntypes, self.max_out_size])
self.out_bias = out_bias_data
self.out_std = out_std_data

def __setitem__(self, key, value):
if key in ["out_bias"]:
self.out_bias = value
elif key in ["out_std"]:
self.out_std = value
else:
raise KeyError(key)

def __getitem__(self, key):
if key in ["out_bias"]:
return self.out_bias
elif key in ["out_std"]:
return self.out_std
else:
raise KeyError(key)

def get_type_map(self) -> List[str]:
"""Get the type map."""
Expand Down Expand Up @@ -132,6 +166,7 @@ def forward_common_atomic(
fparam=fparam,
aparam=aparam,
)
ret_dict = self.apply_out_stat(ret_dict, atype)

# nf x nloc
atom_mask = ext_atom_mask[:, :nloc].astype(np.int32)
Expand All @@ -150,6 +185,84 @@ def forward_common_atomic(

def serialize(self) -> dict:
return {
"type_map": self.type_map,
"atom_exclude_types": self.atom_exclude_types,
"pair_exclude_types": self.pair_exclude_types,
"rcond": self.rcond,
"preset_out_bias": self.preset_out_bias,
"@variables": {
"out_bias": self.out_bias,
"out_std": self.out_std,
},
}

@classmethod
def deserialize(cls, data: dict) -> "BaseAtomicModel":
data = copy.deepcopy(data)
variables = data.pop("@variables")
obj = cls(**data)
for kk in variables.keys():
obj[kk] = variables[kk]
return obj

def apply_out_stat(
self,
ret: Dict[str, np.ndarray],
atype: np.ndarray,
):
"""Apply the stat to each atomic output.
The developer may override the method to define how the bias is applied
to the atomic output of the model.
Parameters
----------
ret
The returned dict by the forward_atomic method
atype
The atom types. nf x nloc
"""
out_bias, out_std = self._fetch_out_stat(self.bias_keys)
for kk in self.bias_keys:
# nf x nloc x odims, out_bias: ntypes x odims
ret[kk] = ret[kk] + out_bias[kk][atype]
return ret

def _varsize(
self,
shape: List[int],
) -> int:
output_size = 1
len_shape = len(shape)
for i in range(len_shape):
output_size *= shape[i]
return output_size

def _get_bias_index(
self,
kk: str,
) -> int:
res: List[int] = []
for i, e in enumerate(self.bias_keys):
if e == kk:
res.append(i)
assert len(res) == 1
return res[0]

def _fetch_out_stat(
self,
keys: List[str],
) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]:
ret_bias = {}
ret_std = {}
ntypes = self.get_ntypes()
for kk in keys:
idx = self._get_bias_index(kk)
isize = self._varsize(self.atomic_output_def()[kk].shape)
ret_bias[kk] = self.out_bias[idx, :, :isize].reshape(
[ntypes] + list(self.atomic_output_def()[kk].shape) # noqa: RUF005
)
ret_std[kk] = self.out_std[idx, :, :isize].reshape(
[ntypes] + list(self.atomic_output_def()[kk].shape) # noqa: RUF005
)
return ret_bias, ret_std
12 changes: 7 additions & 5 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ def __init__(
type_map: List[str],
**kwargs,
):
super().__init__(type_map, **kwargs)
self.type_map = type_map
self.descriptor = descriptor
self.fitting = fitting
self.type_map = type_map
super().__init__(type_map, **kwargs)
super().init_out_stat()

def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
Expand Down Expand Up @@ -136,7 +137,7 @@ def serialize(self) -> dict:
{
"@class": "Model",
"type": "standard",
"@version": 1,
"@version": 2,
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
Expand All @@ -147,13 +148,14 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
check_version_compatibility(data.pop("@version", 1), 2, 2)
data.pop("@class")
data.pop("type")
descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor"))
fitting_obj = BaseFitting.deserialize(data.pop("fitting"))
type_map = data.pop("type_map")
obj = cls(descriptor_obj, fitting_obj, type_map=type_map, **data)
data["descriptor"] = descriptor_obj
data["fitting"] = fitting_obj
obj = super().deserialize(data)
return obj

def get_dim_fparam(self) -> int:
Expand Down
75 changes: 33 additions & 42 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(
type_map: List[str],
**kwargs,
):
super().__init__(type_map, **kwargs)
super().init_out_stat()
self.models = models
sub_model_type_maps = [md.get_type_map() for md in models]
err_msg = []
Expand All @@ -66,7 +68,6 @@ def __init__(
self.mapping_list.append(self.remap_atype(tpmp, self.type_map))
assert len(err_msg) == 0, "\n".join(err_msg)
self.mixed_types_list = [model.mixed_types() for model in self.models]
super().__init__(type_map, **kwargs)

def mixed_types(self) -> bool:
"""If true, the model
Expand All @@ -86,7 +87,7 @@ def get_rcut(self) -> float:

def get_type_map(self) -> List[str]:
"""Get the type map."""
raise self.type_map
return self.type_map

def get_model_rcuts(self) -> List[float]:
"""Get the cut-off radius for each individual models."""
Expand Down Expand Up @@ -218,27 +219,30 @@ def fitting_output_def(self) -> FittingOutputDef:
)

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "linear",
"@version": 1,
"models": [model.serialize() for model in self.models],
"type_map": self.type_map,
}
dd = super().serialize()
dd.update(
{
"@class": "Model",
"@version": 2,
"type": "linear",
"models": [model.serialize() for model in self.models],
"type_map": self.type_map,
}
)
return dd

@classmethod
def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
type_map = data.pop("type_map")
check_version_compatibility(data.pop("@version", 2), 2, 2)
data.pop("@class", None)
data.pop("type", None)
models = [
BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model)
for model in data["models"]
]
data.pop("models")
return cls(models, type_map, **data)
data["models"] = models
return super().deserialize(data)

def _compute_weight(
self,
Expand Down Expand Up @@ -312,24 +316,21 @@ def __init__(
**kwargs,
):
models = [dp_model, zbl_model]
super().__init__(models, type_map, **kwargs)
self.dp_model = dp_model
self.zbl_model = zbl_model
kwargs["models"] = models
kwargs["type_map"] = type_map
super().__init__(**kwargs)

self.sw_rmin = sw_rmin
self.sw_rmax = sw_rmax
self.smin_alpha = smin_alpha

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd = super().serialize()
dd.update(
{
"@class": "Model",
"type": "zbl",
"@version": 2,
"models": LinearEnergyAtomicModel(
models=[self.models[0], self.models[1]], type_map=self.type_map
).serialize(),
"type": "zbl",
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
"smin_alpha": self.smin_alpha,
Expand All @@ -340,25 +341,15 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("@class")
data.pop("type")
sw_rmin = data.pop("sw_rmin")
sw_rmax = data.pop("sw_rmax")
smin_alpha = data.pop("smin_alpha")
linear_model = LinearEnergyAtomicModel.deserialize(data.pop("models"))
dp_model, zbl_model = linear_model.models
type_map = linear_model.type_map

return cls(
dp_model=dp_model,
zbl_model=zbl_model,
sw_rmin=sw_rmin,
sw_rmax=sw_rmax,
type_map=type_map,
smin_alpha=smin_alpha,
**data,
)
check_version_compatibility(data.pop("@version", 1), 2, 2)
models = [
BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model)
for model in data["models"]
]
data["dp_model"], data["zbl_model"] = models[0], models[1]
data.pop("@class", None)
data.pop("type", None)
return super().deserialize(data)

def _compute_weight(
self,
Expand Down
12 changes: 6 additions & 6 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
**kwargs,
):
super().__init__(type_map, **kwargs)
super().init_out_stat()
self.tab_file = tab_file
self.rcut = rcut
self.type_map = type_map
Expand Down Expand Up @@ -136,7 +137,7 @@ def serialize(self) -> dict:
{
"@class": "Model",
"type": "pairtab",
"@version": 1,
"@version": 2,
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
Expand All @@ -148,14 +149,13 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
check_version_compatibility(data.pop("@version", 1), 2, 2)
data.pop("@class")
data.pop("type")
rcut = data.pop("rcut")
sel = data.pop("sel")
type_map = data.pop("type_map")
tab = PairTab.deserialize(data.pop("tab"))
tab_model = cls(None, rcut, sel, type_map, **data)
data["tab_file"] = None
tab_model = super().deserialize(data)

tab_model.tab = tab
tab_model.tab_info = tab_model.tab.tab_info
nspline, ntypes = tab_model.tab_info[-2:].astype(int)
Expand Down
Loading

0 comments on commit fd2daeb

Please sign in to comment.