Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Request Help] “torch._dynamo.exc.UserError: Could not guard on data-dependent expression Eq(u77 - 1, 1)" #8988

Closed
liye0626 opened this issue Mar 6, 2025 · 2 comments
Assignees
Labels
module: exir Issues related to Export IR and the code under exir/ module: user experience Issues related to reducing friction for users triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@liye0626
Copy link

liye0626 commented Mar 6, 2025

I encountered the following issues when converting to PTE:
"torch._dynamo.exc.UserError: Could not guard on data-dependent expression Eq(u77 - 1, 1)"

I guess I should provide more hints for Tensor, but I don't know how to do it.

Related code (I added some code)

       	...
	# Logically, the shapes of the following two tensors are the same (but I did not provide a hint):
        tmp_A = candidates[:, 1:].to(logits.device)
        tmp_B = torch.argmax(logits[:, :-1], dim=-1)
        tmp_A == tmp_B 

Full log

W0306 02:30:21.537000 185957 torch/fx/experimental/symbolic_shapes.py:6354] [0/0_2] failed during evaluate_expr(Eq(u77 - 1, 1), hint=None, size_oblivious=True, forcing_spec=False
E0306 02:30:21.537000 185957 torch/fx/experimental/recording.py:299] [0/0_2] failed while running evaluate_expr(*(Eq(u77 - 1, 1), None), **{'fx_node': False, 'size_oblivious': True})

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2676, in run_node
    return node.target(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1817, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1387, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2291, in _dispatch_impl
    decomposition_table[func](*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/wrappers.py", line 291, in _fn
    result = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/wrappers.py", line 143, in _fn
    result = fn(**bound.arguments)
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 1095, in _ref
    a, b = _maybe_broadcast(a, b)
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 437, in _maybe_broadcast
    common_shape = _broadcast_shapes(
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 418, in _broadcast_shapes
    if guard_size_oblivious(common_shape[idx] == 1):
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 411, in guard_size_oblivious
    return expr.node.guard_size_oblivious("", 0)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 564, in guard_size_oblivious
    r = self.shape_env.evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6350, in evaluate_expr
    return self._evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6540, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u77 - 1, 1) (unhinted: Eq(u77 - 1, 1)).  (Size-like symbols: u77)

Caused by: tmp_A == tmp_B  # executorch/examples/models/llama/utils.py:355 in evaluate_posterior (_refs/__init__.py:418 in _broadcast_shapes)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u77"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 45, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/executorch/examples/models/llama/ea_model.py", line 132, in forward
    best_candidate, accept_length, sample_p = evaluate_posterior(
  File "/usr/local/lib/python3.10/dist-packages/executorch/examples/models/llama/utils.py", line 355, in evaluate_posterior
    tmp_A == tmp_B

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2561, in get_fake_value
    ret_val = wrap_fake_exception(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2107, in wrap_fake_exception
    return fn()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2562, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2694, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2676, in run_node
    return node.target(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1817, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1387, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2291, in _dispatch_impl
    decomposition_table[func](*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/wrappers.py", line 291, in _fn
    result = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/wrappers.py", line 143, in _fn
    result = fn(**bound.arguments)
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 1095, in _ref
    a, b = _maybe_broadcast(a, b)
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 437, in _maybe_broadcast
    common_shape = _broadcast_shapes(
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 418, in _broadcast_shapes
    if guard_size_oblivious(common_shape[idx] == 1):
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 411, in guard_size_oblivious
    return expr.node.guard_size_oblivious("", 0)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 564, in guard_size_oblivious
    r = self.shape_env.evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6350, in evaluate_expr
    return self._evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6540, in _evaluate_expr
    raise self._make_data_dependent_error(
RuntimeError: Failed running call_function <built-in function eq>(*(FakeTensor(..., size=(60 - u76, u77 - 1), dtype=torch.int64), FakeTensor(..., size=(60 - u76, u77 - 1), dtype=torch.int64)), **{}):
Could not guard on data-dependent expression Eq(u77 - 1, 1) (unhinted: Eq(u77 - 1, 1)).  (Size-like symbols: u77)

Caused by: tmp_A == tmp_B  # executorch/examples/models/llama/utils.py:355 in evaluate_posterior (_refs/__init__.py:418 in _broadcast_shapes)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u77"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 45, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/executorch/examples/models/llama/ea_model.py", line 132, in forward
    best_candidate, accept_length, sample_p = evaluate_posterior(
  File "/usr/local/lib/python3.10/dist-packages/executorch/examples/models/llama/utils.py", line 355, in evaluate_posterior
    tmp_A == tmp_B

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/user/workspace/project/2025/executorch/examples/models/llama/export_llama.py", line 32, in <module>
    main()  # pragma: no cover
  File "/home/user/workspace/project/2025/executorch/examples/models/llama/export_llama.py", line 28, in main
    export_llama(args)
  File "/home/user/workspace/project/2025/executorch/examples/models/llama/export_llama_lib.py", line 520, in export_llama
    builder = _export_llama(args)
  File "/home/user/workspace/project/2025/executorch/examples/models/llama/export_llama_lib.py", line 655, in _export_llama
    builder_exported = _prepare_for_llama_export(args).export()
  File "/usr/local/lib/python3.10/dist-packages/executorch/extension/llm/export/builder.py", line 201, in export
    exported_module = export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 168, in export_for_training
    return _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1031, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1004, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1825, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 667, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1579, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 576, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1404, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 565, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1005, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 733, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 768, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1402, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 236, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 680, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2906, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1076, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 986, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 683, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1763, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 921, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/nn_module.py", line 443, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 927, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3111, in inline_call
    return tracer.inline_call_()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3248, in inline_call_
    self.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1076, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 986, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 683, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1763, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 921, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 437, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 319, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 120, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 927, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3111, in inline_call
    return tracer.inline_call_()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3248, in inline_call_
    self.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1076, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 986, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 683, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1763, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 921, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 319, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 120, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 927, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3111, in inline_call
    return tracer.inline_call_()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3248, in inline_call_
    self.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1076, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 986, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 683, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1685, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 921, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 319, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 120, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 927, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3111, in inline_call
    return tracer.inline_call_()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3248, in inline_call_
    self.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1076, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 986, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1676, in COMPARE_OP
    self.push(compare_op_handlers[inst.argval](self, self.popn(2), {}))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 1003, in call_function
    return handler(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 979, in _handle_insert_op_in_graph
    return wrap_fx_proxy(tx, proxy)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2212, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2278, in wrap_fx_proxy_cls
    return _wrap_fx_proxy(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2374, in _wrap_fx_proxy
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2616, in get_fake_value
    raise UserError(  # noqa: B904
torch._dynamo.exc.UserError: Could not guard on data-dependent expression Eq(u77 - 1, 1) (unhinted: Eq(u77 - 1, 1)).  (Size-like symbols: u77)

Caused by: tmp_A == tmp_B  # executorch/examples/models/llama/utils.py:355 in evaluate_posterior (_refs/__init__.py:418 in _broadcast_shapes)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u77"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 45, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/executorch/examples/models/llama/ea_model.py", line 132, in forward
    best_candidate, accept_length, sample_p = evaluate_posterior(
  File "/usr/local/lib/python3.10/dist-packages/executorch/examples/models/llama/utils.py", line 355, in evaluate_posterior
    tmp_A == tmp_B

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example

from user code:
   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 45, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/executorch/examples/models/llama/ea_model.py", line 132, in forward
    best_candidate, accept_length, sample_p = evaluate_posterior(
  File "/usr/local/lib/python3.10/dist-packages/executorch/examples/models/llama/utils.py", line 355, in evaluate_posterior
    tmp_A == tmp_B

cc @JacobSzwejbka @angelayi @mergennachin @byjlw

@iseeyuan iseeyuan added module: exir Issues related to Export IR and the code under exir/ module: user experience Issues related to reducing friction for users labels Mar 7, 2025
@github-project-automation github-project-automation bot moved this to To triage in ExecuTorch Core Mar 7, 2025
@github-project-automation github-project-automation bot moved this to To triage in ExecuTorch DevX Mar 7, 2025
@iseeyuan iseeyuan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 7, 2025
@iseeyuan
Copy link
Contributor

@angelayi , could you help on this issue, or forward to the correct POC?

@iseeyuan iseeyuan moved this from To triage to Backlog in ExecuTorch Core Mar 11, 2025
@angelayi
Copy link
Contributor

Hi! Would you like to try running a new thing we've been working on to help debug these types of errors?

Here's how you can do so:

  1. from torch.export._draft_export import draft_export
  2. Replace your export call with draft_export
  3. It should generate you a message saying to run tlparse /tmp/export/log.log --export
  4. pip install tlparse -- you'll need to be on version >= 0.3.34
  5. Run the command: tlparse /tmp/export/log.log --export
  6. View the generated html files. There should hopefully be an error related to [Eq(u77 - 1, 1)

To fix the issue, since you mentioned

Logically, the shapes of the following two tensors are the same (but I did not provide a hint)

you'll probably need to add a torch._check into the graph to hint to the compiler that the shapes of the tensors are equal. It'll look something like:

torch._check(tmpA.shape[0] == tmpB.shape[0])
tmpA == tmpB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: exir Issues related to Export IR and the code under exir/ module: user experience Issues related to reducing friction for users triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Status: Done
Development

No branches or pull requests

3 participants