Skip to content

Commit

Permalink
Align settings for new device key. (pytorch#98224)
Browse files Browse the repository at this point in the history
Summary: As title.

Test Plan: All CI tests should pass.

Reviewed By: yuhc

Differential Revision: D44341331

Pull Request resolved: pytorch#98224
Approved by: https://github.com/jackm321, https://github.com/ezyang
  • Loading branch information
egienvalue authored and pytorchmergebot committed Apr 4, 2023
1 parent 86505c6 commit d47a4bf
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 4 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/TensorIterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1507,6 +1507,7 @@ void TensorIteratorBase::build(TensorIteratorConfig& config) {
// Nothing beyond this point is important for meta functions, so it's fine to exit early here.
// Extend the condition to ORT tesnors as ORT tensors also don't have storage.
if (privateuse1_without_storage ||
common_device_.type() == DeviceType::MTIA ||
common_device_.type() == DeviceType::XLA ||
common_device_.type() == DeviceType::IPU ||
common_device_.type() == DeviceType::Lazy ||
Expand Down
2 changes: 1 addition & 1 deletion c10/core/Backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) {
return Backend::QuantizedXPU;
} else if (t == DispatchKey::HPU || t == DispatchKey::AutogradHPU) {
return Backend::HPU;
} else if (t == DispatchKey::MTIA) {
} else if (t == DispatchKey::MTIA || t == DispatchKey::AutogradMTIA) {
return Backend::MTIA;
} else if (t == DispatchKey::PrivateUse1) {
return Backend::PrivateUse1;
Expand Down
2 changes: 1 addition & 1 deletion c10/core/Device.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ struct C10_API Device final {
/// Return true if the device supports arbitrary strides.
bool supports_as_strided() const noexcept {
return type_ != DeviceType::IPU && type_ != DeviceType::XLA &&
type_ != DeviceType::Lazy;
type_ != DeviceType::Lazy && type_ != DeviceType::MTIA;
}

/// Same string as returned from operator<<.
Expand Down
4 changes: 2 additions & 2 deletions torch/_tensor_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,9 @@ def _str_intern(inp, *, tensor_contents=None):
suffixes.append("device='" + str(self.device) + "'")

# Tensor printing performs tensor operations like slice, indexing, etc to make it in a
# representable format. These operations on ipu/xla/lazy tensor results in compilations. Hence,
# representable format. These operations on ipu/xla/lazy/mtia tensor results in compilations. Hence,
# to avoid compilations, copying the tensor to cpu before printing.
if self.device.type in ["xla", "lazy", "ipu"]:
if self.device.type in ["xla", "lazy", "ipu", "mtia"]:
self = self.to("cpu")

# TODO: add an API to map real -> complex dtypes
Expand Down

0 comments on commit d47a4bf

Please sign in to comment.