Skip to content

Commit

Permalink
Make profiler trace analysis faster
Browse files Browse the repository at this point in the history
Once the profiler has run, we sum all of the flops for all the operators.
To do so, we need to de-duplicate events (eg SDPA calling MM: we should count the flops only once). When there is ambiguity, we chose to count the flops of the parent operator (if we know its flops).
This process was quite slow when we had a lot of events (~45k for a Llama7B on difformers), and could take up to 5mn (`O(N^2)` algo). Now it takes ~1sec (`O(N*ln(N))`)

ghstack-source-id: 7aac17cf85ae68990e37467c40aa0d5a89e8927c
Pull Request resolved: fairinternal/xformers#1203

__original_commit__ = fairinternal/xformers@9820ed0
  • Loading branch information
danthe3rd authored and xFormers Bot committed Aug 26, 2024
1 parent 6d2200c commit 1ee38a9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 34 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Improved
- Profiler: Fix computation of FLOPS for the attention when using xFormers
- Profiler: Fix MFU/HFU calculation when multiple dtypes are used
- Profiler: Trace analysis to compute MFU & HFU is now much faster
- fMHA/splitK: Fixed `nan` in the output when using a `torch.Tensor` bias where a lot of consecutive keys are masked with `-inf`
- Update Flash-Attention version to `v2.6.3` *when building from scratch*
- When using the most recent version of Flash-Attention, it is no longer possible to mix it with the cutlass backend. In other words, it is no longer possible to use the cutlass Fw with the flash Bw.
Expand Down
68 changes: 34 additions & 34 deletions xformers/profiler/profile_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import math
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Sequence, Set, cast
from typing import Any, Dict, List, Optional, Sequence, cast

import torch

Expand Down Expand Up @@ -152,47 +152,47 @@ def compute_mfu(self, hardware_flops: Dict[torch.dtype, float]) -> float:
return hfu_seconds / self.total_time_s

@staticmethod
def from_profile(
events: Sequence[torch._C._autograd._KinetoEvent],
) -> "AnalyzedTrace":
events = [_replace_if_needed(e) for e in events]
# All dispatcher ops
all_ops = [
def _find_all_root_events_with_flops(
all_events: Sequence[torch._C._autograd._KinetoEvent],
) -> Sequence[torch._C._autograd._KinetoEvent]:
# Filters-out non-dispatch ops
# Or operations without flop counted
all_ops_with_flops = [
e
for e in events
for e in all_events
if (
e.device_type().name == "CPU"
and (e.dtypes() or e.shapes() or e.flops() > 0)
and (e.dtypes() or e.shapes())
and e.flops() > 0
)
]

root_ops: Set[torch._C._autograd._KinetoEvent] = set()

def _find_parent_op(
e: torch._C._autograd._KinetoEvent,
) -> torch._C._autograd._KinetoEvent:
e_range = [e.start_ns(), e.start_ns() + e.duration_ns()]
candidate = e
for parent in all_ops:
if parent.device_type() != e.device_type():
continue
if parent.start_thread_id() != e.start_thread_id():
continue
p_range = [parent.start_ns(), parent.start_ns() + parent.duration_ns()]
if not (p_range[0] < e_range[0] < e_range[1] < p_range[1]):
continue
# We take the longest parent with flops
events_per_group: Dict[
Any, List[torch._C._autograd._KinetoEvent]
] = defaultdict(list)
for e in all_ops_with_flops:
events_per_group[(e.start_thread_id(), e.device_type())].append(e)
root_events: List[torch._C._autograd._KinetoEvent] = []
for events in events_per_group.values():
# We assume that 2 events are either non-overlapping,
# or one is contained entirely within the other
events.sort(key=lambda e: (e.start_ns(), -e.duration_ns()))
current_root: Optional[torch._C._autograd._KinetoEvent] = None
for e in events:
if (
parent.flops() > 0
and candidate.duration_ns() < parent.duration_ns()
current_root is None
or e.start_ns()
> current_root.start_ns() + current_root.duration_ns()
):
candidate = parent
return candidate
current_root = e
root_events.append(e)
return root_events

for op in all_ops:
if op.flops() == 0:
continue
root_ops.add(_find_parent_op(op))
@staticmethod
def from_profile(
events: Sequence[torch._C._autograd._KinetoEvent],
) -> "AnalyzedTrace":
events = [_replace_if_needed(e) for e in events]
root_ops = AnalyzedTrace._find_all_root_events_with_flops(events)

operations_per_dtype_fw: Dict[torch.dtype, float] = defaultdict(float)
operations_per_dtype_bw: Dict[torch.dtype, float] = defaultdict(float)
Expand Down

0 comments on commit 1ee38a9

Please sign in to comment.