From fd2daeb5934e2f3fd2f47183b3fedf87c287965c Mon Sep 17 00:00:00 2001 From: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Date: Tue, 9 Apr 2024 11:19:39 +0800 Subject: [PATCH] feat: support serialize and deserialize for atomic model's compute output stat (#3649) - implement serialize and deserialize for atomic model's compute output stat - intensive change to the linear atomic models' serialization. @anyangml please help checking the changes. --------- Co-authored-by: Han Wang --- .../dpmodel/atomic_model/base_atomic_model.py | 113 ++++++++++++++++++ .../dpmodel/atomic_model/dp_atomic_model.py | 12 +- .../atomic_model/linear_atomic_model.py | 75 +++++------- .../atomic_model/pairtab_atomic_model.py | 12 +- .../model/atomic_model/base_atomic_model.py | 53 ++++++-- .../pt/model/atomic_model/dp_atomic_model.py | 9 +- .../model/atomic_model/linear_atomic_model.py | 60 ++++------ .../atomic_model/pairtab_atomic_model.py | 18 +-- deepmd/tf/model/model.py | 22 +++- .../tests/pt/model/test_atomic_model_stat.py | 54 ++++++++- 10 files changed, 312 insertions(+), 116 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index dbb344d5ca..9e43851157 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy from typing import ( Dict, List, @@ -30,11 +31,44 @@ def __init__( type_map: List[str], atom_exclude_types: List[int] = [], pair_exclude_types: List[Tuple[int, int]] = [], + rcond: Optional[float] = None, + preset_out_bias: Optional[Dict[str, np.ndarray]] = None, ): super().__init__() self.type_map = type_map self.reinit_atom_exclude(atom_exclude_types) self.reinit_pair_exclude(pair_exclude_types) + self.rcond = rcond + self.preset_out_bias = preset_out_bias + + def init_out_stat(self): + """Initialize the output bias.""" + ntypes = self.get_ntypes() + self.bias_keys: List[str] = list(self.fitting_output_def().keys()) + self.max_out_size = max( + [self.atomic_output_def()[kk].size for kk in self.bias_keys] + ) + self.n_out = len(self.bias_keys) + out_bias_data = np.zeros([self.n_out, ntypes, self.max_out_size]) + out_std_data = np.ones([self.n_out, ntypes, self.max_out_size]) + self.out_bias = out_bias_data + self.out_std = out_std_data + + def __setitem__(self, key, value): + if key in ["out_bias"]: + self.out_bias = value + elif key in ["out_std"]: + self.out_std = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ["out_bias"]: + return self.out_bias + elif key in ["out_std"]: + return self.out_std + else: + raise KeyError(key) def get_type_map(self) -> List[str]: """Get the type map.""" @@ -132,6 +166,7 @@ def forward_common_atomic( fparam=fparam, aparam=aparam, ) + ret_dict = self.apply_out_stat(ret_dict, atype) # nf x nloc atom_mask = ext_atom_mask[:, :nloc].astype(np.int32) @@ -150,6 +185,84 @@ def forward_common_atomic( def serialize(self) -> dict: return { + "type_map": self.type_map, "atom_exclude_types": self.atom_exclude_types, "pair_exclude_types": self.pair_exclude_types, + "rcond": self.rcond, + "preset_out_bias": self.preset_out_bias, + "@variables": { + "out_bias": self.out_bias, + "out_std": self.out_std, + }, } + + @classmethod + def deserialize(cls, data: dict) -> "BaseAtomicModel": + data = copy.deepcopy(data) + variables = data.pop("@variables") + obj = cls(**data) + for kk in variables.keys(): + obj[kk] = variables[kk] + return obj + + def apply_out_stat( + self, + ret: Dict[str, np.ndarray], + atype: np.ndarray, + ): + """Apply the stat to each atomic output. + The developer may override the method to define how the bias is applied + to the atomic output of the model. + + Parameters + ---------- + ret + The returned dict by the forward_atomic method + atype + The atom types. nf x nloc + + """ + out_bias, out_std = self._fetch_out_stat(self.bias_keys) + for kk in self.bias_keys: + # nf x nloc x odims, out_bias: ntypes x odims + ret[kk] = ret[kk] + out_bias[kk][atype] + return ret + + def _varsize( + self, + shape: List[int], + ) -> int: + output_size = 1 + len_shape = len(shape) + for i in range(len_shape): + output_size *= shape[i] + return output_size + + def _get_bias_index( + self, + kk: str, + ) -> int: + res: List[int] = [] + for i, e in enumerate(self.bias_keys): + if e == kk: + res.append(i) + assert len(res) == 1 + return res[0] + + def _fetch_out_stat( + self, + keys: List[str], + ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: + ret_bias = {} + ret_std = {} + ntypes = self.get_ntypes() + for kk in keys: + idx = self._get_bias_index(kk) + isize = self._varsize(self.atomic_output_def()[kk].shape) + ret_bias[kk] = self.out_bias[idx, :, :isize].reshape( + [ntypes] + list(self.atomic_output_def()[kk].shape) # noqa: RUF005 + ) + ret_std[kk] = self.out_std[idx, :, :isize].reshape( + [ntypes] + list(self.atomic_output_def()[kk].shape) # noqa: RUF005 + ) + return ret_bias, ret_std diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index d39e236d07..b13bfc17ba 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -49,11 +49,12 @@ def __init__( type_map: List[str], **kwargs, ): + super().__init__(type_map, **kwargs) self.type_map = type_map self.descriptor = descriptor self.fitting = fitting self.type_map = type_map - super().__init__(type_map, **kwargs) + super().init_out_stat() def fitting_output_def(self) -> FittingOutputDef: """Get the output def of the fitting net.""" @@ -136,7 +137,7 @@ def serialize(self) -> dict: { "@class": "Model", "type": "standard", - "@version": 1, + "@version": 2, "type_map": self.type_map, "descriptor": self.descriptor.serialize(), "fitting": self.fitting.serialize(), @@ -147,13 +148,14 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data) -> "DPAtomicModel": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) + check_version_compatibility(data.pop("@version", 1), 2, 2) data.pop("@class") data.pop("type") descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) fitting_obj = BaseFitting.deserialize(data.pop("fitting")) - type_map = data.pop("type_map") - obj = cls(descriptor_obj, fitting_obj, type_map=type_map, **data) + data["descriptor"] = descriptor_obj + data["fitting"] = fitting_obj + obj = super().deserialize(data) return obj def get_dim_fparam(self) -> int: diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index e4a85d7bc2..b38d309fd7 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -52,6 +52,8 @@ def __init__( type_map: List[str], **kwargs, ): + super().__init__(type_map, **kwargs) + super().init_out_stat() self.models = models sub_model_type_maps = [md.get_type_map() for md in models] err_msg = [] @@ -66,7 +68,6 @@ def __init__( self.mapping_list.append(self.remap_atype(tpmp, self.type_map)) assert len(err_msg) == 0, "\n".join(err_msg) self.mixed_types_list = [model.mixed_types() for model in self.models] - super().__init__(type_map, **kwargs) def mixed_types(self) -> bool: """If true, the model @@ -86,7 +87,7 @@ def get_rcut(self) -> float: def get_type_map(self) -> List[str]: """Get the type map.""" - raise self.type_map + return self.type_map def get_model_rcuts(self) -> List[float]: """Get the cut-off radius for each individual models.""" @@ -218,27 +219,30 @@ def fitting_output_def(self) -> FittingOutputDef: ) def serialize(self) -> dict: - return { - "@class": "Model", - "type": "linear", - "@version": 1, - "models": [model.serialize() for model in self.models], - "type_map": self.type_map, - } + dd = super().serialize() + dd.update( + { + "@class": "Model", + "@version": 2, + "type": "linear", + "models": [model.serialize() for model in self.models], + "type_map": self.type_map, + } + ) + return dd @classmethod def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) - data.pop("@class") - data.pop("type") - type_map = data.pop("type_map") + check_version_compatibility(data.pop("@version", 2), 2, 2) + data.pop("@class", None) + data.pop("type", None) models = [ BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model) for model in data["models"] ] - data.pop("models") - return cls(models, type_map, **data) + data["models"] = models + return super().deserialize(data) def _compute_weight( self, @@ -312,24 +316,21 @@ def __init__( **kwargs, ): models = [dp_model, zbl_model] - super().__init__(models, type_map, **kwargs) - self.dp_model = dp_model - self.zbl_model = zbl_model + kwargs["models"] = models + kwargs["type_map"] = type_map + super().__init__(**kwargs) self.sw_rmin = sw_rmin self.sw_rmax = sw_rmax self.smin_alpha = smin_alpha def serialize(self) -> dict: - dd = BaseAtomicModel.serialize(self) + dd = super().serialize() dd.update( { "@class": "Model", - "type": "zbl", "@version": 2, - "models": LinearEnergyAtomicModel( - models=[self.models[0], self.models[1]], type_map=self.type_map - ).serialize(), + "type": "zbl", "sw_rmin": self.sw_rmin, "sw_rmax": self.sw_rmax, "smin_alpha": self.smin_alpha, @@ -340,25 +341,15 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 2, 1) - data.pop("@class") - data.pop("type") - sw_rmin = data.pop("sw_rmin") - sw_rmax = data.pop("sw_rmax") - smin_alpha = data.pop("smin_alpha") - linear_model = LinearEnergyAtomicModel.deserialize(data.pop("models")) - dp_model, zbl_model = linear_model.models - type_map = linear_model.type_map - - return cls( - dp_model=dp_model, - zbl_model=zbl_model, - sw_rmin=sw_rmin, - sw_rmax=sw_rmax, - type_map=type_map, - smin_alpha=smin_alpha, - **data, - ) + check_version_compatibility(data.pop("@version", 1), 2, 2) + models = [ + BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model) + for model in data["models"] + ] + data["dp_model"], data["zbl_model"] = models[0], models[1] + data.pop("@class", None) + data.pop("type", None) + return super().deserialize(data) def _compute_weight( self, diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index 2d3bccb258..d3d179e6e2 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -64,6 +64,7 @@ def __init__( **kwargs, ): super().__init__(type_map, **kwargs) + super().init_out_stat() self.tab_file = tab_file self.rcut = rcut self.type_map = type_map @@ -136,7 +137,7 @@ def serialize(self) -> dict: { "@class": "Model", "type": "pairtab", - "@version": 1, + "@version": 2, "tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel, @@ -148,14 +149,13 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data) -> "PairTabAtomicModel": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) + check_version_compatibility(data.pop("@version", 1), 2, 2) data.pop("@class") data.pop("type") - rcut = data.pop("rcut") - sel = data.pop("sel") - type_map = data.pop("type_map") tab = PairTab.deserialize(data.pop("tab")) - tab_model = cls(None, rcut, sel, type_map, **data) + data["tab_file"] = None + tab_model = super().deserialize(data) + tab_model.tab = tab tab_model.tab_info = tab_model.tab.tab_info nspline, ntypes = tab_model.tab_info[-2:].astype(int) diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index e750b6a54e..10fa2a7bd9 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later - +import copy import logging from typing import ( Callable, @@ -31,6 +31,10 @@ from deepmd.pt.utils.stat import ( compute_output_stats, ) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) from deepmd.utils.path import ( DPPath, ) @@ -89,12 +93,8 @@ def init_out_stat(self): [self.atomic_output_def()[kk].size for kk in self.bias_keys] ) self.n_out = len(self.bias_keys) - out_bias_data = torch.zeros( - [self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device - ) - out_std_data = torch.ones( - [self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device - ) + out_bias_data = self._default_bias() + out_std_data = self._default_std() self.register_buffer("out_bias", out_bias_data) self.register_buffer("out_std", out_std_data) @@ -254,10 +254,37 @@ def forward_common_atomic( def serialize(self) -> dict: return { + "type_map": self.type_map, "atom_exclude_types": self.atom_exclude_types, "pair_exclude_types": self.pair_exclude_types, + "rcond": self.rcond, + "preset_out_bias": self.preset_out_bias, + "@variables": { + "out_bias": to_numpy_array(self.out_bias), + "out_std": to_numpy_array(self.out_std), + }, } + @classmethod + def deserialize(cls, data: dict) -> "BaseAtomicModel": + data = copy.deepcopy(data) + variables = data.pop("@variables", None) + variables = ( + {"out_bias": None, "out_std": None} if variables is None else variables + ) + obj = cls(**data) + obj["out_bias"] = ( + to_torch_tensor(variables["out_bias"]) + if variables["out_bias"] is not None + else obj._default_bias() + ) + obj["out_std"] = ( + to_torch_tensor(variables["out_std"]) + if variables["out_std"] is not None + else obj._default_std() + ) + return obj + def compute_or_load_stat( self, merged: Union[Callable[[], List[dict]], List[dict]], @@ -410,6 +437,18 @@ def model_forward(coord, atype, box, fparam=None, aparam=None): return model_forward + def _default_bias(self): + ntypes = self.get_ntypes() + return torch.zeros( + [self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device + ) + + def _default_std(self): + ntypes = self.get_ntypes() + return torch.ones( + [self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device + ) + def _varsize( self, shape: List[int], diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 6cbbf5ed7f..182196aca5 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -100,7 +100,7 @@ def serialize(self) -> dict: dd.update( { "@class": "Model", - "@version": 1, + "@version": 2, "type": "standard", "type_map": self.type_map, "descriptor": self.descriptor.serialize(), @@ -112,13 +112,14 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data) -> "DPAtomicModel": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) + check_version_compatibility(data.pop("@version", 1), 2, 1) data.pop("@class", None) data.pop("type", None) descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) fitting_obj = BaseFitting.deserialize(data.pop("fitting")) - type_map = data.pop("type_map", None) - obj = cls(descriptor_obj, fitting_obj, type_map=type_map, **data) + data["descriptor"] = descriptor_obj + data["fitting"] = fitting_obj + obj = super().deserialize(data) return obj def forward_atomic( diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index c5abc4575c..b58594d3ce 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -268,27 +268,30 @@ def fitting_output_def(self) -> FittingOutputDef: ) def serialize(self) -> dict: - return { - "@class": "Model", - "@version": 1, - "type": "linear", - "models": [model.serialize() for model in self.models], - "type_map": self.type_map, - } + dd = super().serialize() + dd.update( + { + "@class": "Model", + "@version": 2, + "type": "linear", + "models": [model.serialize() for model in self.models], + "type_map": self.type_map, + } + ) + return dd @classmethod def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) - data.pop("@class") - data.pop("type") - type_map = data.pop("type_map") + check_version_compatibility(data.get("@version", 2), 2, 1) + data.pop("@class", None) + data.pop("type", None) models = [ BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model) for model in data["models"] ] - data.pop("models") - return cls(models, type_map, **data) + data["models"] = models + return super().deserialize(data) def _compute_weight( self, extended_coord, extended_atype, nlists_ @@ -418,7 +421,9 @@ def __init__( **kwargs, ): models = [dp_model, zbl_model] - super().__init__(models, type_map, **kwargs) + kwargs["models"] = models + kwargs["type_map"] = type_map + super().__init__(**kwargs) self.sw_rmin = sw_rmin self.sw_rmax = sw_rmax @@ -428,15 +433,12 @@ def __init__( self.zbl_weight = torch.empty(0, dtype=torch.float64, device=env.DEVICE) def serialize(self) -> dict: - dd = BaseAtomicModel.serialize(self) + dd = super().serialize() dd.update( { "@class": "Model", "@version": 2, "type": "zbl", - "models": LinearEnergyAtomicModel( - models=[self.models[0], self.models[1]], type_map=self.type_map - ).serialize(), "sw_rmin": self.sw_rmin, "sw_rmax": self.sw_rmax, "smin_alpha": self.smin_alpha, @@ -448,24 +450,14 @@ def serialize(self) -> dict: def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 2, 1) - sw_rmin = data.pop("sw_rmin") - sw_rmax = data.pop("sw_rmax") - smin_alpha = data.pop("smin_alpha") - linear_model = LinearEnergyAtomicModel.deserialize(data.pop("models")) - dp_model, zbl_model = linear_model.models - type_map = linear_model.type_map - + models = [ + BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model) + for model in data["models"] + ] + data["dp_model"], data["zbl_model"] = models[0], models[1] data.pop("@class", None) data.pop("type", None) - return cls( - dp_model=dp_model, - zbl_model=zbl_model, - sw_rmin=sw_rmin, - sw_rmax=sw_rmax, - type_map=type_map, - smin_alpha=smin_alpha, - **data, - ) + return super().deserialize(data) def _compute_weight( self, diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index b4639fcbb4..4f8bce78e1 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -71,8 +71,6 @@ def __init__( rcut: float, sel: Union[int, List[int]], type_map: List[str], - rcond: Optional[float] = None, - atom_ener: Optional[List[float]] = None, **kwargs, ): super().__init__(type_map, **kwargs) @@ -81,8 +79,6 @@ def __init__( self.rcut = rcut self.tab = self._set_pairtab(tab_file, rcut) - self.rcond = rcond - self.atom_ener = atom_ener self.type_map = type_map self.ntypes = len(type_map) @@ -169,14 +165,12 @@ def serialize(self) -> dict: dd.update( { "@class": "Model", - "@version": 1, + "@version": 2, "type": "pairtab", "tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel, "type_map": self.type_map, - "rcond": self.rcond, - "atom_ener": self.atom_ener, } ) return dd @@ -184,16 +178,12 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data) -> "PairTabAtomicModel": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) - rcut = data.pop("rcut") - sel = data.pop("sel") - type_map = data.pop("type_map") - rcond = data.pop("rcond") - atom_ener = data.pop("atom_ener") + check_version_compatibility(data.pop("@version", 1), 2, 1) tab = PairTab.deserialize(data.pop("tab")) data.pop("@class", None) data.pop("type", None) - tab_model = cls(None, rcut, sel, type_map, rcond, atom_ener, **data) + data["tab_file"] = None + tab_model = super().deserialize(data) tab_model.tab = tab tab_model.register_buffer("tab_info", torch.from_numpy(tab_model.tab.tab_info)) diff --git a/deepmd/tf/model/model.py b/deepmd/tf/model/model.py index 76bcc6072b..fc8f862e3b 100644 --- a/deepmd/tf/model/model.py +++ b/deepmd/tf/model/model.py @@ -14,6 +14,8 @@ Union, ) +import numpy as np + from deepmd.common import ( j_get_type, ) @@ -785,11 +787,16 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": The deserialized descriptor """ data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 1, 1) + check_version_compatibility(data.pop("@version", 2), 2, 1) descriptor = Descriptor.deserialize(data.pop("descriptor"), suffix=suffix) fitting = Fitting.deserialize(data.pop("fitting"), suffix=suffix) + # BEGINE not supported keys data.pop("atom_exclude_types") data.pop("pair_exclude_types") + data.pop("rcond", None) + data.pop("preset_out_bias", None) + data.pop("@variables", None) + # END not supported keys return cls( descriptor=descriptor, fitting_net=fitting, @@ -813,14 +820,23 @@ def serialize(self, suffix: str = "") -> dict: raise NotImplementedError("type embedding is not supported") if self.spin is not None: raise NotImplementedError("spin is not supported") + + ntypes = len(self.get_type_map()) + dict_fit = self.fitting.serialize(suffix=suffix) return { "@class": "Model", "type": "standard", - "@version": 1, + "@version": 2, "type_map": self.type_map, "descriptor": self.descrpt.serialize(suffix=suffix), - "fitting": self.fitting.serialize(suffix=suffix), + "fitting": dict_fit, # not supported yet "atom_exclude_types": [], "pair_exclude_types": [], + "rcond": None, + "preset_out_bias": None, + "@variables": { + "out_bias": np.zeros([1, ntypes, dict_fit["dim_out"]]), + "out_std": np.ones([1, ntypes, dict_fit["dim_out"]]), + }, } diff --git a/source/tests/pt/model/test_atomic_model_stat.py b/source/tests/pt/model/test_atomic_model_stat.py index e266cf215a..3dc80a0155 100644 --- a/source/tests/pt/model/test_atomic_model_stat.py +++ b/source/tests/pt/model/test_atomic_model_stat.py @@ -12,6 +12,7 @@ import numpy as np import torch +from deepmd.dpmodel.atomic_model import DPAtomicModel as DPDPAtomicModel from deepmd.dpmodel.output_def import ( FittingOutputDef, OutputVariableDef, @@ -20,12 +21,16 @@ BaseAtomicModel, DPAtomicModel, ) -from deepmd.pt.model.descriptor.dpa1 import ( +from deepmd.pt.model.descriptor import ( DescrptDPA1, + DescrptSeA, ) from deepmd.pt.model.task.base_fitting import ( BaseFitting, ) +from deepmd.pt.model.task.ener import ( + InvarFitting, +) from deepmd.pt.utils import ( env, ) @@ -441,3 +446,50 @@ def cvt_ret(x): expected_ret1["bar"] = ret0["bar"] + bar_bias[at] for kk in ["foo", "pix", "bar"]: np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + def test_serialize(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "foo", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ).to(env.DEVICE) + type_map = ["A", "B"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(env.DEVICE) + args = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: to_numpy_array(vv) for kk, vv in x.items()} + + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + md1 = DPAtomicModel.deserialize(md0.serialize()) + ret1 = md1.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + + for kk in ["foo"]: + np.testing.assert_almost_equal(ret0[kk], ret1[kk]) + + md2 = DPDPAtomicModel.deserialize(md0.serialize()) + args = [self.coord_ext, self.atype_ext, self.nlist] + ret2 = md2.forward_common_atomic(*args) + for kk in ["foo"]: + np.testing.assert_almost_equal(ret0[kk], ret2[kk])