Skip to content

Commit

Permalink
Move shape and operand definitions to base node (pytorch#75223)
Browse files Browse the repository at this point in the history
Summary:
First stage of breaking up pytorch#74710

Moves the shape and operand definitions from `TsNode` to the base `Node`

CC: wconstab JackCaoG henrytwo

Partially Fixes pytorch#74628

Pull Request resolved: pytorch#75223

Reviewed By: zou3519

Differential Revision: D35410285

Pulled By: wconstab

fbshipit-source-id: bb84d3fb636882cbe7e18af4b35ff2c0e22aaa58
(cherry picked from commit a4144c9)
  • Loading branch information
antoniojkim authored and pytorchmergebot committed Apr 6, 2022
1 parent b8a4708 commit e1b4117
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 170 deletions.
17 changes: 3 additions & 14 deletions test/cpp/lazy/test_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,8 @@ class CacheNode : public Node {
const Output& operand(size_t i) const override {
TORCH_INTERNAL_ASSERT(false, "Can't access operand[i] of test node");
}
const Shape& shape(size_t i) const override { return shape_; }
c10::ArrayRef<Shape> shapes() const override { return {shape_}; }
private:
std::string str_;
Shape shape_;
};

TEST(CacheTest, BasicTest) {
Expand Down Expand Up @@ -66,15 +63,7 @@ TEST(CacheTest, BasicTest) {
class CacheNodeWithShape : public TsNode {
public:
explicit CacheNodeWithShape(const Shape& shape)
: TsNode(OpKind(), shape, /* num_outputs */ 1, /* seed */ 0),
shape_(shape) {}

const Shape& getShape() const {
return shape_;
}

private:
Shape shape_;
: TsNode(OpKind(), shape, /* num_outputs */ 1, /* seed */ 0){}
};

TEST(CacheTest, ShapeCacheTestForDynamicShape) {
Expand All @@ -89,8 +78,8 @@ TEST(CacheTest, ShapeCacheTestForDynamicShape) {
* Make sure the cached shape for node (2, 4) is not used for node (4, 2)
*/
for (auto& node : nodes) {
EXPECT_EQ(node.getShape(), node.GetOpShape([&]() {
return node.getShape();
EXPECT_EQ(node.shape(), node.GetOpShape([&]() {
return node.shape();
}));
}

Expand Down
3 changes: 0 additions & 3 deletions test/cpp/lazy/test_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,8 @@ class TestLeafNode : public Node {
const Output& operand(size_t i) const override {
TORCH_INTERNAL_ASSERT(false, "Can't access operand[i] of leaf node");
}
const Shape& shape(size_t i) const override { return shape_; }
c10::ArrayRef<Shape> shapes() const override { return {shape_}; }
private:
size_t param_;
Shape shape_;
};

TEST(IrTest, BasicTest) {
Expand Down
16 changes: 0 additions & 16 deletions test/cpp/lazy/test_ir_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,6 @@ class IrUtilNode : public Node {
operands_as_outputs_.emplace_back(v.node.get(), v.index);
operands_.push_back(std::move(v.node));
}

const std::vector<Output>& operands() const override {
return operands_as_outputs_;
}

const Output& operand(size_t i) const override {
return operands_as_outputs_.at(i);
}

const Shape& shape(size_t i) const override { return shape_; }
c10::ArrayRef<Shape> shapes() const override { return {shape_}; }

private:
std::vector<NodePtr> operands_;
std::vector<Output> operands_as_outputs_;
Shape shape_;
};

/* a
Expand Down
85 changes: 81 additions & 4 deletions torch/csrc/lazy/core/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ hash_t OpKind::hash() const {
return StringHash(op.toQualString());
}

hash_t OperandHashes(const OpList& operands, const hash_t& seed, bool bakeInSizes) {
hash_t hash = seed;
for (auto& operand : operands) {
if (!operand) {
hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
continue;
}
auto operand_hash = bakeInSizes ? operand.hash_with_sizes() : operand.hash_without_sizes();
hash = HashCombine(hash, operand_hash);
}
return hash;
}

bool Node::enableDynamicShape() {
static bool enabled = std::getenv("LTC_ENABLE_DYNAMIC_SHAPES") != nullptr;
return enabled || FLAGS_ltc_enable_dynamic_shapes;
Expand All @@ -62,20 +75,84 @@ Node::Node(OpKind op, size_t num_outputs, std::function<hash_t(bool)> node_hash_
dag_hash_with_sizes_(node_hash_fn(true)),
metadata_(GetMetaDataIfDebugging()) {}


Node::Node(OpKind op, OpList operands, std::vector<Shape>&& shapes,
size_t num_outputs, hash_t hash_seed)
: Node(op, num_outputs,
// TODO(WHC) this is inefficient (having to compute node_hash twice
// since I can't call hash() yet) so probably move dag_hash
// initialization to a separate function?
/* node_hash */ HashCombine(op.hash(), hash_seed),
/* dag_hash */
[&](bool bakeInSizes) { return OperandHashes(operands, HashCombine(op.hash(), hash_seed), bakeInSizes); }) {
// Move shapes into node
shapes_.insert(
shapes_.end(),
std::make_move_iterator(shapes.begin()),
std::make_move_iterator(shapes.end()));

for (auto& operand : operands) {
// Ideally, optional operands should be filtered by the leaf node classes,
// but it's just much easier to do it here.
// TODO(alanwaketan): Find a way to move the below logic to the leaf node
// classes.
if (!operand) {
continue;
}

AddOperand(operand.node, operand.index);
}
}

Node::Node(OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed)
: Node(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}

Node::Node(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
: Node(op, num_outputs, [&](bool bakeInSizes) -> hash_t { return GetOpHash(op, shape, hash_seed, bakeInSizes); }) {
shapes_.push_back(std::move(shape));
}

Node::~Node() = default;

hash_t Node::GetOpHash(OpKind op, const Shape& shape, hash_t hash_seed, bool bakeInSizes) {
hash_t h = HashCombine(op.hash(), shape.hash(bakeInSizes));
return HashCombine(h, hash_seed);
}

// Retrieves the full shape of the IR Node.
c10::ArrayRef<Shape> Node::shapes() const { return shapes_; }

// Retrieves the shape of the output at a given index.
const Shape& Node::shape(size_t output_index) const {
return shapes_.at(output_index);
}

const std::vector<Output>& Node::operands() const {
return operands_as_outputs_;
}
const Output& Node::operand(size_t i) const {
return operands_as_outputs_.at(i);
}

std::string Node::ToString() const {
std::stringstream ss;
ss << op();
ss << shapes() << " " << op();
if (num_outputs() > 1) {
ss << ", num_outputs=" << num_outputs();
}
if (!metadata_.scope.empty()) {
ss << ", scope=" << metadata_.scope;
if (!metadata().scope.empty()) {
ss << ", scope=" << metadata().scope;
}
EmitShortFrameInfo(ss, metadata_.frame_info);
EmitShortFrameInfo(ss, metadata().frame_info);
return ss.str();
}

void Node::AddOperand(NodePtr node, size_t index) {
CHECK_LT(index, node->num_outputs());
operands_.push_back(std::move(node));
operands_as_outputs_.emplace_back(operands_.back().get(), index);
}


} // namespace lazy
} // namespace torch
37 changes: 33 additions & 4 deletions torch/csrc/lazy/core/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ inline std::ostream& operator<<(std::ostream& stream, const OpKind& op) {

using OpList = c10::ArrayRef<Value>;

hash_t OperandHashes(const OpList& operands, const hash_t& seed, bool bakeInSizes);

// A node in the graph. Nodes for operations which requires extra data to be
// stored for lowering, should inherit from this class and add operation
// specific member there. For example, a constant might create a new
Expand All @@ -88,8 +90,22 @@ class TORCH_API Node {
// Contructor used to create leaf nodes.
Node(OpKind op, size_t num_outputs, std::function<hash_t(bool)> node_hash_fn);

// Construct node with operands and shapes
Node(OpKind op, OpList operands, std::vector<Shape>&& shapes,
size_t num_outputs = 1, hash_t hash_seed = kHashSeed);

// Construct node with operands and no shape
Node(OpKind op, OpList operands, size_t num_outputs = 1,
hash_t hash_seed = kHashSeed);

// Construct node with shape and no operands
Node(OpKind op, Shape shape, size_t num_outputs = 1,
hash_t hash_seed = kHashSeed);

virtual ~Node();

static hash_t GetOpHash(OpKind op, const Shape& shape, hash_t hash_seed, bool bakeInSizes);

const OpKind& op() const {
return op_;
}
Expand All @@ -98,13 +114,15 @@ class TORCH_API Node {
return num_outputs_;
}

virtual c10::ArrayRef<Shape> shapes() const = 0;
// Retrieves the full shape of the IR Node.
virtual c10::ArrayRef<Shape> shapes() const;

virtual const Shape& shape(size_t output_index = 0) const = 0;
// Retrieves the shape of the output at a given index.
virtual const Shape& shape(size_t output_index = 0) const;

virtual const std::vector<Output>& operands() const = 0;
virtual const std::vector<Output>& operands() const;

virtual const Output& operand(size_t i) const = 0;
virtual const Output& operand(size_t i) const;

hash_t node_hash() const {
return node_hash_;
Expand Down Expand Up @@ -161,6 +179,17 @@ class TORCH_API Node {
// The IR framework user can attach a user defined metadata object deriving
// from UserMetaData.
std::shared_ptr<UserMetaData> user_metadata_;

protected:
// Adds node's index output number as operand.
void AddOperand(NodePtr node, size_t index = 0);

std::vector<Shape> shapes_;
// A node holds a real reference to its operands.
std::vector<NodePtr> operands_;
// Outputs do not hold references on the nodes, and neither do the uses, since
// otherwise we get into circular reference counting.
std::vector<Output> operands_as_outputs_;
};


Expand Down
85 changes: 0 additions & 85 deletions torch/csrc/lazy/ts_backend/ts_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,79 +17,18 @@ namespace {
namespace torch {
namespace lazy {

void TsNodeSetShapeDeferred(
NodePtr node, const std::function<Shape()>& shape_fn) {
if (auto tsnode = std::dynamic_pointer_cast<TsNode>(node)) {
tsnode->SetShapeDeferred(shape_fn);
return;
}
throw std::runtime_error("Expected TsNode but could not dynamic cast");
}

hash_t OperandHashes(const OpList& operands, const hash_t& seed, bool bakeInSizes) {
hash_t hash = seed;
for (auto& operand : operands) {
if (!operand) {
hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
continue;
}
auto operand_hash = bakeInSizes ? operand.hash_with_sizes() : operand.hash_without_sizes();
hash = HashCombine(hash, operand_hash);
}
return hash;
}

TsNode::TsNode(OpKind op, OpList operands, std::vector<Shape>&& shapes,
size_t num_outputs, hash_t hash_seed)
: Node(op, num_outputs,
// TODO(WHC) this is inefficient (having to compute node_hash twice
// since I can't call hash() yet) so probably move dag_hash
// initialization to a separate function?
/* node_hash */ HashCombine(op.hash(), hash_seed),
/* dag_hash */
[&](bool bakeInSizes) { return OperandHashes(operands, HashCombine(op.hash(), hash_seed), bakeInSizes); }),
shapes_(shapes),
python_stacktrace_(GetFirstUserFrameInPythonIfEnabled()) {
for (auto& operand : operands) {
// Ideally, optional operands should be filtered by the leaf node classes,
// but it's just much easier to do it here.
// TODO(alanwaketan): Find a way to move the below logic to the leaf node
// classes.
if (!operand) {
continue;
}

AddOperand(operand.node, operand.index);
}
}

TsNode::TsNode(OpKind op, OpList operands,
const std::function<Shape()>& shape_fn,
size_t num_outputs, hash_t hash_seed)
: TsNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {
shapes_.push_back(GetOpShape(shape_fn));
}

TsNode::TsNode(OpKind op, OpList operands, size_t num_outputs,
hash_t hash_seed)
: TsNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}

void TsNode::SetShapeDeferred(
const std::function<Shape()>& shape_fn) {
shapes_.push_back(GetOpShape(shape_fn));
}

TsNode::TsNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
: Node(op, num_outputs, [&](bool bakeInSizes) -> hash_t { return GetOpHash(op, shape, hash_seed, bakeInSizes); }),
python_stacktrace_(GetFirstUserFrameInPythonIfEnabled())
{
shapes_.push_back(std::move(shape));
}

const Shape& TsNode::shape(size_t output_index) const {
return shapes_.at(output_index);
}

using ShapeCache = Cache<hash_t, Shape, HashReducer>;

ShapeCache* GetShapeCache() {
Expand All @@ -109,30 +48,6 @@ Shape TsNode::GetOpShape(
return *shape;
}

std::string TsNode::ToString() const {
std::stringstream ss;
ss << shapes() << " " << op();
if (num_outputs() > 1) {
ss << ", num_outputs=" << num_outputs();
}
if (!metadata().scope.empty()) {
ss << ", scope=" << metadata().scope;
}
EmitShortFrameInfo(ss, metadata().frame_info);
return ss.str();
}

hash_t TsNode::GetOpHash(OpKind op, const Shape& shape, hash_t hash_seed, bool bakeInSizes) {
hash_t h = HashCombine(op.hash(), shape.hash(bakeInSizes));
return HashCombine(h, hash_seed);
}

void TsNode::AddOperand(NodePtr node, size_t index) {
CHECK_LT(index, node->num_outputs());
operands_.push_back(std::move(node));
operands_as_outputs_.emplace_back(operands_.back().get(), index);
}

TSOpVector TsNode::Lower(std::shared_ptr<torch::jit::GraphFunction> function,
TSLoweringContext* loctx) const {
// TODO(whc) beginning to invert the design here. Move to provide a Lower()
Expand Down
Loading

0 comments on commit e1b4117

Please sign in to comment.