Skip to content

Commit

Permalink
Use DeviceCapabilities in a few places
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Dec 8, 2023
1 parent 510edf0 commit 7956ce5
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 61 deletions.
31 changes: 2 additions & 29 deletions torch_xla/csrc/convert_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,35 +58,8 @@ xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from,
if (from == to) {
return op;
}
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(bridge::GetDeviceOrCurrent(device).type());
if (hw_type != XlaDeviceType::TPU) {
return xla::ConvertElementType(op, to);
}
switch (from) {
case xla::PrimitiveType::PRED:
case xla::PrimitiveType::S8:
case xla::PrimitiveType::U8:
case xla::PrimitiveType::S16:
case xla::PrimitiveType::U16:
case xla::PrimitiveType::S32:
case xla::PrimitiveType::U32:
case xla::PrimitiveType::BF16:
case xla::PrimitiveType::F32:
return xla::ConvertElementType(op, to);
case xla::PrimitiveType::S64:
case xla::PrimitiveType::U64: {
switch (to) {
case xla::PrimitiveType::PRED:
return ExplicitBooleanConvert(op, from);
default:
return xla::ConvertElementType(op, to);
}
break;
}
default:
XLA_ERROR() << "Unsupported XLA type " << from;
}

return xla::ConvertElementType(op, to);
}

xla::XlaOp ConvertToRaw(xla::XlaOp op, xla::PrimitiveType from,
Expand Down
8 changes: 3 additions & 5 deletions torch_xla/csrc/data_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@ bool IsSparseGather(const xla::Shape& input_shape,
const xla::Shape& index_shape, int64_t dim) {
// Conservative sparsity check for multi-platform support
// to avoid gather on a single float on TPU.
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(bridge::GetCurrentDevice().type());
if (hw_type == XlaDeviceType::TPU || hw_type == XlaDeviceType::NEURON) {
auto device_capabilities = runtime::GetComputationClient()->GetDeviceCapabilities();
if (device_capabilities.dense_gather_factor) {
// XLA_DENSE_GATHER_FACTOR can be used to finely control the
// sparsity check.
static int dense_gather_factor =
runtime::sys_util::GetEnvInt("XLA_DENSE_GATHER_FACTOR", 8192);
static int dense_gather_factor = *device_capabilities.dense_gather_factor;
int64_t input_elements = input_shape.dimensions()[dim];
// Use a very conservative check so that we run dense gather
// most of the time on TPU.
Expand Down
22 changes: 6 additions & 16 deletions torch_xla/csrc/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,6 @@ bool Use32BitLong() {
return use_32bit_long;
}

bool IsTpuDevice(XlaDeviceType hw_type) {
static bool spmd_device_is_tpu =
(hw_type == XlaDeviceType::SPMD) &&
// HACK: find a better way to decide if SPMD is actually a TPU without
// accessing the runtime.
runtime::sys_util::GetEnvString("PJRT_DEVICE", "") == "TPU";
return (hw_type == XlaDeviceType::TPU) || spmd_device_is_tpu;
}

} // namespace

at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) {
Expand Down Expand Up @@ -153,7 +144,7 @@ xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type) {

xla::PrimitiveType MaybeDowncastToXlaDeviceType(
xla::PrimitiveType type, const torch::lazy::BackendDevice& device) {
XlaDeviceType hw_type = static_cast<XlaDeviceType>(device.type());
auto device_capabilities = runtime::GetComputationClient()->GetDeviceCapabilities();
switch (type) {
case xla::PrimitiveType::F64:
if (UseF16()) {
Expand All @@ -162,8 +153,7 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
if (UseBF16()) {
return xla::PrimitiveType::BF16;
}
if (DowncastBF16() || DowncastF16() || IsTpuDevice(hw_type) ||
hw_type == XlaDeviceType::NEURON) {
if (DowncastBF16() || DowncastF16() || !device_capabilities.supports_float64) {
return xla::PrimitiveType::F32;
}
return xla::PrimitiveType::F64;
Expand All @@ -174,20 +164,20 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16
: xla::PrimitiveType::F32;
case xla::PrimitiveType::U16:
return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON
return device_capabilities.supports_int16
? xla::PrimitiveType::U16
: xla::PrimitiveType::U32;
case xla::PrimitiveType::S16:
return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON
return device_capabilities.supports_int16
? xla::PrimitiveType::S16
: xla::PrimitiveType::S32;
case xla::PrimitiveType::S64:
return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64;
case xla::PrimitiveType::U64:
return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64;
case xla::PrimitiveType::C128:
return !IsTpuDevice(hw_type) ? xla::PrimitiveType::C128
: xla::PrimitiveType::C64;
return device_capabilities.supports_complex128 ? xla::PrimitiveType::C128
: xla::PrimitiveType::C64;
default:
return type;
}
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/layout_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ xla::Shape MakeArrayShapeFromDimensions(
return MakeShapeWithLayout(type, dimensions, dynamic_dimensions,
*layout_ptr);
}
if (dimensions.size() > 1 && hw_type == XlaDeviceType::TPU) {
auto device_capabilities = runtime::GetComputationClient()->GetDeviceCapabilities();
if (dimensions.size() > 1 && device_capabilities.use_tpu_layout) {
return MakeTpuShape(dimensions, dynamic_dimensions, type);
}
return MakeTorchTensorLayout(dimensions, dynamic_dimensions, type);
Expand Down
12 changes: 2 additions & 10 deletions torch_xla/csrc/random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,8 @@ namespace torch_xla {
namespace {

std::string GetDefaultGitGeneratorName() {
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(bridge::GetCurrentDevice().type());
switch (hw_type) {
case XlaDeviceType::GPU:
case XlaDeviceType::CUDA:
case XlaDeviceType::ROCM:
return "three_fry";
default:
return "default";
}
auto device_capabilities = runtime::GetComputationClient()->GetDeviceCapabilities();
return device_capabilities.default_rng_bit_generator_name.value_or("default");
}

xla::BitGeneratorTy GetBitGenerator() {
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,8 @@ class ComputationClient {
// Return the XlaCoordinator for the runtime.
virtual XlaCoordinator& GetCoordinator() = 0;

virtual const DeviceCapabilities& GetDeviceCapabilities() const = 0;

// Utility API around the vector based Compile() API to compile a single
// computation.
ComputationPtr Compile(xla::XlaComputation computation,
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/runtime/device_capabilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ namespace runtime {

struct DeviceCapabilities {
bool supports_float64;
bool supports_int16;
bool supports_complex128;
bool supports_bool;
bool use_tpu_layout;
// TODO figure out
// https://github.com/pytorch/xla/blob/2c6e4a773cc70cdea3c606d410b3aef8f7dfb6f7/torch_xla/csrc/resize_ops.cpp#L268
std::optional<int32_t> dense_gather_factor; // TODO better name
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ PjRtComputationClient::PjRtComputationClient() {
const PJRT_Api* c_api =
static_cast<xla::PjRtCApiClient*>(client_.get())->pjrt_c_api();
profiler::RegisterProfilerForPlugin(c_api);
device_capabilities_ = DeviceCapabilities{false, false, 8192, 100, std::nullopt};
} else if (device_type == "TPU_LEGACY") {
XLA_ERROR() << "TPU_LEGACY client is no longer available.";
} else if (device_type == "GPU" || device_type == "CUDA" ||
Expand Down Expand Up @@ -176,6 +177,7 @@ PjRtComputationClient::PjRtComputationClient() {
/*kv_get=*/kv_get,
/*kv_put=*/kv_put)
.value());
device_capabilities_ = DeviceCapabilities{true, true, std::nullopt, std::nullopt, "three_fry"};
} else if (device_type == "XPU") {
TF_VLOG(1) << "Initializing PjRt XPU client...";
XLA_CHECK_OK(
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ class PjRtComputationClient : public ComputationClient {

bool CoordinatorInitialized() const override;

const DeviceCapabilities& GetDeviceCapabilities() const override { return device_capabilities_; };

// NOT IMPLEMENTED

MemoryInfo GetMemoryInfo(const std::string& device) override {
Expand All @@ -114,6 +116,7 @@ class PjRtComputationClient : public ComputationClient {
OperationManager operation_manager_;
tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool(
tsl::Env::Default(), "pjrt", std::thread::hardware_concurrency());
DeviceCapabilities device_capabilities_;

xla::PjRtDevice* StringToPjRtDevice(const std::string& device);

Expand Down

0 comments on commit 7956ce5

Please sign in to comment.