Skip to content

Commit

Permalink
[pt2] grad support (pytorch#102264)
Browse files Browse the repository at this point in the history
Teach dynamo about grad

Pull Request resolved: pytorch#102264
Approved by: https://github.com/zou3519
  • Loading branch information
kshitij12345 authored and pytorchmergebot committed Jun 21, 2023
1 parent 6d2887c commit d552c27
Show file tree
Hide file tree
Showing 14 changed files with 870 additions and 138 deletions.
24 changes: 3 additions & 21 deletions aten/src/ATen/functorch/TensorWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,9 @@ void dumpTensor(std::ostream& ss, const Tensor& tensor) {
}

void TensorWrapper::refreshMetadata() {
auto dim = value_.dim();
auto sizes = value_.sizes();
auto strides = value_.strides();
storage_offset_ = value_.storage_offset();
sizes_and_strides_.resize(value_.dim());
for (int64_t i = 0; i < dim; i++) {
sizes_and_strides_.size_at_unchecked(i) = sizes[i];
sizes_and_strides_.stride_at_unchecked(i) = strides[i];
}
// update size, strides and storage_offset
set_sizes_and_strides(
value_.sym_sizes(), value_.sym_strides(), value_.sym_storage_offset());

refresh_numel();
refresh_contiguous();
Expand Down Expand Up @@ -159,18 +153,6 @@ TensorWrapper::TensorWrapper(
set_storage_access_should_throw();
}

// The following are some internal inherited methods that we do not support.
// They should never get called.
void TensorWrapper::set_size(int64_t dim, int64_t new_size) {
TORCH_INTERNAL_ASSERT(false, "Can't set_size for TensorWrapper");
}
void TensorWrapper::set_stride(int64_t dim, int64_t new_stride) {
TORCH_INTERNAL_ASSERT(false, "Can't set_stride for TensorWrapper");
}
void TensorWrapper::set_storage_offset(int64_t storage_offset) {
TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for TensorWrapper");
}

const char* TensorWrapper::tensorimpl_type_name() const {
return "TensorWrapper";
}
Expand Down
5 changes: 0 additions & 5 deletions aten/src/ATen/functorch/TensorWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ struct TORCH_API TensorWrapper : public c10::TensorImpl {
bool is_immutable = false, // if true, this came from an operation that aliases an immutable tensor
bool use_value_sizes_strides = true);

// Override a bunch of methods inherited from TensorImpl to return error messages
void set_size(int64_t dim, int64_t new_size) override;
void set_stride(int64_t dim, int64_t new_stride) override;
void set_storage_offset(int64_t storage_offset) override;

void refreshMetadata();

const Tensor& value() const {
Expand Down
Loading

0 comments on commit d552c27

Please sign in to comment.