From ed2eec1e9242b86fa3598b37eb57a41a50e21fd6 Mon Sep 17 00:00:00 2001 From: scxfjiang Date: Mon, 16 Sep 2024 06:42:39 -0500 Subject: [PATCH] scrub navi --- .../xla/xla/stream_executor/device_description.h | 13 ++++++++----- .../xla/xla/stream_executor/rocm/rocm_driver.cc | 8 ++++++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/stream_executor/device_description.h b/third_party/xla/xla/stream_executor/device_description.h index 8fdc78c6f45cf2..29dbc19967034d 100644 --- a/third_party/xla/xla/stream_executor/device_description.h +++ b/third_party/xla/xla/stream_executor/device_description.h @@ -205,16 +205,19 @@ class RocmComputeCapability { return absl::c_count(kList, gfx_version()) != 0; } - bool navi21() const { return gfx_version() == "gfx1030"; } + bool gfx10_rx68xx() const { return gfx_version() == "gfx1030"; } - bool navi31() const { return gfx_version() == "gfx1100"; } + bool gfx10_rx69xx() const { return gfx_version() == "gfx1030"; } + + bool gfx11_rx7900() const { return gfx_version() == "gfx1100"; } bool has_nhwc_layout_support() const { return gfx9_mi100_or_later(); } bool has_bf16_dtype_support() const { return gfx9_mi100_or_later(); } bool has_fast_fp16_support() const { - return gfx9_mi100_or_later() || navi21() || navi31(); + return gfx9_mi100_or_later() || gfx10_rx68xx() || gfx10_rx69xx() || + gfx11_rx7900(); } bool has_mfma_instr_support() const { return gfx9_mi100_or_later(); } @@ -251,8 +254,8 @@ class RocmComputeCapability { "gfx908", // MI100 "gfx90a", // MI200 "gfx940", "gfx941", "gfx942", // MI300 - "gfx1030", // Navi21 - "gfx1100" // Navi31 + "gfx1030", // RX68xx / RX69xx + "gfx1100" // RX7900 }; }; diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc index 0b0871a0fde2e9..f293720492c319 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc @@ -1998,12 +1998,16 @@ static absl::StatusOr GetSimpleAttribute(hipDevice_t device, const uint64_t RESERVED_GFX908 = 1048576 * 512; const uint64_t RESERVED_GFX9_X = 1048576 * 1024; const uint64_t RESERVED_GFX10_X = 1048576 * 512; - if (compute_capability.gfx_version() == "gfx908") { + const uint64_t RESERVED_GFX11_X = 1048576 * 512; + if (compute_capability.gfx9_mi100()) { *reserve = RESERVED_GFX908; } else if (compute_capability.gfx9_mi200_or_later()) { *reserve = RESERVED_GFX9_X; - } else if (compute_capability.navi21() || compute_capability.navi31()) { + } else if (compute_capability.gfx10_rx68xx() || + compute_capability.gfx10_rx69xx()) { *reserve = RESERVED_GFX10_X; + } else if (compute_capability.gfx11_rx7900()) { + *reserve = RESERVED_GFX11_X; } return true;