Skip to content

Commit

Permalink
pt: add real atoms select in c++ interface (#3375)
Browse files Browse the repository at this point in the history
select real atoms may be merged in DeepPot class for both backends in
the future.

---------

Signed-off-by: Lysithea <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jinzhe Zeng <[email protected]>
Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
4 people authored Mar 1, 2024
1 parent 759bdcb commit ee8b82b
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 76 deletions.
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def reinit_pair_exclude(
# export public methods that are not abstract
get_nsel = torch.jit.export(BaseAtomicModel_.get_nsel)
get_nnei = torch.jit.export(BaseAtomicModel_.get_nnei)
get_ntypes = torch.jit.export(BaseAtomicModel_.get_ntypes)

@torch.jit.export
def get_model_def_script(self) -> str:
Expand Down
1 change: 0 additions & 1 deletion deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ def output_type_cast(
)
return model_ret

@torch.jit.export
def format_nlist(
self,
extended_coord: torch.Tensor,
Expand Down
6 changes: 3 additions & 3 deletions source/api_cc/include/DeepPotPT.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once

#include <torch/script.h>
#include <torch/torch.h>

#include "DeepPot.h"
#include "commonPT.h"

namespace deepmd {
/**
Expand Down Expand Up @@ -106,7 +106,7 @@ class DeepPotPT : public DeepPotBase {
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box,
// const int nghost,
const int nghost,
const InputNlist& lmp_list,
const int& ago,
const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
Expand Down Expand Up @@ -322,7 +322,7 @@ class DeepPotPT : public DeepPotBase {
// copy neighbor list info from host
torch::jit::script::Module module;
double rcut;
NeighborListDataPT nlist_data;
NeighborListData nlist_data;
int max_num_neighbors;
int gpu_id;
bool gpu_enabled;
Expand Down
1 change: 1 addition & 0 deletions source/api_cc/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct NeighborListData {
void shuffle(const deepmd::AtomMap& map);
void shuffle_exclude_empty(const std::vector<int>& fwd_map);
void make_inlist(InputNlist& inlist);
void padding();
};

/**
Expand Down
24 changes: 0 additions & 24 deletions source/api_cc/include/commonPT.h

This file was deleted.

82 changes: 59 additions & 23 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@

#include "common.h"
using namespace deepmd;
torch::Tensor createNlistTensor(const std::vector<std::vector<int>>& data) {
std::vector<torch::Tensor> row_tensors;

for (const auto& row : data) {
torch::Tensor row_tensor = torch::tensor(row, torch::kInt32).unsqueeze(0);
row_tensors.push_back(row_tensor);
}

torch::Tensor tensor = torch::cat(row_tensors, 0).unsqueeze(0);
return tensor;
}
DeepPotPT::DeepPotPT() : inited(false) {}
DeepPotPT::DeepPotPT(const std::string& model,
const int& gpu_rank,
Expand Down Expand Up @@ -60,7 +71,7 @@ void DeepPotPT::init(const std::string& model,

auto rcut_ = module.run_method("get_rcut").toDouble();
rcut = static_cast<double>(rcut_);
ntypes = 0;
ntypes = module.run_method("get_ntypes").toInt();
ntypes_spin = 0;
dfparam = module.run_method("get_dim_fparam").toInt();
daparam = module.run_method("get_dim_aparam").toInt();
Expand All @@ -78,6 +89,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box,
const int nghost,
const InputNlist& lmp_list,
const int& ago,
const std::vector<VALUETYPE>& fparam,
Expand All @@ -86,7 +98,6 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
if (!gpu_enabled) {
device = torch::Device(torch::kCPU);
}
std::vector<VALUETYPE> coord_wrapped = coord;
int natoms = atype.size();
auto options = torch::TensorOptions().dtype(torch::kFloat64);
torch::ScalarType floatType = torch::kFloat64;
Expand All @@ -96,18 +107,29 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
}
auto int_options = torch::TensorOptions().dtype(torch::kInt64);
auto int32_options = torch::TensorOptions().dtype(torch::kInt32);

// select real atoms
std::vector<VALUETYPE> dcoord, dforce, aparam_, datom_energy, datom_virial;
std::vector<int> datype, fwd_map, bkw_map;
int nghost_real, nall_real, nloc_real;
int nall = natoms;
select_real_atoms_coord(dcoord, datype, aparam_, nghost_real, fwd_map,
bkw_map, nall_real, nloc_real, coord, atype, aparam,
nghost, ntypes, 1, daparam, nall, aparam_nall);
std::cout << datype.size() << std::endl;
std::vector<VALUETYPE> coord_wrapped = dcoord;
at::Tensor coord_wrapped_Tensor =
torch::from_blob(coord_wrapped.data(), {1, natoms, 3}, options)
torch::from_blob(coord_wrapped.data(), {1, nall_real, 3}, options)
.to(device);
std::vector<int64_t> atype_64(atype.begin(), atype.end());
std::vector<int64_t> atype_64(datype.begin(), datype.end());
at::Tensor atype_Tensor =
torch::from_blob(atype_64.data(), {1, natoms}, int_options).to(device);
torch::from_blob(atype_64.data(), {1, nall_real}, int_options).to(device);
if (ago == 0) {
nlist_data.copy_from_nlist(lmp_list, max_num_neighbors);
nlist_data.copy_from_nlist(lmp_list);
nlist_data.shuffle_exclude_empty(fwd_map);
nlist_data.padding();
}
at::Tensor firstneigh =
torch::from_blob(nlist_data.jlist.data(),
{1, lmp_list.inum, max_num_neighbors}, int32_options);
at::Tensor firstneigh = createNlistTensor(nlist_data.jlist);
firstneigh_tensor = firstneigh.to(torch::kInt64).to(device);
bool do_atom_virial_tensor = true;
c10::optional<torch::Tensor> optional_tensor;
Expand All @@ -119,13 +141,13 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
.to(device);
}
c10::optional<torch::Tensor> aparam_tensor;
if (!aparam.empty()) {
aparam_tensor =
torch::from_blob(const_cast<VALUETYPE*>(aparam.data()),
{1, lmp_list.inum,
static_cast<long int>(aparam.size()) / lmp_list.inum},
options)
.to(device);
if (!aparam_.empty()) {
aparam_tensor = torch::from_blob(
const_cast<VALUETYPE*>(aparam_.data()),
{1, lmp_list.inum,
static_cast<long int>(aparam_.size()) / lmp_list.inum},
options)
.to(device);
}
c10::Dict<c10::IValue, c10::IValue> outputs =
module
Expand All @@ -145,24 +167,36 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
torch::Tensor flat_atom_energy_ =
atom_energy_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_atom_energy_ = flat_atom_energy_.to(torch::kCPU);
atom_energy.resize(natoms, 0.0); // resize to nall to be consistenet with TF.
atom_energy.assign(
datom_energy.resize(nall_real,
0.0); // resize to nall to be consistenet with TF.
datom_energy.assign(
cpu_atom_energy_.data_ptr<VALUETYPE>(),
cpu_atom_energy_.data_ptr<VALUETYPE>() + cpu_atom_energy_.numel());
torch::Tensor flat_force_ = force_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_force_ = flat_force_.to(torch::kCPU);
force.assign(cpu_force_.data_ptr<VALUETYPE>(),
cpu_force_.data_ptr<VALUETYPE>() + cpu_force_.numel());
dforce.assign(cpu_force_.data_ptr<VALUETYPE>(),
cpu_force_.data_ptr<VALUETYPE>() + cpu_force_.numel());
torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
torch::Tensor flat_atom_virial_ =
atom_virial_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_atom_virial_ = flat_atom_virial_.to(torch::kCPU);
atom_virial.assign(
datom_virial.assign(
cpu_atom_virial_.data_ptr<VALUETYPE>(),
cpu_atom_virial_.data_ptr<VALUETYPE>() + cpu_atom_virial_.numel());
int nframes = 1;
// bkw map
force.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
atom_energy.resize(static_cast<size_t>(nframes) * fwd_map.size());
atom_virial.resize(static_cast<size_t>(nframes) * fwd_map.size() * 9);
select_map<VALUETYPE>(force, dforce, bkw_map, 3, nframes, fwd_map.size(),
nall_real);
select_map<VALUETYPE>(atom_energy, datom_energy, bkw_map, 1, nframes,
fwd_map.size(), nall_real);
select_map<VALUETYPE>(atom_virial, datom_virial, bkw_map, 9, nframes,
fwd_map.size(), nall_real);
}
template void DeepPotPT::compute<double, std::vector<ENERGYTYPE>>(
std::vector<ENERGYTYPE>& ener,
Expand All @@ -173,6 +207,7 @@ template void DeepPotPT::compute<double, std::vector<ENERGYTYPE>>(
const std::vector<double>& coord,
const std::vector<int>& atype,
const std::vector<double>& box,
const int nghost,
const InputNlist& lmp_list,
const int& ago,
const std::vector<double>& fparam,
Expand All @@ -186,6 +221,7 @@ template void DeepPotPT::compute<float, std::vector<ENERGYTYPE>>(
const std::vector<float>& coord,
const std::vector<int>& atype,
const std::vector<float>& box,
const int nghost,
const InputNlist& lmp_list,
const int& ago,
const std::vector<float>& fparam,
Expand Down Expand Up @@ -353,7 +389,7 @@ void DeepPotPT::computew(std::vector<double>& ener,
const std::vector<double>& fparam,
const std::vector<double>& aparam) {
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
inlist, ago, fparam, aparam);
nghost, inlist, ago, fparam, aparam);
}
void DeepPotPT::computew(std::vector<double>& ener,
std::vector<float>& force,
Expand All @@ -369,7 +405,7 @@ void DeepPotPT::computew(std::vector<double>& ener,
const std::vector<float>& fparam,
const std::vector<float>& aparam) {
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
inlist, ago, fparam, aparam);
nghost, inlist, ago, fparam, aparam);
}
void DeepPotPT::computew_mixed_type(std::vector<double>& ener,
std::vector<double>& force,
Expand Down
10 changes: 10 additions & 0 deletions source/api_cc/src/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,16 @@ void deepmd::NeighborListData::shuffle_exclude_empty(
ilist = new_ilist;
jlist = new_jlist;
}
void deepmd::NeighborListData::padding() {
size_t max_length = 0;
for (const auto& row : jlist) {
max_length = std::max(max_length, row.size());
}

for (int i = 0; i < jlist.size(); i++) {
jlist[i].resize(max_length, -1);
}
}

void deepmd::NeighborListData::make_inlist(InputNlist& inlist) {
int nloc = ilist.size();
Expand Down
23 changes: 0 additions & 23 deletions source/api_cc/src/commonPT.cc

This file was deleted.

2 changes: 0 additions & 2 deletions source/api_cc/tests/test_deeppot_pt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ TYPED_TEST(TestInferDeepPotAPt, cpu_lmp_nlist_2rc) {
}

TYPED_TEST(TestInferDeepPotAPt, cpu_lmp_nlist_type_sel) {
GTEST_SKIP() << "Skipping this test for unsupported";
using VALUETYPE = TypeParam;
std::vector<VALUETYPE>& coord = this->coord;
std::vector<int>& atype = this->atype;
Expand Down Expand Up @@ -465,7 +464,6 @@ TYPED_TEST(TestInferDeepPotAPt, cpu_lmp_nlist_type_sel) {
}

TYPED_TEST(TestInferDeepPotAPt, cpu_lmp_nlist_type_sel_atomic) {
GTEST_SKIP() << "Skipping this test for unsupported";
using VALUETYPE = TypeParam;
std::vector<VALUETYPE>& coord = this->coord;
std::vector<int>& atype = this->atype;
Expand Down
Binary file modified source/tests/infer/deeppot_sea.pth
Binary file not shown.
Binary file modified source/tests/infer/fparam_aparam.pth
Binary file not shown.

0 comments on commit ee8b82b

Please sign in to comment.