Skip to content

Commit

Permalink
Refactor index_remapping_array initialization for PT2 export (pytorch…
Browse files Browse the repository at this point in the history
…#2424)

Summary:
Pull Request resolved: pytorch#2424

We end up getting this error
```
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)
```
w/ tracing `if self.index_remappings_array_offsets[-1] == 0` in FakeTensorMode. Changing it to check len(index_remapping) is logically the same, and seems to fix the error.

Reviewed By: sryap

Differential Revision: D54779307

fbshipit-source-id: 609402fa67a50f2bedfb3883413dfe3f7f6ca314
  • Loading branch information
qxy11 authored and facebook-github-bot committed Mar 14, 2024
1 parent b55904b commit fefddbb
Showing 1 changed file with 6 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1453,15 +1453,16 @@ def set_index_remappings_array(
device=self.current_device,
dtype=torch.int64,
)
if self.index_remappings_array_offsets[-1] == 0:

index_remappings_filter_nones = []
for mapping in index_remapping:
if mapping is not None:
index_remappings_filter_nones.append(mapping)
if len(index_remappings_filter_nones) == 0:
self.index_remappings_array = torch.empty(
0, dtype=torch.int32, device=self.current_device
)
else:
index_remappings_filter_nones = []
for mapping in index_remapping:
if mapping is not None:
index_remappings_filter_nones.append(mapping)
self.index_remappings_array = torch.cat(index_remappings_filter_nones).to(
self.current_device
)
Expand Down

0 comments on commit fefddbb

Please sign in to comment.