Skip to content

Commit

Permalink
gpu_options
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Sep 30, 2024
1 parent d3944bb commit da2e7ec
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
3 changes: 2 additions & 1 deletion source/api_cc/src/DataModifierTF.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ void DipoleChargeModifierTF::init(const std::string& model,
std::string str = "/gpu:0";
// See
// https://github.com/tensorflow/tensorflow/blame/8fac27b486939f40bc8e362b94a16a4a8bb51869/tensorflow/core/protobuf/config.proto#L80
options.config.visible_device_list = std::to_string(gpu_rank % gpu_num);
options.config.mutable_gpu_options()->visible_device_list =
std::to_string(gpu_rank % gpu_num);
graph::SetDefaultDevice(str, graph_def);
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
Expand Down
3 changes: 2 additions & 1 deletion source/api_cc/src/DeepPotTF.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,8 @@ void DeepPotTF::init(const std::string& model,
std::string str = "/gpu:0";
// See
// https://github.com/tensorflow/tensorflow/blame/8fac27b486939f40bc8e362b94a16a4a8bb51869/tensorflow/core/protobuf/config.proto#L80
options.config.visible_device_list = std::to_string(gpu_rank % gpu_num);
options.config.mutable_gpu_options()->visible_device_list =
std::to_string(gpu_rank % gpu_num);
graph::SetDefaultDevice(str, graph_def);
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
Expand Down
3 changes: 2 additions & 1 deletion source/api_cc/src/DeepTensorTF.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ void DeepTensorTF::init(const std::string &model,
std::string str = "/gpu:0";
// See
// https://github.com/tensorflow/tensorflow/blame/8fac27b486939f40bc8e362b94a16a4a8bb51869/tensorflow/core/protobuf/config.proto#L80
options.config.visible_device_list = std::to_string(gpu_rank % gpu_num);
options.config.mutable_gpu_options()->visible_device_list =
std::to_string(gpu_rank % gpu_num);
graph::SetDefaultDevice(str, graph_def);
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
Expand Down

0 comments on commit da2e7ec

Please sign in to comment.