Skip to content

Commit

Permalink
Chore: migrate and refactor polar and dos bias (#3662)
Browse files Browse the repository at this point in the history
- [x] remove polar and dos bias from fitting net
- [x] update apply bias
- [x] update UT for polar stat

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
anyangml and pre-commit-ci[bot] authored Apr 18, 2024
1 parent cac8715 commit f550f24
Show file tree
Hide file tree
Showing 26 changed files with 452 additions and 802 deletions.
4 changes: 2 additions & 2 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def output_def(self):
return FittingOutputDef(
[
OutputVariableDef(
self.var_name,
"polarizability",
[3, 3],
reduciable=True,
r_differentiable=False,
Expand Down Expand Up @@ -280,4 +280,4 @@ def call(
# (nframes, nloc, 3, 3)
bias = np.expand_dims(bias, axis=-1) * eye
out = out + bias
return {self.var_name: out}
return {"polarizability": out}
8 changes: 4 additions & 4 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ def test_polar(
polar = polar.reshape((polar.shape[0], -1, 9))[:, sel_mask, :].reshape(
(polar.shape[0], -1)
)
rmse_f = rmse(polar - test_data["atomic_polarizability"][:numb_test])
rmse_f = rmse(polar - test_data["atom_polarizability"][:numb_test])

log.info(f"# number of test data : {numb_test:d} ")
log.info(f"Polarizability RMSE : {rmse_f:e}")
Expand Down Expand Up @@ -926,7 +926,7 @@ def test_polar(
pe = np.concatenate(
(
np.reshape(
test_data["atomic_polarizability"][:numb_test],
test_data["atom_polarizability"][:numb_test],
[-1, 9 * sel_natoms],
),
np.reshape(polar, [-1, 9 * sel_natoms]),
Expand Down Expand Up @@ -1037,7 +1037,7 @@ def test_dipole(
dipole = dipole.reshape((dipole.shape[0], -1, 3))[:, sel_mask, :].reshape(
(dipole.shape[0], -1)
)
rmse_f = rmse(dipole - test_data["atomic_dipole"][:numb_test])
rmse_f = rmse(dipole - test_data["atom_dipole"][:numb_test])

log.info(f"# number of test data : {numb_test:d}")
log.info(f"Dipole RMSE : {rmse_f:e}")
Expand All @@ -1061,7 +1061,7 @@ def test_dipole(
pe = np.concatenate(
(
np.reshape(
test_data["atomic_dipole"][:numb_test], [-1, 3 * sel_natoms]
test_data["atom_dipole"][:numb_test], [-1, 3 * sel_natoms]
),
np.reshape(dipole, [-1, 3 * sel_natoms]),
),
Expand Down
6 changes: 3 additions & 3 deletions deepmd/pt/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
if (
self.has_local_weight
and self.tensor_name in model_pred
and "atomic_" + self.label_name in label
and "atom_" + self.label_name in label
):
find_local = label.get("find_" + "atomic_" + self.label_name, 0.0)
find_local = label.get("find_" + "atom_" + self.label_name, 0.0)
local_weight = self.local_weight * find_local
local_tensor_pred = model_pred[self.tensor_name].reshape(
[-1, natoms, self.tensor_size]
)
local_tensor_label = label["atomic_" + self.label_name].reshape(
local_tensor_label = label["atom_" + self.label_name].reshape(
[-1, natoms, self.tensor_size]
)
diff = (local_tensor_pred - local_tensor_label).reshape(
Expand Down
36 changes: 35 additions & 1 deletion deepmd/pt/model/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,39 @@ def apply_out_stat(
ret: Dict[str, torch.Tensor],
atype: torch.Tensor,
):
# TODO: migrate bias
"""Apply the stat to each atomic output.
Parameters
----------
ret
The returned dict by the forward_atomic method
atype
The atom types. nf x nloc
"""
out_bias, out_std = self._fetch_out_stat(self.bias_keys)

if self.fitting_net.shift_diag:
nframes, nloc = atype.shape
device = out_bias[self.bias_keys[0]].device
dtype = out_bias[self.bias_keys[0]].dtype
for kk in self.bias_keys:
ntypes = out_bias[kk].shape[0]
temp = torch.zeros(ntypes, dtype=dtype, device=device)
for i in range(ntypes):
temp[i] = torch.mean(torch.diagonal(out_bias[kk][i].reshape(3, 3)))
modified_bias = temp[atype]

# (nframes, nloc, 1)
modified_bias = (
modified_bias.unsqueeze(-1) * self.fitting_net.scale[atype]
)

eye = torch.eye(3, dtype=dtype, device=device)
eye = eye.repeat(nframes, nloc, 1, 1)
# (nframes, nloc, 3, 3)
modified_bias = modified_bias.unsqueeze(-1) * eye

# nf x nloc x odims, out_bias: ntypes x odims
ret[kk] = ret[kk] + modified_bias
return ret
8 changes: 4 additions & 4 deletions deepmd/pt/model/model/polar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def forward(
)
if self.get_fitting_net() is not None:
model_predict = {}
model_predict["polar"] = model_ret["polar"]
model_predict["global_polar"] = model_ret["polar_redu"]
model_predict["polar"] = model_ret["polarizability"]
model_predict["global_polar"] = model_ret["polarizability_redu"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
else:
Expand Down Expand Up @@ -85,8 +85,8 @@ def forward_lower(
)
if self.get_fitting_net() is not None:
model_predict = {}
model_predict["polar"] = model_ret["polar"]
model_predict["global_polar"] = model_ret["polar_redu"]
model_predict["polar"] = model_ret["polarizability"]
model_predict["global_polar"] = model_ret["polarizability_redu"]
else:
model_predict = model_ret
return model_predict
66 changes: 0 additions & 66 deletions deepmd/pt/model/task/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
import copy
import logging
from typing import (
Callable,
List,
Optional,
Union,
)

import numpy as np
import torch

from deepmd.dpmodel import (
Expand All @@ -30,13 +28,6 @@
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.out_stat import (
compute_stats_from_atomic,
compute_stats_from_redu,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -105,63 +96,6 @@ def output_def(self) -> FittingOutputDef:
]
)

def compute_output_stats(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
) -> None:
"""
Compute the output statistics (e.g. dos bias) for the fitting net from packed data.
Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.
"""
if stat_file_path is not None:
stat_file_path = stat_file_path / "bias_dos"
if stat_file_path is not None and stat_file_path.is_file():
bias_dos = stat_file_path.load_numpy()
else:
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged
for sys in range(len(sampled)):
nframs = sampled[sys]["atype"].shape[0]

if "atom_dos" in sampled[sys]:
bias_dos = compute_stats_from_atomic(
sampled[sys]["atom_dos"].numpy(force=True),
sampled[sys]["atype"].numpy(force=True),
)[0]
else:
sys_type_count = np.zeros(
(nframs, self.ntypes), dtype=env.GLOBAL_NP_FLOAT_PRECISION
)
for itype in range(self.ntypes):
type_mask = sampled[sys]["atype"] == itype
sys_type_count[:, itype] = type_mask.sum(dim=1).numpy(
force=True
)
sys_bias_redu = sampled[sys]["dos"].numpy(force=True)

bias_dos = compute_stats_from_redu(
sys_bias_redu, sys_type_count, rcond=self.rcond
)[0]
if stat_file_path is not None:
stat_file_path.save_numpy(bias_dos)
self.bias_dos = torch.tensor(bias_dos, device=env.DEVICE)

@classmethod
def deserialize(cls, data: dict) -> "DOSFittingNet":
data = copy.deepcopy(data)
Expand Down
42 changes: 0 additions & 42 deletions deepmd/pt/model/task/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
import copy
import logging
from typing import (
Callable,
List,
Optional,
Union,
)

import torch
Expand All @@ -24,12 +22,6 @@
from deepmd.pt.utils.env import (
DEFAULT_PRECISION,
)
from deepmd.pt.utils.stat import (
compute_output_stats,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -146,40 +138,6 @@ def deserialize(cls, data: dict) -> "GeneralFitting":
check_version_compatibility(data.pop("@version", 1), 1, 1)
return super().deserialize(data)

def compute_output_stats(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
):
"""
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.
"""
# [0] to get the mean (bias)
bias_atom_e = compute_output_stats(
merged,
self.ntypes,
keys=[self.var_name],
stat_file_path=stat_file_path,
rcond=self.rcond,
preset_bias={self.var_name: self.atom_ener}
if self.atom_ener is not None
else None,
)[0][self.var_name]
self.bias_atom_e.copy_(bias_atom_e.view([self.ntypes, self.dim_out]))

def output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
Expand Down
Loading

0 comments on commit f550f24

Please sign in to comment.