Skip to content

Commit

Permalink
compute output stat for atomic model (#3642)
Browse files Browse the repository at this point in the history
This PR:
- breaking change: the base atomic model is now a module. 
  - reason: the out stat is a data attribute of the base atomic model. 
- implement the `compute_or_load_output_stat` for the base atomic model.
the method computes both bias and std.
- the derived atomic models call the `compute_or_load_output_stat`
method for computing output stat.
- atomic model provides the `apply_out_stat`, the derived class may
override the method to define how the statistics is applied to an atomic
model's output. @anyangml may need.
- `out_stat` support statistics of output tensor of any shape. 

@iProzd please check if i took it correctly in
[ce7ec1f](ce7ec1f)

To be done:
- atomic statistics of the bias and std. @anyangml 
- erialization and deserialization.

---------

Signed-off-by: Han Wang <[email protected]>
Co-authored-by: Han Wang <[email protected]>
Co-authored-by: Duo <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Apr 7, 2024
1 parent edb9da8 commit 39d027e
Show file tree
Hide file tree
Showing 18 changed files with 823 additions and 121 deletions.
6 changes: 6 additions & 0 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,19 @@
class BaseAtomicModel(BaseAtomicModel_):
def __init__(
self,
type_map: List[str],
atom_exclude_types: List[int] = [],
pair_exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
self.type_map = type_map
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)

def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

def reinit_atom_exclude(
self,
exclude_types: List[int] = [],
Expand Down
6 changes: 1 addition & 5 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
self.descriptor = descriptor
self.fitting = fitting
self.type_map = type_map
super().__init__(**kwargs)
super().__init__(type_map, **kwargs)

def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
Expand All @@ -67,10 +67,6 @@ def get_sel(self) -> List[int]:
"""Get the neighbor selection."""
return self.descriptor.get_sel()

def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

def mixed_types(self) -> bool:
"""If true, the model
1. assumes total number of atoms aligned across frames;
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
self.mapping_list.append(self.remap_atype(tpmp, self.type_map))
assert len(err_msg) == 0, "\n".join(err_msg)
self.mixed_types_list = [model.mixed_types() for model in self.models]
super().__init__(**kwargs)
super().__init__(type_map, **kwargs)

def mixed_types(self) -> bool:
"""If true, the model
Expand Down
3 changes: 0 additions & 3 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ def atomic_output_def(self) -> FittingOutputDef:
"""
return self.fitting_output_def()

def get_output_keys(self) -> List[str]:
return list(self.atomic_output_def().keys())

@abstractmethod
def get_rcut(self) -> float:
"""Get the cut-off radius."""
Expand Down
6 changes: 5 additions & 1 deletion deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,20 @@ def __init__(
rcut: float,
sel: Union[int, List[int]],
type_map: List[str],
rcond: Optional[float] = None,
atom_ener: Optional[List[float]] = None,
**kwargs,
):
super().__init__()
super().__init__(type_map, **kwargs)
self.tab_file = tab_file
self.rcut = rcut
self.type_map = type_map

self.tab = PairTab(self.tab_file, rcut=rcut)
self.type_map = type_map
self.ntypes = len(type_map)
self.rcond = rcond
self.atom_ener = atom_ener

if self.tab_file is not None:
self.tab_info, self.tab_data = self.tab.get()
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ def __init__(
if not self.r_differentiable:
raise ValueError("only r_differentiable variable can calculate hessian")

@property
def size(self):
return self.output_size


class FittingOutputDef:
"""Defines the shapes and other properties of the fitting network outputs.
Expand Down
212 changes: 202 additions & 10 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from deepmd.pt.utils import (
AtomExcludeMask,
PairExcludeMask,
env,
)
from deepmd.pt.utils.nlist import (
extend_input_and_build_neighbor_list,
Expand All @@ -35,19 +36,88 @@
)

log = logging.getLogger(__name__)
dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE

BaseAtomicModel_ = make_base_atomic_model(torch.Tensor)


class BaseAtomicModel(BaseAtomicModel_):
class BaseAtomicModel(torch.nn.Module, BaseAtomicModel_):
"""The base of atomic model.
Parameters
----------
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
atom_exclude_types
Exclude the atomic contribution of the given types
pair_exclude_types
Exclude the pair of atoms of the given types from computing the output
of the atomic model. Implemented by removing the pairs from the nlist.
rcond : float, optional
The condition number for the regression of atomic energy.
preset_out_bias : Dict[str, List[Optional[torch.Tensor]]], optional
Specifying atomic energy contribution in vacuum. Given by key:value pairs.
The value is a list specifying the bias. the elements can be None or np.array of output shape.
For example: [None, [2.]] means type 0 is not set, type 1 is set to [2.]
The `set_davg_zero` key in the descrptor should be set.
"""

def __init__(
self,
type_map: List[str],
atom_exclude_types: List[int] = [],
pair_exclude_types: List[Tuple[int, int]] = [],
rcond: Optional[float] = None,
preset_out_bias: Optional[Dict[str, torch.Tensor]] = None,
):
super().__init__()
torch.nn.Module.__init__(self)
BaseAtomicModel_.__init__(self)
self.type_map = type_map
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)
self.rcond = rcond
self.preset_out_bias = preset_out_bias

def init_out_stat(self):
"""Initialize the output bias."""
ntypes = self.get_ntypes()
self.bias_keys: List[str] = list(self.fitting_output_def().keys())
self.max_out_size = max(
[self.atomic_output_def()[kk].size for kk in self.bias_keys]
)
self.n_out = len(self.bias_keys)
out_bias_data = torch.zeros(
[self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device
)
out_std_data = torch.ones(
[self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device
)
self.register_buffer("out_bias", out_bias_data)
self.register_buffer("out_std", out_std_data)

def __setitem__(self, key, value):
if key in ["out_bias"]:
self.out_bias = value
elif key in ["out_std"]:
self.out_std = value
else:
raise KeyError(key)

def __getitem__(self, key):
if key in ["out_bias"]:
return self.out_bias
elif key in ["out_std"]:
return self.out_std
else:
raise KeyError(key)

@torch.jit.export
def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

def reinit_atom_exclude(
self,
Expand Down Expand Up @@ -165,6 +235,7 @@ def forward_common_atomic(
fparam=fparam,
aparam=aparam,
)
ret_dict = self.apply_out_stat(ret_dict, atype)

# nf x nloc
atom_mask = ext_atom_mask[:, :nloc].to(torch.int32)
Expand Down Expand Up @@ -210,9 +281,60 @@ def compute_or_load_stat(
"""
raise NotImplementedError

def compute_or_load_out_stat(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
):
"""
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.
"""
self.change_out_bias(
merged,
stat_file_path=stat_file_path,
bias_adjust_mode="set-by-statistic",
)

def apply_out_stat(
self,
ret: Dict[str, torch.Tensor],
atype: torch.Tensor,
):
"""Apply the stat to each atomic output.
The developer may override the method to define how the bias is applied
to the atomic output of the model.
Parameters
----------
ret
The returned dict by the forward_atomic method
atype
The atom types. nf x nloc
"""
out_bias, out_std = self._fetch_out_stat(self.bias_keys)
for kk in self.bias_keys:
# nf x nloc x odims, out_bias: ntypes x odims
ret[kk] = ret[kk] + out_bias[kk][atype]
return ret

def change_out_bias(
self,
sample_merged,
stat_file_path: Optional[DPPath] = None,
bias_adjust_mode="change-by-statistic",
) -> None:
"""Change the output bias according to the input data and the pretrained model.
Expand All @@ -231,22 +353,32 @@ def change_out_bias(
'change-by-statistic' : perform predictions on labels of target dataset,
and do least square on the errors to obtain the target shift as bias.
'set-by-statistic' : directly use the statistic output bias in the target dataset.
stat_file_path : Optional[DPPath]
The path to the stat file.
"""
if bias_adjust_mode == "change-by-statistic":
delta_bias = compute_output_stats(
delta_bias, out_std = compute_output_stats(
sample_merged,
self.get_ntypes(),
keys=self.get_output_keys(),
keys=list(self.atomic_output_def().keys()),
stat_file_path=stat_file_path,
model_forward=self._get_forward_wrapper_func(),
)["energy"]
self.set_out_bias(delta_bias, add=True)
rcond=self.rcond,
preset_bias=self.preset_out_bias,
)
# self.set_out_bias(delta_bias, add=True)
self._store_out_stat(delta_bias, out_std, add=True)
elif bias_adjust_mode == "set-by-statistic":
bias_atom = compute_output_stats(
bias_out, std_out = compute_output_stats(
sample_merged,
self.get_ntypes(),
keys=self.get_output_keys(),
)["energy"]
self.set_out_bias(bias_atom)
keys=list(self.atomic_output_def().keys()),
stat_file_path=stat_file_path,
rcond=self.rcond,
preset_bias=self.preset_out_bias,
)
# self.set_out_bias(bias_out)
self._store_out_stat(bias_out, std_out)
else:
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)

Expand Down Expand Up @@ -279,3 +411,63 @@ def model_forward(coord, atype, box, fparam=None, aparam=None):
return {kk: vv.detach() for kk, vv in atomic_ret.items()}

return model_forward

def _varsize(
self,
shape: List[int],
) -> int:
output_size = 1
len_shape = len(shape)
for i in range(len_shape):
output_size *= shape[i]
return output_size

def _get_bias_index(
self,
kk: str,
) -> int:
res: List[int] = []
for i, e in enumerate(self.bias_keys):
if e == kk:
res.append(i)
assert len(res) == 1
return res[0]

def _store_out_stat(
self,
out_bias: Dict[str, torch.Tensor],
out_std: Dict[str, torch.Tensor],
add: bool = False,
):
ntypes = self.get_ntypes()
out_bias_data = torch.clone(self.out_bias)
out_std_data = torch.clone(self.out_std)
for kk in out_bias.keys():
assert kk in out_std.keys()
idx = self._get_bias_index(kk)
size = self._varsize(self.atomic_output_def()[kk].shape)
if not add:
out_bias_data[idx, :, :size] = out_bias[kk].view(ntypes, size)
else:
out_bias_data[idx, :, :size] += out_bias[kk].view(ntypes, size)
out_std_data[idx, :, :size] = out_std[kk].view(ntypes, size)
self.out_bias.copy_(out_bias_data)
self.out_std.copy_(out_std_data)

def _fetch_out_stat(
self,
keys: List[str],
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
ret_bias = {}
ret_std = {}
ntypes = self.get_ntypes()
for kk in keys:
idx = self._get_bias_index(kk)
isize = self._varsize(self.atomic_output_def()[kk].shape)
ret_bias[kk] = self.out_bias[idx, :, :isize].view(
[ntypes] + list(self.atomic_output_def()[kk].shape) # noqa: RUF005
)
ret_std[kk] = self.out_std[idx, :, :isize].view(
[ntypes] + list(self.atomic_output_def()[kk].shape) # noqa: RUF005
)
return ret_bias, ret_std
Loading

0 comments on commit 39d027e

Please sign in to comment.