From a26b680308f02b4c0a3648aac7b76bebef0ae0ef Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 11 Apr 2024 09:17:58 +0800 Subject: [PATCH] Chore: refactor atomic bias (#3654) `var_name` is no longer required as user input for `polar` and `dipole`. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/dpmodel/fitting/dipole_fitting.py | 7 +- .../dpmodel/fitting/polarizability_fitting.py | 7 +- deepmd/pt/model/task/dipole.py | 5 +- deepmd/pt/model/task/polarizability.py | 5 +- deepmd/pt/utils/stat.py | 369 +++++++++++++--- deepmd/tf/fit/dipole.py | 1 - deepmd/tf/fit/polar.py | 1 - deepmd/utils/out_stat.py | 1 + .../tests/consistent/fitting/test_dipole.py | 1 - source/tests/consistent/fitting/test_polar.py | 1 - .../pt/model/test_atomic_model_atomic_stat.py | 406 ++++++++++++++++++ ...at.py => test_atomic_model_global_stat.py} | 2 + .../pt/model/test_linear_atomic_model_stat.py | 1 + source/tests/pt/test_training.py | 1 + 14 files changed, 723 insertions(+), 85 deletions(-) create mode 100644 source/tests/pt/model/test_atomic_model_atomic_stat.py rename source/tests/pt/model/{test_atomic_model_stat.py => test_atomic_model_global_stat.py} (99%) diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index 6d6324770c..98325f41ee 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -36,8 +36,6 @@ class DipoleFitting(GeneralFitting): Parameters ---------- - var_name - The name of the output variable. ntypes The number of atom types. dim_descrpt @@ -86,7 +84,6 @@ class DipoleFitting(GeneralFitting): def __init__( self, - var_name: str, ntypes: int, dim_descrpt: int, embedding_width: int, @@ -124,7 +121,7 @@ def __init__( self.r_differentiable = r_differentiable self.c_differentiable = c_differentiable super().__init__( - var_name=var_name, + var_name="dipole", ntypes=ntypes, dim_descrpt=dim_descrpt, neuron=neuron, @@ -161,6 +158,8 @@ def serialize(self) -> dict: def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 1, 1) + var_name = data.pop("var_name", None) + assert var_name == "dipole" return super().deserialize(data) def output_def(self): diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 5d75037137..2a691e963d 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -39,8 +39,6 @@ class PolarFitting(GeneralFitting): Parameters ---------- - var_name - The name of the output variable. ntypes The number of atom types. dim_descrpt @@ -88,7 +86,6 @@ class PolarFitting(GeneralFitting): def __init__( self, - var_name: str, ntypes: int, dim_descrpt: int, embedding_width: int, @@ -145,7 +142,7 @@ def __init__( self.shift_diag = shift_diag self.constant_matrix = np.zeros(ntypes, dtype=GLOBAL_NP_FLOAT_PRECISION) super().__init__( - var_name=var_name, + var_name="polar", ntypes=ntypes, dim_descrpt=dim_descrpt, neuron=neuron, @@ -201,6 +198,8 @@ def serialize(self) -> dict: def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 2, 1) + var_name = data.pop("var_name", None) + assert var_name == "polar" return super().deserialize(data) def output_def(self): diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index ca445c8588..cddbbf5291 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -39,8 +39,6 @@ class DipoleFittingNet(GeneralFitting): Parameters ---------- - var_name : str - The atomic property to fit, 'dipole'. ntypes : int Element count. dim_descrpt : int @@ -97,7 +95,7 @@ def __init__( self.r_differentiable = r_differentiable self.c_differentiable = c_differentiable super().__init__( - var_name=kwargs.pop("var_name", "dipole"), + var_name="dipole", ntypes=ntypes, dim_descrpt=dim_descrpt, neuron=neuron, @@ -131,6 +129,7 @@ def serialize(self) -> dict: def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("var_name", None) return super().deserialize(data) def output_def(self) -> FittingOutputDef: diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 544d23555c..cd944996be 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -47,8 +47,6 @@ class PolarFittingNet(GeneralFitting): Parameters ---------- - var_name : str - The atomic property to fit, 'polar'. ntypes : int Element count. dim_descrpt : int @@ -127,7 +125,7 @@ def __init__( ntypes, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE ) super().__init__( - var_name=kwargs.pop("var_name", "polar"), + var_name="polar", ntypes=ntypes, dim_descrpt=dim_descrpt, neuron=neuron, @@ -180,6 +178,7 @@ def serialize(self) -> dict: def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 2, 1) + data.pop("var_name", None) return super().deserialize(data) def output_def(self) -> FittingOutputDef: diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index d85741b231..77da1e01f1 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +from collections import ( + defaultdict, +) from typing import ( Callable, Dict, @@ -23,6 +26,7 @@ to_torch_tensor, ) from deepmd.utils.out_stat import ( + compute_stats_from_atomic, compute_stats_from_redu, ) from deepmd.utils.path import ( @@ -171,10 +175,9 @@ def model_forward_auto_batch_size(*args, **kwargs): for kk in keys: model_predict[kk].append( to_numpy_array( - torch.sum(sample_predict[kk], dim=1) # nf x nloc x odims + sample_predict[kk] # nf x nloc x odims ) ) - model_predict = {kk: np.concatenate(model_predict[kk]) for kk in keys} return model_predict @@ -203,6 +206,31 @@ def _make_preset_out_bias( return np.array(nbias) +def _fill_stat_with_global( + atomic_stat: Union[np.ndarray, None], + global_stat: np.ndarray, +): + """This function is used to fill atomic stat with global stat. + + Parameters + ---------- + atomic_stat : Union[np.ndarray, None] + The atomic stat. + global_stat : np.ndarray + The global stat. + if the atomic stat is None, use global stat. + if the atomic stat is not None, but has nan values (missing atypes), fill with global stat. + """ + if atomic_stat is None: + return global_stat + else: + return np.nan_to_num( + np.where( + np.isnan(atomic_stat) & ~np.isnan(global_stat), global_stat, atomic_stat + ) + ) + + def compute_output_stats( merged: Union[Callable[[], List[dict]], List[dict]], ntypes: int, @@ -246,87 +274,294 @@ def compute_output_stats( # failed to restore the bias from stat file. compute if bias_atom_e is None: - # only get data for once + # only get data once, sampled is a list of dict[str, torch.Tensor] sampled = merged() if callable(merged) else merged + if model_forward is not None: + model_pred = _compute_model_predict(sampled, keys, model_forward) + else: + model_pred = None + # remove the keys that are not in the sample keys = [keys] if isinstance(keys, str) else keys assert isinstance(keys, list) - new_keys = [ii for ii in keys if ii in sampled[0].keys()] + new_keys = [ + ii + for ii in keys + if (ii in sampled[0].keys()) or ("atom_" + ii in sampled[0].keys()) + ] del keys keys = new_keys - # get label dict from sample - outputs = {kk: [item[kk] for item in sampled] for kk in keys} - data_mixed_type = "real_natoms_vec" in sampled[0] - natoms_key = "natoms" if not data_mixed_type else "real_natoms_vec" - for system in sampled: - if "atom_exclude_types" in system: - type_mask = AtomExcludeMask( - ntypes, system["atom_exclude_types"] - ).get_type_mask() - system[natoms_key][:, 2:] *= type_mask.unsqueeze(0) - input_natoms = [item[natoms_key] for item in sampled] - # shape: (nframes, ndim) - merged_output = {kk: to_numpy_array(torch.cat(outputs[kk])) for kk in keys} - # shape: (nframes, ntypes) - merged_natoms = to_numpy_array(torch.cat(input_natoms)[:, 2:]) - nf = merged_natoms.shape[0] - if preset_bias is not None: - assigned_atom_ener = { - kk: _make_preset_out_bias(ntypes, preset_bias[kk]) - if kk in preset_bias.keys() - else None - for kk in keys - } - else: - assigned_atom_ener = {kk: None for kk in keys} - - if model_forward is None: - stats_input = merged_output - else: - # subtract the model bias and output the delta bias - model_predict = _compute_model_predict(sampled, keys, model_forward) - stats_input = {kk: merged_output[kk] - model_predict[kk] for kk in keys} + # split system based on label + atomic_sampled_idx = defaultdict(list) + global_sampled_idx = defaultdict(list) - bias_atom_e = {} - std_atom_e = {} for kk in keys: - bias_atom_e[kk], std_atom_e[kk] = compute_stats_from_redu( - stats_input[kk], - merged_natoms, - assigned_bias=assigned_atom_ener[kk], - rcond=rcond, - ) - bias_atom_e, std_atom_e = _post_process_stat(bias_atom_e, std_atom_e) + for idx, system in enumerate(sampled): + if (("find_atom_" + kk) in system) and ( + system["find_atom_" + kk] > 0.0 + ): + atomic_sampled_idx[kk].append(idx) + elif (("find_" + kk) in system) and (system["find_" + kk] > 0.0): + global_sampled_idx[kk].append(idx) + + else: + continue + + # use index to gather model predictions for the corresponding systems. + + model_pred_g = ( + { + kk: [vv[idx] for idx in global_sampled_idx[kk]] + for kk, vv in model_pred.items() + } + if model_pred + else None + ) + model_pred_a = ( + { + kk: [vv[idx] for idx in atomic_sampled_idx[kk]] + for kk, vv in model_pred.items() + } + if model_pred + else None + ) - # unbias_e is only used for print rmse - if model_forward is None: - unbias_e = { - kk: merged_natoms @ bias_atom_e[kk].reshape(ntypes, -1) for kk in keys + # concat all frames within those systmes + model_pred_g = ( + { + kk: np.concatenate(model_pred_g[kk]) + for kk in model_pred_g.keys() + if len(model_pred_g[kk]) > 0 } - else: - unbias_e = { - kk: model_predict[kk].reshape(nf, -1) - + merged_natoms @ bias_atom_e[kk].reshape(ntypes, -1) - for kk in keys + if model_pred + else None + ) + model_pred_a = ( + { + kk: np.concatenate(model_pred_a[kk]) + for kk in model_pred_a.keys() + if len(model_pred_a[kk]) > 0 } - atom_numbs = merged_natoms.sum(-1) + if model_pred + else None + ) - def rmse(x): - return np.sqrt(np.mean(np.square(x))) + # compute stat + bias_atom_g, std_atom_g = compute_output_stats_global( + sampled, + ntypes, + keys, + rcond, + preset_bias, + model_pred_g, + ) + bias_atom_a, std_atom_a = compute_output_stats_atomic( + sampled, + ntypes, + keys, + model_pred_a, + ) + # merge global/atomic bias + bias_atom_e, std_atom_e = {}, {} for kk in keys: - rmse_ae = rmse( - (unbias_e[kk].reshape(nf, -1) - merged_output[kk].reshape(nf, -1)) - / atom_numbs[:, None] - ) - log.info( - f"RMSE of {kk} per atom after linear regression is: {rmse_ae} in the unit of {kk}." - ) + # use atomic bias whenever available + if kk in bias_atom_a: + bias_atom_e[kk] = bias_atom_a[kk] + std_atom_e[kk] = std_atom_a[kk] + else: + bias_atom_e[kk] = None + std_atom_e[kk] = None + # use global bias to fill missing atomic bias + if kk in bias_atom_g: + bias_atom_e[kk] = _fill_stat_with_global( + bias_atom_e[kk], bias_atom_g[kk] + ) + std_atom_e[kk] = _fill_stat_with_global(std_atom_e[kk], std_atom_g[kk]) + if (bias_atom_e[kk] is None) or (std_atom_e[kk] is None): + raise RuntimeError("Fail to compute stat.") if stat_file_path is not None: _save_to_file(stat_file_path, bias_atom_e, std_atom_e) - ret_bias = {kk: to_torch_tensor(vv) for kk, vv in bias_atom_e.items()} - ret_std = {kk: to_torch_tensor(vv) for kk, vv in std_atom_e.items()} + bias_atom_e = {kk: to_torch_tensor(vv) for kk, vv in bias_atom_e.items()} + std_atom_e = {kk: to_torch_tensor(vv) for kk, vv in std_atom_e.items()} + return bias_atom_e, std_atom_e - return ret_bias, ret_std + +def compute_output_stats_global( + sampled: List[dict], + ntypes: int, + keys: List[str], + rcond: Optional[float] = None, + preset_bias: Optional[Dict[str, List[Optional[torch.Tensor]]]] = None, + model_pred: Optional[Dict[str, np.ndarray]] = None, +): + """This function only handle stat computation from reduced global labels.""" + # get label dict from sample; for each key, only picking the system with global labels. + outputs = { + kk: [ + system[kk] + for system in sampled + if kk in system and system.get(f"find_{kk}", 0) > 0 + ] + for kk in keys + } + + data_mixed_type = "real_natoms_vec" in sampled[0] + natoms_key = "natoms" if not data_mixed_type else "real_natoms_vec" + for system in sampled: + if "atom_exclude_types" in system: + type_mask = AtomExcludeMask( + ntypes, system["atom_exclude_types"] + ).get_type_mask() + system[natoms_key][:, 2:] *= type_mask.unsqueeze(0) + + input_natoms = { + kk: [ + item[natoms_key] + for item in sampled + if kk in item and item.get(f"find_{kk}", 0) > 0 + ] + for kk in keys + } + # shape: (nframes, ndim) + merged_output = { + kk: to_numpy_array(torch.cat(outputs[kk])) + for kk in keys + if len(outputs[kk]) > 0 + } + # shape: (nframes, ntypes) + + merged_natoms = { + kk: to_numpy_array(torch.cat(input_natoms[kk])[:, 2:]) + for kk in keys + if len(input_natoms[kk]) > 0 + } + nf = {kk: merged_natoms[kk].shape[0] for kk in keys if kk in merged_natoms} + if preset_bias is not None: + assigned_atom_ener = { + kk: _make_preset_out_bias(ntypes, preset_bias[kk]) + if kk in preset_bias.keys() + else None + for kk in keys + } + else: + assigned_atom_ener = {kk: None for kk in keys} + + if model_pred is None: + stats_input = merged_output + else: + # subtract the model bias and output the delta bias + + model_pred = {kk: np.sum(model_pred[kk], axis=1) for kk in keys} + stats_input = { + kk: merged_output[kk] - model_pred[kk] for kk in keys if kk in merged_output + } + + bias_atom_e = {} + std_atom_e = {} + for kk in keys: + if kk in stats_input: + bias_atom_e[kk], std_atom_e[kk] = compute_stats_from_redu( + stats_input[kk], + merged_natoms[kk], + assigned_bias=assigned_atom_ener[kk], + rcond=rcond, + ) + else: + # this key does not have global labels, skip it. + continue + bias_atom_e, std_atom_e = _post_process_stat(bias_atom_e, std_atom_e) + + # unbias_e is only used for print rmse + + if model_pred is None: + unbias_e = { + kk: merged_natoms[kk] @ bias_atom_e[kk].reshape(ntypes, -1) + for kk in bias_atom_e.keys() + } + else: + unbias_e = { + kk: model_pred[kk].reshape(nf[kk], -1) + + merged_natoms[kk] @ bias_atom_e[kk].reshape(ntypes, -1) + for kk in bias_atom_e.keys() + } + atom_numbs = {kk: merged_natoms[kk].sum(-1) for kk in bias_atom_e.keys()} + + def rmse(x): + return np.sqrt(np.mean(np.square(x))) + + for kk in bias_atom_e.keys(): + rmse_ae = rmse( + (unbias_e[kk].reshape(nf[kk], -1) - merged_output[kk].reshape(nf[kk], -1)) + / atom_numbs[kk][:, None] + ) + log.info( + f"RMSE of {kk} per atom after linear regression is: {rmse_ae} in the unit of {kk}." + ) + return bias_atom_e, std_atom_e + + +def compute_output_stats_atomic( + sampled: List[dict], + ntypes: int, + keys: List[str], + model_pred: Optional[Dict[str, np.ndarray]] = None, +): + # get label dict from sample; for each key, only picking the system with atomic labels. + outputs = { + kk: [ + system["atom_" + kk] + for system in sampled + if ("atom_" + kk) in system and system.get(f"find_atom_{kk}", 0) > 0 + ] + for kk in keys + } + natoms = { + kk: [ + system["atype"] + for system in sampled + if ("atom_" + kk) in system and system.get(f"find_atom_{kk}", 0) > 0 + ] + for kk in keys + } + # shape: (nframes, nloc, ndim) + merged_output = { + kk: to_numpy_array(torch.cat(outputs[kk])) + for kk in keys + if len(outputs[kk]) > 0 + } + merged_natoms = { + kk: to_numpy_array(torch.cat(natoms[kk])) for kk in keys if len(natoms[kk]) > 0 + } + + if model_pred is None: + stats_input = merged_output + else: + # subtract the model bias and output the delta bias + stats_input = { + kk: merged_output[kk] - model_pred[kk] for kk in keys if kk in merged_output + } + + bias_atom_e = {} + std_atom_e = {} + + for kk in keys: + if kk in stats_input: + bias_atom_e[kk], std_atom_e[kk] = compute_stats_from_atomic( + stats_input[kk], + merged_natoms[kk], + ) + # correction for missing types + missing_types = ntypes - merged_natoms[kk].max() - 1 + if missing_types > 0: + nan_padding = np.empty((missing_types, bias_atom_e[kk].shape[1])) + nan_padding.fill(np.nan) + bias_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding], axis=0) + std_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding], axis=0) + else: + # this key does not have atomic labels, skip it. + continue + bias_atom_e, std_atom_e = _post_process_stat(bias_atom_e, std_atom_e) + return bias_atom_e, std_atom_e diff --git a/deepmd/tf/fit/dipole.py b/deepmd/tf/fit/dipole.py index f98d52c7bd..d99c793415 100644 --- a/deepmd/tf/fit/dipole.py +++ b/deepmd/tf/fit/dipole.py @@ -362,7 +362,6 @@ def serialize(self, suffix: str) -> dict: "@class": "Fitting", "type": "dipole", "@version": 1, - "var_name": "dipole", "ntypes": self.ntypes, "dim_descrpt": self.dim_descrpt, "embedding_width": self.dim_rot_mat_1, diff --git a/deepmd/tf/fit/polar.py b/deepmd/tf/fit/polar.py index 473b57ff54..c124bd3ef4 100644 --- a/deepmd/tf/fit/polar.py +++ b/deepmd/tf/fit/polar.py @@ -555,7 +555,6 @@ def serialize(self, suffix: str) -> dict: "@class": "Fitting", "type": "polar", "@version": 1, - "var_name": "polar", "ntypes": self.ntypes, "dim_descrpt": self.dim_descrpt, "embedding_width": self.dim_rot_mat_1, diff --git a/deepmd/utils/out_stat.py b/deepmd/utils/out_stat.py index 1dcbcb1280..9678f8ed72 100644 --- a/deepmd/utils/out_stat.py +++ b/deepmd/utils/out_stat.py @@ -112,6 +112,7 @@ def compute_stats_from_atomic( assert output.ndim == 3 assert atype.ndim == 2 assert output.shape[:2] == atype.shape + # compute output bias nframes, nloc, ndim = output.shape ntypes = atype.max() + 1 diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index 18a29934ca..4f33d58c10 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -94,7 +94,6 @@ def addtional_data(self) -> dict: "ntypes": self.ntypes, "dim_descrpt": self.inputs.shape[-1], "mixed_types": mixed_types, - "var_name": "dipole", "embedding_width": 30, } diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index 5b55c6d333..a6e0e07784 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -94,7 +94,6 @@ def addtional_data(self) -> dict: "ntypes": self.ntypes, "dim_descrpt": self.inputs.shape[-1], "mixed_types": mixed_types, - "var_name": "polar", "embedding_width": 30, } diff --git a/source/tests/pt/model/test_atomic_model_atomic_stat.py b/source/tests/pt/model/test_atomic_model_atomic_stat.py new file mode 100644 index 0000000000..8f365a09fe --- /dev/null +++ b/source/tests/pt/model/test_atomic_model_atomic_stat.py @@ -0,0 +1,406 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tempfile +import unittest +from pathlib import ( + Path, +) +from typing import ( + Optional, +) + +import h5py +import numpy as np +import torch + +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt.model.atomic_model import ( + BaseAtomicModel, + DPAtomicModel, +) +from deepmd.pt.model.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pt.model.task.base_fitting import ( + BaseFitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) +from deepmd.utils.path import ( + DPPath, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class FooFitting(torch.nn.Module, BaseFitting): + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef( + "foo", + [1], + reduciable=True, + r_differentiable=True, + c_differentiable=True, + ), + OutputVariableDef( + "bar", + [1, 2], + reduciable=True, + r_differentiable=True, + c_differentiable=True, + ), + ] + ) + + def serialize(self) -> dict: + raise NotImplementedError + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: Optional[torch.Tensor] = None, + g2: Optional[torch.Tensor] = None, + h2: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ): + nf, nloc, _ = descriptor.shape + ret = {} + ret["foo"] = ( + torch.Tensor( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ) + .view([nf, nloc, *self.output_def()["foo"].shape]) + .to(env.GLOBAL_PT_FLOAT_PRECISION) + .to(env.DEVICE) + ) + ret["bar"] = ( + torch.Tensor( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ) + .view([nf, nloc, *self.output_def()["bar"].shape]) + .to(env.GLOBAL_PT_FLOAT_PRECISION) + .to(env.DEVICE) + ) + return ret + + +class TestAtomicModelStat(unittest.TestCase, TestCaseSingleFrameWithNlist): + def tearDown(self): + self.tempdir.cleanup() + + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + self.merged_output_stat = [ + { + "coord": to_torch_tensor(np.zeros([2, 3, 3])), + "atype": to_torch_tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32) + ), + "atype_ext": to_torch_tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32) + ), + "box": to_torch_tensor(np.zeros([2, 3, 3])), + "natoms": to_torch_tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32) + ), + # bias of foo: 5, 6 + "atom_foo": to_torch_tensor( + np.array([[5.0, 5.0, 5.0], [5.0, 6.0, 7.0]]).reshape(2, 3, 1) + ), + # bias of bar: [1, 5], [3, 2] + "bar": to_torch_tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2) + ), + "find_atom_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + { + "coord": to_torch_tensor(np.zeros([2, 3, 3])), + "atype": to_torch_tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32) + ), + "atype_ext": to_torch_tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32) + ), + "box": to_torch_tensor(np.zeros([2, 3, 3])), + "natoms": to_torch_tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32) + ), + # bias of foo: 5, 6 from atomic label. + "foo": to_torch_tensor(np.array([5.0, 7.0]).reshape(2, 1)), + # bias of bar: [1, 5], [3, 2] + "bar": to_torch_tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2) + ), + "find_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_output_stat(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptDPA1( + self.rcut, + self.rcut_smth, + sum(self.sel), + self.nt, + ).to(env.DEVICE) + ft = FooFitting().to(env.DEVICE) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(env.DEVICE) + args = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: to_numpy_array(vv) for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape]) + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape]) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + # nt x odim + foo_bias = np.array([5.0, 6.0]).reshape(2, 1) + bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + # 3. test bias load from file + def raise_error(): + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + ret2 = cvt_ret(ret2) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + + # 4. test change bias + BaseAtomicModel.change_out_bias( + md0, self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + args = [ + to_torch_tensor(ii) + for ii in [ + self.coord_ext, + to_numpy_array(self.merged_output_stat[0]["atype_ext"]), + self.nlist, + ] + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + + expected_ret3 = {} + # new bias [2.666, 1.333] + expected_ret3["foo"] = np.array( + [[3.6667, 4.6667, 4.3333], [6.6667, 6.3333, 7.3333]] + ).reshape(2, 3, 1) + for kk in ["foo"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4) + + +class TestAtomicModelStatMergeGlobalAtomic( + unittest.TestCase, TestCaseSingleFrameWithNlist +): + def tearDown(self): + self.tempdir.cleanup() + + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + self.merged_output_stat = [ + { + "coord": to_torch_tensor(np.zeros([2, 3, 3])), + "atype": to_torch_tensor( + np.array([[0, 0, 0], [0, 0, 0]], dtype=np.int32) + ), + "atype_ext": to_torch_tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32) + ), + "box": to_torch_tensor(np.zeros([2, 3, 3])), + "natoms": to_torch_tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32) + ), + # bias of foo: 5.5, nan + "atom_foo": to_torch_tensor( + np.array([[5.0, 5.0, 5.0], [5.0, 6.0, 7.0]]).reshape(2, 3, 1) + ), + # bias of bar: [1, 5], [3, 2] + "bar": to_torch_tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2) + ), + "find_atom_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + { + "coord": to_torch_tensor(np.zeros([2, 3, 3])), + "atype": to_torch_tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32) + ), + "atype_ext": to_torch_tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32) + ), + "box": to_torch_tensor(np.zeros([2, 3, 3])), + "natoms": to_torch_tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32) + ), + # bias of foo: 5.5, 3 from atomic label. + "foo": to_torch_tensor(np.array([5.0, 7.0]).reshape(2, 1)), + # bias of bar: [1, 5], [3, 2] + "bar": to_torch_tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2) + ), + "find_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_output_stat(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptDPA1( + self.rcut, + self.rcut_smth, + sum(self.sel), + self.nt, + ).to(env.DEVICE) + ft = FooFitting().to(env.DEVICE) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(env.DEVICE) + args = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: to_numpy_array(vv) for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape]) + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape]) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + # nt x odim + foo_bias = np.array([5.5, 3.0]).reshape(2, 1) + bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + # 3. test bias load from file + def raise_error(): + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + ret2 = cvt_ret(ret2) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + + # 4. test change bias + BaseAtomicModel.change_out_bias( + md0, self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + args = [ + to_torch_tensor(ii) + for ii in [ + self.coord_ext, + to_numpy_array(self.merged_output_stat[0]["atype_ext"]), + self.nlist, + ] + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + expected_ret3 = {} + # new bias [2, -5] + expected_ret3["foo"] = np.array([[3, 4, -2], [6, 0, 1]]).reshape(2, 3, 1) + for kk in ["foo"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4) diff --git a/source/tests/pt/model/test_atomic_model_stat.py b/source/tests/pt/model/test_atomic_model_global_stat.py similarity index 99% rename from source/tests/pt/model/test_atomic_model_stat.py rename to source/tests/pt/model/test_atomic_model_global_stat.py index 3dc80a0155..ca71b604ce 100644 --- a/source/tests/pt/model/test_atomic_model_stat.py +++ b/source/tests/pt/model/test_atomic_model_global_stat.py @@ -155,6 +155,8 @@ def setUp(self): "bar": to_torch_tensor( np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2) ), + "find_foo": np.float32(1.0), + "find_bar": np.float32(1.0), } ] self.tempdir = tempfile.TemporaryDirectory() diff --git a/source/tests/pt/model/test_linear_atomic_model_stat.py b/source/tests/pt/model/test_linear_atomic_model_stat.py index 010cecf9f8..f7feeda550 100644 --- a/source/tests/pt/model/test_linear_atomic_model_stat.py +++ b/source/tests/pt/model/test_linear_atomic_model_stat.py @@ -154,6 +154,7 @@ def setUp(self): ), # bias of foo: 1, 3 "energy": to_torch_tensor(np.array([5.0, 7.0]).reshape(2, 1)), + "find_energy": np.float32(1.0), } ] self.tempdir = tempfile.TemporaryDirectory() diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 1635ad56ea..f0a988607e 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -106,6 +106,7 @@ def setUp(self): self.config["training"]["training_data"]["systems"] = data_file self.config["training"]["validation_data"]["systems"] = data_file self.config["model"] = deepcopy(model_dos) + self.config["model"]["type_map"] = ["H"] self.config["training"]["numb_steps"] = 1 self.config["training"]["save_freq"] = 1 self.not_all_grad = True