Skip to content

Commit

Permalink
add partition_by_partition_cost (pytorch#47280)
Browse files Browse the repository at this point in the history
Summary:
This PR adds the support to calculate the cost of a partitioned graph partition by partition based on the node cost. In a partitioned graph, top partitions (partitions without parents) are collected as the starting points, then use DFS to find the critical path among all partitions in the graph

Pull Request resolved: pytorch#47280

Reviewed By: gcatron

Differential Revision: D24735932

Pulled By: scottxu0730

fbshipit-source-id: 96653a8208554d2c3624e6c8718628f7c13e320b
  • Loading branch information
scottxu0730 authored and facebook-github-bot committed Nov 5, 2020
1 parent 878032d commit 5107a41
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 21 deletions.
21 changes: 15 additions & 6 deletions test/test_fx_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from torch.fx.experimental.rewriter import RewritingTracer
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.jit_utils import JitTestCase
from torch.fx.experimental.partitioner_utils import get_latency_of_one_partition, \
NodeLatency
from torch.fx.experimental.partitioner_utils import NodeLatency, \
get_partition_to_latency_mapping, get_latency_of_partitioned_graph
from typing import Union, Callable

def symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule:
Expand Down Expand Up @@ -212,10 +212,19 @@ def get_node_to_latency_mapping(fx_module: GraphModule):
module_with_submodules = ret.module_with_submodules
self.assertEqual(traced(a), module_with_submodules(a))
partitions = partitioner.partitions
partition_latency_0 = get_latency_of_one_partition(partitions[0], node_to_latency_mapping)
assert (128., 80., 160.) == partition_latency_0
partition_latency_1 = get_latency_of_one_partition(partitions[1], node_to_latency_mapping)
assert (16., 32., 32) == partition_latency_1
partition_to_latency_mapping = get_partition_to_latency_mapping(partitions, node_to_latency_mapping)
for p in partition_to_latency_mapping:
if p.partition_id == 0:
assert partition_to_latency_mapping[p] == (128., 80., 160.)
else:
assert partition_to_latency_mapping[p] == (16., 32., 32.)
transfer_rate_bytes_per_sec = 0.5
critical_path_latency_sec = get_latency_of_partitioned_graph(
partitions,
partition_to_latency_mapping,
transfer_rate_bytes_per_sec
)
assert critical_path_latency_sec == 208.

def test_call_to_assert_no_msg(self):

Expand Down
113 changes: 98 additions & 15 deletions torch/fx/experimental/partitioner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@

class NodeLatency(NamedTuple):
# Latency due to the memory bandwidth
mem_latency: float
mem_latency_sec: float
# Latency due to the computation
compute_latency: float
computer_latency_sec: float

class PartitionLatency(NamedTuple):
# Sum of all nodes' memory latency on the critical path
mem_latency: float
mem_latency_sec: float
# Sum of all nodes' compute latency on the critical path
compute_latency: float
computer_latency_sec: float
# Latency of the critical path
overall_latency: float
overall_latency_sec: float

def get_latency_of_one_partition(
partition: Partition,
Expand Down Expand Up @@ -45,30 +45,113 @@ def dfs_helper(node: Node, partition_latency) -> PartitionLatency:
"""
node_latency = node_to_latency_mapping[node]
# Calculate the current overall latency of the partition
overall_latency = partition_latency.overall_latency + max(node_latency.compute_latency, node_latency.mem_latency)
overall_latency_sec = partition_latency.overall_latency_sec + \
max(node_latency.computer_latency_sec, node_latency.mem_latency_sec)
# Update the mem latency of this path
mem_latency = partition_latency.mem_latency + node_latency.mem_latency
mem_latency_sec = partition_latency.mem_latency_sec + node_latency.mem_latency_sec
# Update the compute latency of this path
compute_latency = partition_latency.compute_latency + node_latency.compute_latency
computer_latency_sec = partition_latency.computer_latency_sec + node_latency.computer_latency_sec
# Get all users of this node that are in this partition
users = set(node.users).intersection(partition.nodes)
if users:
max_latency = PartitionLatency(mem_latency=0., compute_latency=0., overall_latency=0.)
max_latency = PartitionLatency(mem_latency_sec=0., computer_latency_sec=0., overall_latency_sec=0.)
for n in users:
# Get new partition latency recursively
new_partition_latency = dfs_helper(n, PartitionLatency(mem_latency, compute_latency, overall_latency))
if new_partition_latency.overall_latency > max_latency.overall_latency:
new_partition_latency = dfs_helper(n, PartitionLatency(mem_latency_sec, computer_latency_sec, overall_latency_sec))
if new_partition_latency.overall_latency_sec > max_latency.overall_latency_sec:
max_latency = new_partition_latency
return max_latency
# If there is no user, the node is at bottom of the partition
return PartitionLatency(mem_latency, compute_latency, overall_latency)
return PartitionLatency(mem_latency_sec, computer_latency_sec, overall_latency_sec)
# Main part starts
# Get all top level nodes of this partition
top_nodes = get_top_nodes(partition)
critical_path_latency = PartitionLatency(mem_latency=0., compute_latency=0., overall_latency=0.)
critical_path_latency = PartitionLatency(mem_latency_sec=0., computer_latency_sec=0., overall_latency_sec=0.)
# Go through all top nodes and find the largest latency (critical pass latency)
for node in top_nodes:
partition_latency = dfs_helper(node, PartitionLatency(mem_latency=0., compute_latency=0., overall_latency=0.))
if partition_latency.overall_latency > critical_path_latency.overall_latency:
partition_latency = dfs_helper(node, PartitionLatency(mem_latency_sec=0., computer_latency_sec=0., overall_latency_sec=0.))
if partition_latency.overall_latency_sec > critical_path_latency.overall_latency_sec:
critical_path_latency = partition_latency
return critical_path_latency

def get_partition_to_latency_mapping(
partitions: List[Partition],
node_to_latency_mapping: Dict[Node, NodeLatency]
) -> Dict[Partition, PartitionLatency]:
"""Given all the partitions and node_to_latency_mapping dictionary,
return a mapping dictionary of each partition to its overall latency
"""
partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {}
# Go through each partition and get its latency
for partition in partitions:
partition_latency = get_latency_of_one_partition(partition, node_to_latency_mapping)
partition_to_latency_mapping[partition] = partition_latency
return partition_to_latency_mapping

def get_comm_latency_between(parent_partition: Partition, child_partition: Partition, transfer_rate_bytes_per_sec: float):
"""Given two partitions (parent and child),
calculate the communication latency between the two.
"""
# Keep tracking the communication size between parent and child
comm_size = 0
# Keep tracking all the counted node
visited_nodes = set()
# Go through all nodes in the child partition
# If a node has input nodes from the parent partition,
# the output size of those input nodes will be counted
# and added to comm_size
for node in child_partition.nodes:
input_nodes: Dict[Node, None] = {}
map_arg(node.args, lambda n: input_nodes.setdefault(n))
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
for n in input_nodes:
if n in parent_partition.nodes and n not in visited_nodes:
size_bytes = getattr(n, "size_bytes", None)
if size_bytes is not None:
comm_size += size_bytes.output_size
visited_nodes.add(n)
return comm_size * transfer_rate_bytes_per_sec

def get_latency_of_partitioned_graph(
partitions: List[Partition],
partition_to_latency_mapping: Dict[Partition, PartitionLatency],
transfer_rate_bytes_per_sec: float
):
"""Given all paritions in a graph, find the critical path among all partitions
and return its latency as the latency of the whole graph
"""
def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float:
"""This function helps to recursively get the latency of a path of partitions
"""
# Update latency by adding current partition's latency
latency_so_far_sec += partition_to_latency_mapping[partition].overall_latency_sec
children = partition.children
if partition.children:
max_latency_sec = 0.
for child in partition.children:
# Calculate latency between
comm_latency_sec = get_comm_latency_between(partition, child, transfer_rate_bytes_per_sec)
new_latency_sec = dfs_helper(child, latency_so_far_sec + comm_latency_sec)
if new_latency_sec > max_latency_sec:
max_latency_sec = new_latency_sec
return max_latency_sec
return latency_so_far_sec

def get_top_partitions(partitions: List[Partition]) -> List[Partition]:
"""This function is to return all the partitions without parents
as the starting points of all the paths
"""
top_partitions = []
for partition in partitions:
# If a partition has no parents, then it is a top partition
if len(partition.parents) == 0:
top_partitions.append(partition)
return top_partitions

top_partitions = get_top_partitions(partitions)
critical_path_latency_sec = 0.
for partition in top_partitions:
latency_sec = dfs_helper(partition, 0.)
if latency_sec > critical_path_latency_sec:
critical_path_latency_sec = latency_sec
return critical_path_latency_sec

0 comments on commit 5107a41

Please sign in to comment.