Skip to content

Commit

Permalink
Make rocm files use tsl DsoLoader functions instead of the stream_exe…
Browse files Browse the repository at this point in the history
…cutor wrappers.

PiperOrigin-RevId: 681577126
  • Loading branch information
klucke authored and Google-ML-Automation committed Oct 2, 2024
1 parent 6947dee commit 6fd8234
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 109 deletions.
10 changes: 10 additions & 0 deletions xla/stream_executor/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ cc_library(
"@local_config_rocm//rocm:hip",
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform:casts",
"@tsl//tsl/platform:dso_loader",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
Expand Down Expand Up @@ -353,6 +354,7 @@ cc_library(
"//xla/tsl/util:determinism_for_kernels",
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform",
"@tsl//tsl/platform:dso_loader",
"@tsl//tsl/platform:env",
],
alwayslink = True,
Expand Down Expand Up @@ -445,6 +447,7 @@ cc_library(
"//xla/stream_executor/gpu:scoped_activate_context",
"//xla/stream_executor/platform",
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform:dso_loader",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:logging",
],
Expand Down Expand Up @@ -510,6 +513,7 @@ cc_library(
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform:dso_loader",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:env_impl",
"@tsl//tsl/platform:errors",
Expand Down Expand Up @@ -564,6 +568,7 @@ cc_library(
":rocm_platform_id",
"//xla/stream_executor/platform",
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform:dso_loader",
"@tsl//tsl/platform:env",
],
alwayslink = True,
Expand Down Expand Up @@ -600,6 +605,7 @@ cc_library(
":rocsolver_if_static",
"//xla/stream_executor/platform",
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform:dso_loader",
"@tsl//tsl/platform:env",
],
alwayslink = True,
Expand Down Expand Up @@ -635,6 +641,7 @@ cc_library(
":rocm_platform_id",
"//xla/stream_executor/platform",
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform:dso_loader",
"@tsl//tsl/platform:env",
],
alwayslink = True,
Expand Down Expand Up @@ -691,6 +698,7 @@ cc_library(
"//xla/stream_executor/platform",
"@com_google_absl//absl/status",
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform:dso_loader",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:status",
Expand Down Expand Up @@ -725,6 +733,7 @@ cc_library(
"//xla/stream_executor/platform",
"@com_google_absl//absl/status",
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform:dso_loader",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:status",
Expand Down Expand Up @@ -785,6 +794,7 @@ cc_library(
"//xla/stream_executor/platform",
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform",
"@tsl//tsl/platform:dso_loader",
"@tsl//tsl/platform:env",
],
alwayslink = True,
Expand Down
32 changes: 16 additions & 16 deletions xla/stream_executor/rocm/hipblaslt_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "rocm/include/hipblaslt.h"
#endif
#include "xla/stream_executor/platform/port.h"
#include "tsl/platform/dso_loader.h"
#include "tsl/platform/env.h"

namespace stream_executor {
Expand All @@ -46,22 +47,21 @@ namespace wrap {
#define TO_STR_(x) #x
#define TO_STR(x) TO_STR_(x)

#define HIPBLASLT_API_WRAPPER(api_name) \
template <typename... Args> \
auto api_name(Args... args) -> decltype(::api_name(args...)) { \
using FuncPtrT = std::add_pointer<decltype(::api_name)>::type; \
static FuncPtrT loaded = []() -> FuncPtrT { \
static const char* kName = TO_STR(api_name); \
void* f; \
auto s = tsl::Env::Default() -> GetSymbolFromLibrary( \
stream_executor::internal::CachedDsoLoader::GetHipblasltDsoHandle() \
.value(), \
kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in hipblaslt lib; dlerror: " << s.message(); \
return reinterpret_cast<FuncPtrT>(f); \
}(); \
return loaded(args...); \
#define HIPBLASLT_API_WRAPPER(api_name) \
template <typename... Args> \
auto api_name(Args... args) -> decltype(::api_name(args...)) { \
using FuncPtrT = std::add_pointer<decltype(::api_name)>::type; \
static FuncPtrT loaded = []() -> FuncPtrT { \
static const char* kName = TO_STR(api_name); \
void* f; \
auto s = tsl::Env::Default()->GetSymbolFromLibrary( \
tsl::internal::CachedDsoLoader::GetHipblasltDsoHandle().value(), \
kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in hipblaslt lib; dlerror: " << s.message(); \
return reinterpret_cast<FuncPtrT>(f); \
}(); \
return loaded(args...); \
}

#endif
Expand Down
32 changes: 16 additions & 16 deletions xla/stream_executor/rocm/hipsolver_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "rocm/include/hipsolver.h"
#endif
#include "xla/stream_executor/platform/port.h"
#include "tsl/platform/dso_loader.h"
#include "tsl/platform/env.h"

namespace stream_executor {
Expand All @@ -48,22 +49,21 @@ namespace wrap {
#define TO_STR_(x) #x
#define TO_STR(x) TO_STR_(x)

#define HIPSOLVER_API_WRAPPER(api_name) \
template <typename... Args> \
auto api_name(Args... args) -> decltype(::api_name(args...)) { \
using FuncPtrT = std::add_pointer<decltype(::api_name)>::type; \
static FuncPtrT loaded = []() -> FuncPtrT { \
static const char* kName = TO_STR(api_name); \
void* f; \
auto s = tsl::Env::Default() -> GetSymbolFromLibrary( \
stream_executor::internal::CachedDsoLoader::GetHipsolverDsoHandle() \
.value(), \
kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in hipsolver lib; dlerror: " << s.message(); \
return reinterpret_cast<FuncPtrT>(f); \
}(); \
return loaded(args...); \
#define HIPSOLVER_API_WRAPPER(api_name) \
template <typename... Args> \
auto api_name(Args... args) -> decltype(::api_name(args...)) { \
using FuncPtrT = std::add_pointer<decltype(::api_name)>::type; \
static FuncPtrT loaded = []() -> FuncPtrT { \
static const char* kName = TO_STR(api_name); \
void* f; \
auto s = tsl::Env::Default()->GetSymbolFromLibrary( \
tsl::internal::CachedDsoLoader::GetHipsolverDsoHandle().value(), \
kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in hipsolver lib; dlerror: " << s.message(); \
return reinterpret_cast<FuncPtrT>(f); \
}(); \
return loaded(args...); \
}

#endif
Expand Down
50 changes: 25 additions & 25 deletions xla/stream_executor/rocm/hipsparse_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#endif
#include "xla/stream_executor/platform/platform.h"
#include "xla/stream_executor/platform/port.h"
#include "tsl/platform/dso_loader.h"
#include "tsl/platform/env.h"

namespace stream_executor {
Expand All @@ -47,31 +48,30 @@ namespace wrap {

#else

#define HIPSPARSE_API_WRAPPER(__name) \
static struct DynLoadShim__##__name { \
constexpr static const char* kName = #__name; \
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
static void* GetDsoHandle() { \
auto s = \
stream_executor::internal::CachedDsoLoader::GetHipsparseDsoHandle(); \
return s.value(); \
} \
static FuncPtrT LoadOrDie() { \
void* f; \
auto s = tsl::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in miopen DSO; dlerror: " << s.message(); \
return reinterpret_cast<FuncPtrT>(f); \
} \
static FuncPtrT DynLoad() { \
static FuncPtrT f = LoadOrDie(); \
return f; \
} \
template <typename... Args> \
hipsparseStatus_t operator()(Args... args) { \
return DynLoad()(args...); \
} \
#define HIPSPARSE_API_WRAPPER(__name) \
static struct DynLoadShim__##__name { \
constexpr static const char* kName = #__name; \
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
static void* GetDsoHandle() { \
auto s = tsl::internal::CachedDsoLoader::GetHipsparseDsoHandle(); \
return s.value(); \
} \
static FuncPtrT LoadOrDie() { \
void* f; \
auto s = tsl::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in miopen DSO; dlerror: " << s.message(); \
return reinterpret_cast<FuncPtrT>(f); \
} \
static FuncPtrT DynLoad() { \
static FuncPtrT f = LoadOrDie(); \
return f; \
} \
template <typename... Args> \
hipsparseStatus_t operator()(Args... args) { \
return DynLoad()(args...); \
} \
} __name;

#endif
Expand Down
3 changes: 2 additions & 1 deletion xla/stream_executor/rocm/rocblas_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "rocm/include/rocblas/rocblas.h"
#include "rocm/rocm_config.h"
#include "xla/stream_executor/platform/port.h"
#include "tsl/platform/dso_loader.h"
#include "tsl/platform/env.h"
#include "tsl/platform/platform.h"

Expand All @@ -43,7 +44,7 @@ namespace wrap {
} __name;

#else
using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle;
using tsl::internal::CachedDsoLoader::GetRocblasDsoHandle;

#define ROCBLAS_API_WRAPPER(__name) \
static struct DynLoadShim__##__name { \
Expand Down
3 changes: 2 additions & 1 deletion xla/stream_executor/rocm/rocm_dnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ limitations under the License.
#include "xla/stream_executor/stream_executor.h"
#include "xla/tsl/util/determinism.h"
#include "xla/tsl/util/env_var.h"
#include "tsl/platform/dso_loader.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/hash.h"
Expand Down Expand Up @@ -248,7 +249,7 @@ namespace wrap {
static const char* kName; \
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
static void* GetDsoHandle() { \
auto s = internal::CachedDsoLoader::GetMiopenDsoHandle(); \
auto s = tsl::internal::CachedDsoLoader::GetMiopenDsoHandle(); \
return s.value(); \
} \
static FuncPtrT LoadOrDie() { \
Expand Down
33 changes: 16 additions & 17 deletions xla/stream_executor/rocm/rocm_driver_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ limitations under the License.

#include "rocm/include/hip/hip_runtime.h"
#include "rocm/rocm_config.h"
#include "xla/stream_executor/platform/port.h"
#include "tsl/platform/dso_loader.h"
#include "tsl/platform/env.h"

namespace stream_executor {
Expand All @@ -46,22 +46,21 @@ namespace wrap {
#define TO_STR_(x) #x
#define TO_STR(x) TO_STR_(x)

#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \
template <typename... Args> \
auto hipSymbolName(Args... args) -> decltype(::hipSymbolName(args...)) { \
using FuncPtrT = std::add_pointer<decltype(::hipSymbolName)>::type; \
static FuncPtrT loaded = []() -> FuncPtrT { \
static const char *kName = TO_STR(hipSymbolName); \
void *f; \
auto s = tsl::Env::Default()->GetSymbolFromLibrary( \
stream_executor::internal::CachedDsoLoader::GetHipDsoHandle() \
.value(), \
kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in HIP DSO; dlerror: " << s.message(); \
return reinterpret_cast<FuncPtrT>(f); \
}(); \
return loaded(args...); \
#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \
template <typename... Args> \
auto hipSymbolName(Args... args) -> decltype(::hipSymbolName(args...)) { \
using FuncPtrT = std::add_pointer<decltype(::hipSymbolName)>::type; \
static FuncPtrT loaded = []() -> FuncPtrT { \
static const char *kName = TO_STR(hipSymbolName); \
void *f; \
auto s = tsl::Env::Default()->GetSymbolFromLibrary( \
tsl::internal::CachedDsoLoader::GetHipDsoHandle().value(), kName, \
&f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in HIP DSO; dlerror: " << s.message(); \
return reinterpret_cast<FuncPtrT>(f); \
}(); \
return loaded(args...); \
}
#endif

Expand Down
3 changes: 2 additions & 1 deletion xla/stream_executor/rocm/rocm_fft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "xla/stream_executor/rocm/rocm_complex_converters.h"
#include "xla/stream_executor/rocm/rocm_platform_id.h"
#include "xla/stream_executor/stream_executor.h"
#include "tsl/platform/dso_loader.h"
#include "tsl/platform/env.h"
#include "tsl/platform/logging.h"

Expand Down Expand Up @@ -60,7 +61,7 @@ namespace wrap {
static const char *kName; \
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
static void *GetDsoHandle() { \
auto s = internal::CachedDsoLoader::GetHipfftDsoHandle(); \
auto s = tsl::internal::CachedDsoLoader::GetHipfftDsoHandle(); \
return s.value(); \
} \
static FuncPtrT LoadOrDie() { \
Expand Down
Loading

0 comments on commit 6fd8234

Please sign in to comment.