Skip to content

Commit

Permalink
fix: manage testing models in a standard way (#4028)
Browse files Browse the repository at this point in the history
Fix #2103. Migrate three models (se_e2_a, se_e2_r, and fparam_aparam)
for the Python unit tests. Fix several bugs. Old files are kept until
the C++ tests are also migrated.

Note that several models (for example, the dipole model due to #3672)
cannot be serialized yet.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


## Summary by CodeRabbit

- **New Features**
- Introduced a structured framework for managing and testing models with
YAML files.
- Added comprehensive configurations for energy calculations and
molecular simulations in YAML format.
- Implemented new test cases for the `DeepPot` and `DeepPotNeighborList`
classes.

- **Bug Fixes**
- Improved robustness in tensor reshaping, resolving potential dimension
mismatches.

- **Tests**
- Enhanced unit tests with a case-based approach for better adaptability
and maintainability.
- Consolidated tests by relocating obsolete classes to streamline the
test suite.

- **Chores**
- Updated deserialization functions for better type safety and input
handling.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Aug 5, 2024
1 parent 3af248f commit 8201ebc
Show file tree
Hide file tree
Showing 22 changed files with 8,964 additions and 1,723 deletions.
22 changes: 13 additions & 9 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,26 +380,28 @@ def _eval_model(
natoms = len(atom_types[0])

coord_input = torch.tensor(
coords.reshape([-1, natoms, 3]),
coords.reshape([nframes, natoms, 3]),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
if cells is not None:
box_input = torch.tensor(
cells.reshape([-1, 3, 3]),
cells.reshape([nframes, 3, 3]),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
else:
box_input = None
if fparam is not None:
fparam_input = to_torch_tensor(fparam.reshape(-1, self.get_dim_fparam()))
fparam_input = to_torch_tensor(
fparam.reshape(nframes, self.get_dim_fparam())
)
else:
fparam_input = None
if aparam is not None:
aparam_input = to_torch_tensor(
aparam.reshape(-1, natoms, self.get_dim_aparam())
aparam.reshape(nframes, natoms, self.get_dim_aparam())
)
else:
aparam_input = None
Expand Down Expand Up @@ -451,31 +453,33 @@ def _eval_model_spin(
natoms = len(atom_types[0])

coord_input = torch.tensor(
coords.reshape([-1, natoms, 3]),
coords.reshape([nframes, natoms, 3]),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
spin_input = torch.tensor(
spins.reshape([-1, natoms, 3]),
spins.reshape([nframes, natoms, 3]),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
if cells is not None:
box_input = torch.tensor(
cells.reshape([-1, 3, 3]),
cells.reshape([nframes, 3, 3]),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
else:
box_input = None
if fparam is not None:
fparam_input = to_torch_tensor(fparam.reshape(-1, self.get_dim_fparam()))
fparam_input = to_torch_tensor(
fparam.reshape(nframes, self.get_dim_fparam())
)
else:
fparam_input = None
if aparam is not None:
aparam_input = to_torch_tensor(
aparam.reshape(-1, natoms, self.get_dim_aparam())
aparam.reshape(nframes, natoms, self.get_dim_aparam())
)
else:
aparam_input = None
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
model = BaseModel.deserialize(data["model"])
# JIT will happy in this way...
model.model_def_script = json.dumps(data["model_def_script"])
model = torch.jit.script(model)
if "min_nbor_dist" in data.get("@variables", {}):
model.min_nbor_dist = data["@variables"]["min_nbor_dist"]
model.min_nbor_dist = float(data["@variables"]["min_nbor_dist"])
model = torch.jit.script(model)
torch.jit.save(model, model_file)
3 changes: 2 additions & 1 deletion deepmd/tf/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,10 +720,11 @@ def prod_force_virial(
"""
[net_deriv] = tf.gradients(atom_ener, self.descrpt_reshape)
tf.summary.histogram("net_derivative", net_deriv)
nf = tf.shape(self.nlist)[0]
net_deriv_reshape = tf.reshape(
net_deriv,
[
np.asarray(-1, dtype=np.int64),
nf,
natoms[0] * np.asarray(self.ndescrpt, dtype=np.int64),
],
)
Expand Down
3 changes: 2 additions & 1 deletion deepmd/tf/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,10 +512,11 @@ def prod_force_virial(
"""
[net_deriv] = tf.gradients(atom_ener, self.descrpt_reshape)
tf.summary.histogram("net_derivative", net_deriv)
nf = tf.shape(self.nlist)[0]
net_deriv_reshape = tf.reshape(
net_deriv,
[
np.asarray(-1, dtype=np.int64),
nf,
natoms[0] * np.asarray(self.ndescrpt, dtype=np.int64),
],
)
Expand Down
4 changes: 2 additions & 2 deletions deepmd/tf/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
if model.get_numb_fparam() > 0:
inputs["fparam"] = tf.placeholder(
GLOBAL_TF_FLOAT_PRECISION,
[None, model.get_numb_fparam()],
[None],
name="t_fparam",
)
if model.get_numb_aparam() > 0:
inputs["aparam"] = tf.placeholder(
GLOBAL_TF_FLOAT_PRECISION,
[None, model.get_numb_aparam()],
[None],
name="t_aparam",
)
model.build(
Expand Down
27 changes: 10 additions & 17 deletions source/tests/consistent/model/test_frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,36 +31,29 @@
from deepmd.tf.model.model import Model as FrozenModelTF
else:
FrozenModelTF = None
from pathlib import (
Path,
)

from deepmd.entrypoints.convert_backend import (
convert_backend,
)
from deepmd.utils.argcheck import (
model_args,
)

original_model = str(Path(__file__).parent.parent.parent / "infer" / "deeppot.dp")
from ...infer.case import (
get_cases,
)

pt_model = "deeppot_for_consistent_frozen.pth"
tf_model = "deeppot_for_consistent_frozen.pb"
dp_model = original_model
dp_model = "deeppot_for_consistent_frozen.dp"


def setUpModule():
convert_backend(
INPUT=dp_model,
OUTPUT=tf_model,
)
convert_backend(
INPUT=dp_model,
OUTPUT=pt_model,
)
case = get_cases()["se_e2_a"]
case.get_model(".dp", dp_model)
case.get_model(".pb", tf_model)
case.get_model(".pth", pt_model)


def tearDownModule():
for model_file in (pt_model, tf_model):
for model_file in (dp_model, pt_model, tf_model):
try:
os.remove(model_file)
except FileNotFoundError:
Expand Down
1 change: 1 addition & 0 deletions source/tests/infer/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
deepmd_test_models*/
1 change: 1 addition & 0 deletions source/tests/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
195 changes: 195 additions & 0 deletions source/tests/infer/case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Manage testing models in a standard way.
For each model, a YAML file ending with `-testcase.yaml` must be given. It should contains the following keys:
- `key`: The key of the model.
- `filename`: The path to the model file.
- `ntypes`: The number of atomic types.
- `rcut`: The cutoff radius.
- `type_map`: The mapping between atomic types and atomic names.
- `dim_fparam`: The number of frame parameters.
- `dim_aparam`: The number of atomic parameters.
- `results`: A list of results. Each result should contain the following keys:
- `atype`: The atomic types.
- `coord`: The atomic coordinates.
- `box`: The simulation box.
- `atomic_energy` or `energy` (optional): The atomic energies or the total energy.
- `force` (optional): The atomic forces.
- `atomic_virial` or `virial` (optional): The atomic virials or the total virial.
"""

import tempfile
from functools import (
lru_cache,
)
from pathlib import (
Path,
)
from typing import (
Dict,
Optional,
)

import numpy as np
import yaml

from deepmd.entrypoints.convert_backend import (
convert_backend,
)

this_directory = Path(__file__).parent.resolve()
# create a temporary directory under this directory
# to store the temporary model files
# it will be deleted when the program exits
tempdir = tempfile.TemporaryDirectory(dir=this_directory, prefix="deepmd_test_models_")


class Result:
"""Test results.
Parameters
----------
data : dict
Dictionary containing the results.
Attributes
----------
atype : np.ndarray
The atomic types.
nloc : int
The number of atoms.
coord : np.ndarray
The atomic coordinates.
box : np.ndarray
The simulation box.
atomic_energy : np.ndarray
The atomic energies.
energy : np.ndarray
The total energy.
force : np.ndarray
The atomic forces.
atomic_virial : np.ndarray
The atomic virials.
virial : np.ndarray
The total virial.
"""

def __init__(self, data: dict) -> None:
self.atype = np.array(data["atype"], dtype=np.int64)
self.nloc = self.atype.size
self.coord = np.array(data["coord"], dtype=np.float64).reshape(self.nloc, 3)
if data["box"] is not None:
self.box = np.array(data["box"], dtype=np.float64).reshape(3, 3)
else:
self.box = None
if "fparam" in data:
self.fparam = np.array(data["fparam"], dtype=np.float64).ravel()
else:
self.fparam = None
if "aparam" in data:
self.aparam = np.array(data["aparam"], dtype=np.float64).reshape(
self.nloc, -1
)
else:
self.aparam = None
if "atomic_energy" in data:
self.atomic_energy = np.array(
data["atomic_energy"], dtype=np.float64
).reshape(self.nloc, 1)
self.energy = np.sum(self.atomic_energy, axis=0)
elif "energy" in data:
self.atomic_energy = None
self.energy = np.array(data["energy"], dtype=np.float64).reshape(1)
else:
self.atomic_energy = None
self.energy = None
if "force" in data:
self.force = np.array(data["force"], dtype=np.float64).reshape(self.nloc, 3)
else:
self.force = None
if "atomic_virial" in data:
self.atomic_virial = np.array(
data["atomic_virial"], dtype=np.float64
).reshape(self.nloc, 9)
self.virial = np.sum(self.atomic_virial, axis=0)
elif "virial" in data:
self.atomic_virial = None
self.virial = np.array(data["virial"], dtype=np.float64).reshape(9)
else:
self.atomic_virial = None
self.virial = None
if "descriptor" in data:
self.descriptor = np.array(data["descriptor"], dtype=np.float64).reshape(
self.nloc, -1
)
else:
self.descriptor = None


class Case:
"""Test case.
Parameters
----------
filename : str
The path to the test case file.
"""

def __init__(self, filename: str):
with open(filename) as file:
config = yaml.safe_load(file)
self.key = config["key"]
self.filename = str(Path(filename).parent / config["filename"])
self.results = [Result(data) for data in config["results"]]
self.ntypes = config["ntypes"]
self.rcut = config["rcut"]
self.type_map = config["type_map"]
self.dim_fparam = config["dim_fparam"]
self.dim_aparam = config["dim_aparam"]

@lru_cache
def get_model(self, suffix: str, out_file: Optional[str] = None) -> str:
"""Get the model file with the specified suffix.
Parameters
----------
suffix : str
The suffix of the model file.
out_file : str, optional
The path to the output model file. If not given, a temporary file will be created.
Returns
-------
str
The path to the model file.
"""
# generate a temporary model file
if out_file is None:
out_file = tempfile.NamedTemporaryFile(
suffix=suffix, dir=tempdir.name, delete=False, prefix=self.key + "_"
).name
convert_backend(INPUT=self.filename, OUTPUT=out_file)
return out_file


@lru_cache
def get_cases() -> Dict[str, Case]:
"""Get all test cases.
Returns
-------
Dict[str, Case]
A dictionary containing all test cases.
Examples
--------
To get a specific case:
>>> get_cases()["se_e2_a"]
"""
cases = {}
for ff in this_directory.glob("*-testcase.yaml"):
case = Case(ff)
cases[case.key] = case
return cases
Loading

0 comments on commit 8201ebc

Please sign in to comment.