Skip to content

Commit

Permalink
feat: compute output stat for a dict of labels. (#3628)
Browse files Browse the repository at this point in the history
- add UT for it.
- at the moment, only energy is supported in the `base_atomic_model`.
handling of multiple output stat will be implemented in a future PR.

---------

Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
wanghan-iapcm and Han Wang authored Apr 1, 2024
1 parent 15e4926 commit 8e0cc90
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 52 deletions.
12 changes: 5 additions & 7 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,6 @@ def serialize(self) -> dict:

def get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]:
"""Get a forward wrapper of the atomic model for output bias calculation."""
model_output_type = list(self.atomic_output_def().keys())
if "mask" in model_output_type:
model_output_type.pop(model_output_type.index("mask"))
out_name = model_output_type[0]

def model_forward(coord, atype, box, fparam=None, aparam=None):
with torch.no_grad(): # it's essential for pure torch forward function to use auto_batchsize
Expand All @@ -220,7 +216,7 @@ def model_forward(coord, atype, box, fparam=None, aparam=None):
fparam=fparam,
aparam=aparam,
)
return atomic_ret[out_name].detach()
return {kk: vv.detach() for kk, vv in atomic_ret.items()}

return model_forward

Expand Down Expand Up @@ -287,14 +283,16 @@ def change_out_bias(
delta_bias = compute_output_stats(
merged,
self.get_ntypes(),
keys=["energy"],
model_forward=self.get_forward_wrapper_func(),
)
)["energy"]
self.set_out_bias(delta_bias, add=True)
elif bias_adjust_mode == "set-by-statistic":
bias_atom = compute_output_stats(
merged,
self.get_ntypes(),
)
keys=["energy"],
)["energy"]
self.set_out_bias(bias_atom)
else:
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)
Expand Down
9 changes: 7 additions & 2 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,13 @@ def compute_or_load_stat(
"""
bias_atom_e = compute_output_stats(
merged, self.ntypes, stat_file_path, self.rcond, self.atom_ener
)
merged,
self.ntypes,
keys=["energy"],
stat_file_path=stat_file_path,
rcond=self.rcond,
atom_ener=self.atom_ener,
)["energy"]
self.bias_atom_e.copy_(
torch.tensor(bias_atom_e, device=env.DEVICE).view([self.ntypes, 1])
)
Expand Down
9 changes: 7 additions & 2 deletions deepmd/pt/model/task/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,13 @@ def compute_output_stats(
"""
bias_atom_e = compute_output_stats(
merged, self.ntypes, stat_file_path, self.rcond, self.atom_ener
)
merged,
self.ntypes,
keys=["energy"],
stat_file_path=stat_file_path,
rcond=self.rcond,
atom_ener=self.atom_ener,
)["energy"]
self.bias_atom_e.copy_(bias_atom_e.view([self.ntypes, self.dim_out]))

def output_def(self) -> FittingOutputDef:
Expand Down
129 changes: 88 additions & 41 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,39 @@ def make_stat_input(datasets, dataloaders, nbatches):
return lst


def restore_from_file(
stat_file_path: DPPath,
keys: List[str] = ["energy"],
) -> Optional[dict]:
if stat_file_path is None:
return None
stat_files = [stat_file_path / f"bias_atom_{kk}" for kk in keys]
if any(not (ii.is_file()) for ii in stat_files):
return None
ret = {}

for kk in keys:
fp = stat_file_path / f"bias_atom_{kk}"
assert fp.is_file()
ret[kk] = fp.load_numpy()
return ret


def save_to_file(
stat_file_path: DPPath,
results: dict,
):
assert stat_file_path is not None
stat_file_path.mkdir(exist_ok=True, parents=True)
for kk, vv in results.items():
fp = stat_file_path / f"bias_atom_{kk}"
fp.save_numpy(vv)


def compute_output_stats(
merged: Union[Callable[[], List[dict]], List[dict]],
ntypes: int,
keys: List[str] = ["energy"],
stat_file_path: Optional[DPPath] = None,
rcond: Optional[float] = None,
atom_ener: Optional[List[float]] = None,
Expand Down Expand Up @@ -112,17 +142,15 @@ def compute_output_stats(
which will be subtracted from the energy label of the data.
The difference will then be used to calculate the delta complement energy bias for each type.
"""
if stat_file_path is not None:
stat_file_path = stat_file_path / "bias_atom_e"
if stat_file_path is not None and stat_file_path.is_file():
bias_atom_e = stat_file_path.load_numpy()
else:
bias_atom_e = restore_from_file(stat_file_path, keys)

if bias_atom_e is None:
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged
energy = [item["energy"] for item in sampled]
outputs = {kk: [item[kk] for item in sampled] for kk in keys}
data_mixed_type = "real_natoms_vec" in sampled[0]
natoms_key = "natoms" if not data_mixed_type else "real_natoms_vec"
for system in sampled:
Expand All @@ -133,7 +161,7 @@ def compute_output_stats(
system[natoms_key][:, 2:] *= type_mask.unsqueeze(0)
input_natoms = [item[natoms_key] for item in sampled]
# shape: (nframes, ndim)
merged_energy = to_numpy_array(torch.cat(energy))
merged_output = {kk: to_numpy_array(torch.cat(outputs[kk])) for kk in keys}
# shape: (nframes, ntypes)
merged_natoms = to_numpy_array(torch.cat(input_natoms)[:, 2:])
if atom_ener is not None and len(atom_ener) > 0:
Expand All @@ -144,16 +172,20 @@ def compute_output_stats(
assigned_atom_ener = None
if model_forward is None:
# only use statistics result
bias_atom_e, _ = compute_stats_from_redu(
merged_energy,
merged_natoms,
assigned_bias=assigned_atom_ener,
rcond=rcond,
)
# [0]: take the first otuput (mean) of compute_stats_from_redu
bias_atom_e = {
kk: compute_stats_from_redu(
merged_output[kk],
merged_natoms,
assigned_bias=assigned_atom_ener,
rcond=rcond,
)[0]
for kk in keys
}
else:
# subtract the model bias and output the delta bias
auto_batch_size = AutoBatchSize()
energy_predict = []
model_predict = {kk: [] for kk in keys}
for system in sampled:
nframes = system["coord"].shape[0]
coord, atype, box, natoms = (
Expand All @@ -174,34 +206,49 @@ def model_forward_auto_batch_size(*args, **kwargs):
**kwargs,
)

energy = (
model_forward_auto_batch_size(
coord, atype, box, fparam=fparam, aparam=aparam
)
.reshape(nframes, -1)
.sum(-1)
sample_predict = model_forward_auto_batch_size(
coord, atype, box, fparam=fparam, aparam=aparam
)
energy_predict.append(to_numpy_array(energy).reshape([nframes, 1]))

energy_predict = np.concatenate(energy_predict)
bias_diff = merged_energy - energy_predict
bias_atom_e, _ = compute_stats_from_redu(
bias_diff,
merged_natoms,
assigned_bias=assigned_atom_ener,
rcond=rcond,
)
unbias_e = energy_predict + merged_natoms @ bias_atom_e

for kk in keys:
model_predict[kk].append(
to_numpy_array(
torch.sum(sample_predict[kk], dim=1) # nf x nloc x odims
)
)

model_predict = {kk: np.concatenate(model_predict[kk]) for kk in keys}

bias_diff = {kk: merged_output[kk] - model_predict[kk] for kk in keys}
bias_atom_e = {
kk: compute_stats_from_redu(
bias_diff[kk],
merged_natoms,
assigned_bias=assigned_atom_ener,
rcond=rcond,
)[0]
for kk in keys
}
unbias_e = {
kk: model_predict[kk] + merged_natoms @ bias_atom_e[kk] for kk in keys
}
atom_numbs = merged_natoms.sum(-1)
rmse_ae = np.sqrt(
np.mean(
np.square((unbias_e.ravel() - merged_energy.ravel()) / atom_numbs)
for kk in keys:
rmse_ae = np.sqrt(
np.mean(
np.square(
(unbias_e[kk].ravel() - merged_output[kk].ravel())
/ atom_numbs
)
)
)
)
log.info(
f"RMSE of energy per atom after linear regression is: {rmse_ae} eV/atom."
)
log.info(
f"RMSE of {kk} per atom after linear regression is: {rmse_ae} in the unit of {kk}."
)

if stat_file_path is not None:
stat_file_path.save_numpy(bias_atom_e)
assert all(x is not None for x in [bias_atom_e])
return to_torch_tensor(bias_atom_e)
save_to_file(stat_file_path, bias_atom_e)

ret = {kk: to_torch_tensor(bias_atom_e[kk]) for kk in keys}

return ret
100 changes: 100 additions & 0 deletions source/tests/pt/test_stat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import os
import tempfile
import unittest
from abc import (
ABC,
Expand All @@ -11,6 +12,7 @@
)

import dpdata
import h5py
import numpy as np
import torch

Expand All @@ -29,7 +31,14 @@
from deepmd.pt.utils.dataloader import (
DpLoaderSet,
)
from deepmd.pt.utils.stat import (
compute_output_stats,
)
from deepmd.pt.utils.stat import make_stat_input
from deepmd.pt.utils.stat import make_stat_input as my_make
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.tf.common import (
expand_sys_str,
)
Expand All @@ -47,6 +56,9 @@
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.path import (
DPPath,
)

CUR_DIR = os.path.dirname(__file__)

Expand Down Expand Up @@ -325,5 +337,93 @@ def tf_compute_input_stats(self):
)


class TestOutputStat(unittest.TestCase):
def setUp(self):
self.data_file = [str(Path(__file__).parent / "water/data/data_0")]
self.type_map = ["O", "H"] # by dataset
self.data = DpLoaderSet(
self.data_file,
batch_size=1,
type_map=self.type_map,
)
self.data.add_data_requirement(energy_data_requirement)
self.sampled = make_stat_input(
self.data.systems,
self.data.dataloaders,
nbatches=1,
)
self.tempdir = tempfile.TemporaryDirectory()
h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve())
with h5py.File(h5file, "w") as f:
pass
self.stat_file_path = DPPath(h5file, "a")

def tearDown(self):
self.tempdir.cleanup()

def test_calc_and_load(self):
stat_file_path = self.stat_file_path
type_map = self.type_map

# compute from sample
ret0 = compute_output_stats(
self.sampled,
len(type_map),
keys=["energy"],
stat_file_path=stat_file_path,
atom_ener=None,
model_forward=None,
)
# ground truth
ntest = 1
atom_nums = np.tile(
np.bincount(to_numpy_array(self.sampled[0]["atype"][0])),
(ntest, 1),
)
energy_diff = to_numpy_array(self.sampled[0]["energy"][:ntest])
ground_truth_shift = np.linalg.lstsq(atom_nums, energy_diff, rcond=None)[0]

# check values
np.testing.assert_almost_equal(
to_numpy_array(ret0["energy"]), ground_truth_shift, decimal=10
)
# self.assertTrue(stat_file_path.is_dir())

def raise_error():
raise RuntimeError

# hack!!!
# suppose to load stat from file, if from sample, an error will raise.
ret1 = compute_output_stats(
raise_error,
len(type_map),
keys=["energy"],
stat_file_path=stat_file_path,
atom_ener=None,
model_forward=None,
)
np.testing.assert_almost_equal(
to_numpy_array(ret0["energy"]), to_numpy_array(ret1["energy"]), decimal=10
)

def test_assigned(self):
atom_ener = np.array([3.0, 5.0]).reshape(2, 1)
stat_file_path = self.stat_file_path
type_map = self.type_map

# from assigned atom_ener
ret2 = compute_output_stats(
self.sampled,
len(type_map),
keys=["energy"],
stat_file_path=stat_file_path,
atom_ener=atom_ener,
model_forward=None,
)
np.testing.assert_almost_equal(
to_numpy_array(ret2["energy"]), atom_ener, decimal=10
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8e0cc90

Please sign in to comment.