Skip to content

Commit

Permalink
Auto merge of rust-lang#91844 - nnethercote:rm-ObligationCauseData-2,…
Browse files Browse the repository at this point in the history
… r=Mark-Simulacrum

Eliminate `ObligationCauseData`

This makes `Obligation` two words bigger, but avoids allocating a lot of the time.

I previously tried this in rust-lang#73983 and it didn't help much, but local timings look more promising now.

r? `@ghost`
  • Loading branch information
bors committed Dec 20, 2021
2 parents e95e084 + f09b1fa commit ed7a206
Show file tree
Hide file tree
Showing 22 changed files with 135 additions and 139 deletions.
22 changes: 11 additions & 11 deletions compiler/rustc_infer/src/infer/error_reporting/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
exp_found: Option<ty::error::ExpectedFound<Ty<'tcx>>>,
terr: &TypeError<'tcx>,
) {
match cause.code {
match *cause.code() {
ObligationCauseCode::Pattern { origin_expr: true, span: Some(span), root_ty } => {
let ty = self.resolve_vars_if_possible(root_ty);
if ty.is_suggestable() {
Expand Down Expand Up @@ -781,7 +781,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
}
_ => {
if let ObligationCauseCode::BindingObligation(_, binding_span) =
cause.code.peel_derives()
cause.code().peel_derives()
{
if matches!(terr, TypeError::RegionsPlaceholderMismatch) {
err.span_note(*binding_span, "the lifetime requirement is introduced here");
Expand Down Expand Up @@ -1729,10 +1729,10 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
}
_ => exp_found,
};
debug!("exp_found {:?} terr {:?} cause.code {:?}", exp_found, terr, cause.code);
debug!("exp_found {:?} terr {:?} cause.code {:?}", exp_found, terr, cause.code());
if let Some(exp_found) = exp_found {
let should_suggest_fixes = if let ObligationCauseCode::Pattern { root_ty, .. } =
&cause.code
cause.code()
{
// Skip if the root_ty of the pattern is not the same as the expected_ty.
// If these types aren't equal then we've probably peeled off a layer of arrays.
Expand Down Expand Up @@ -1827,15 +1827,15 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
exp_span, exp_found.expected, exp_found.found,
);

if let ObligationCauseCode::CompareImplMethodObligation { .. } = &cause.code {
if let ObligationCauseCode::CompareImplMethodObligation { .. } = cause.code() {
return;
}

match (
self.get_impl_future_output_ty(exp_found.expected),
self.get_impl_future_output_ty(exp_found.found),
) {
(Some(exp), Some(found)) if same_type_modulo_infer(exp, found) => match &cause.code {
(Some(exp), Some(found)) if same_type_modulo_infer(exp, found) => match cause.code() {
ObligationCauseCode::IfExpression(box IfExpressionCause { then, .. }) => {
diag.multipart_suggestion(
"consider `await`ing on both `Future`s",
Expand Down Expand Up @@ -1875,7 +1875,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
Applicability::MaybeIncorrect,
);
}
(Some(ty), _) if same_type_modulo_infer(ty, exp_found.found) => match cause.code {
(Some(ty), _) if same_type_modulo_infer(ty, exp_found.found) => match cause.code() {
ObligationCauseCode::Pattern { span: Some(span), .. }
| ObligationCauseCode::IfExpression(box IfExpressionCause { then: span, .. }) => {
diag.span_suggestion_verbose(
Expand Down Expand Up @@ -1927,7 +1927,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
.map(|field| (field.ident.name, field.ty(self.tcx, expected_substs)))
.find(|(_, ty)| same_type_modulo_infer(ty, exp_found.found))
{
if let ObligationCauseCode::Pattern { span: Some(span), .. } = cause.code {
if let ObligationCauseCode::Pattern { span: Some(span), .. } = *cause.code() {
if let Ok(snippet) = self.tcx.sess.source_map().span_to_snippet(span) {
let suggestion = if expected_def.is_struct() {
format!("{}.{}", snippet, name)
Expand Down Expand Up @@ -2064,7 +2064,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
}
}
if let MatchExpressionArm(box MatchExpressionArmCause { source, .. }) =
trace.cause.code
*trace.cause.code()
{
if let hir::MatchSource::TryDesugar = source {
if let Some((expected_ty, found_ty)) = self.values_str(trace.values) {
Expand Down Expand Up @@ -2659,7 +2659,7 @@ impl<'tcx> ObligationCauseExt<'tcx> for ObligationCause<'tcx> {
fn as_failure_code(&self, terr: &TypeError<'tcx>) -> FailureCode {
use self::FailureCode::*;
use crate::traits::ObligationCauseCode::*;
match self.code {
match self.code() {
CompareImplMethodObligation { .. } => Error0308("method not compatible with trait"),
CompareImplTypeObligation { .. } => Error0308("type not compatible with trait"),
MatchExpressionArm(box MatchExpressionArmCause { source, .. }) => {
Expand Down Expand Up @@ -2694,7 +2694,7 @@ impl<'tcx> ObligationCauseExt<'tcx> for ObligationCause<'tcx> {

fn as_requirement_str(&self) -> &'static str {
use crate::traits::ObligationCauseCode::*;
match self.code {
match self.code() {
CompareImplMethodObligation { .. } => "method type is compatible with trait",
CompareImplTypeObligation { .. } => "associated type is compatible with trait",
ExprAssignable => "expression is assignable",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {
};
// If we added a "points at argument expression" obligation, we remove it here, we care
// about the original obligation only.
let code = match &cause.code {
let code = match cause.code() {
ObligationCauseCode::FunctionArgumentObligation { parent_code, .. } => &*parent_code,
_ => &cause.code,
_ => cause.code(),
};
let (parent, impl_def_id) = match code {
ObligationCauseCode::MatchImpl(parent, impl_def_id) => (parent, impl_def_id),
_ => return None,
};
let binding_span = match parent.code {
let binding_span = match *parent.code() {
ObligationCauseCode::BindingObligation(_def_id, binding_span) => binding_span,
_ => return None,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ impl<'tcx> NiceRegionError<'_, 'tcx> {
);
let mut err = self.tcx().sess.struct_span_err(span, &msg);

let leading_ellipsis = if let ObligationCauseCode::ItemObligation(def_id) = cause.code {
let leading_ellipsis = if let ObligationCauseCode::ItemObligation(def_id) = *cause.code() {
err.span_label(span, "doesn't satisfy where-clause");
err.span_label(
self.tcx().def_span(def_id),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {
sup_r,
) if **sub_r == RegionKind::ReStatic => {
// This is for an implicit `'static` requirement coming from `impl dyn Trait {}`.
if let ObligationCauseCode::UnifyReceiver(ctxt) = &cause.code {
if let ObligationCauseCode::UnifyReceiver(ctxt) = cause.code() {
// This may have a closure and it would cause ICE
// through `find_param_with_region` (#78262).
let anon_reg_sup = tcx.is_suitable_region(sup_r)?;
Expand Down Expand Up @@ -184,7 +184,7 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {
}
if let SubregionOrigin::Subtype(box TypeTrace { cause, .. }) = sub_origin {
if let ObligationCauseCode::ReturnValue(hir_id)
| ObligationCauseCode::BlockTailExpression(hir_id) = &cause.code
| ObligationCauseCode::BlockTailExpression(hir_id) = cause.code()
{
let parent_id = tcx.hir().get_parent_item(*hir_id);
if let Some(fn_decl) = tcx.hir().fn_decl_by_hir_id(parent_id) {
Expand Down Expand Up @@ -226,7 +226,7 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {

let mut override_error_code = None;
if let SubregionOrigin::Subtype(box TypeTrace { cause, .. }) = &sup_origin {
if let ObligationCauseCode::UnifyReceiver(ctxt) = &cause.code {
if let ObligationCauseCode::UnifyReceiver(ctxt) = cause.code() {
// Handle case of `impl Foo for dyn Bar { fn qux(&self) {} }` introducing a
// `'static` lifetime when called as a method on a binding: `bar.qux()`.
if self.find_impl_on_dyn_trait(&mut err, param.param_ty, &ctxt) {
Expand All @@ -235,9 +235,9 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {
}
}
if let SubregionOrigin::Subtype(box TypeTrace { cause, .. }) = &sub_origin {
let code = match &cause.code {
ObligationCauseCode::MatchImpl(parent, ..) => &parent.code,
_ => &cause.code,
let code = match cause.code() {
ObligationCauseCode::MatchImpl(parent, ..) => parent.code(),
_ => cause.code(),
};
if let (ObligationCauseCode::ItemObligation(item_def_id), None) =
(code, override_error_code)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {
ValuePairs::Types(sub_expected_found),
ValuePairs::Types(sup_expected_found),
CompareImplMethodObligation { trait_item_def_id, .. },
) = (&sub_trace.values, &sup_trace.values, &sub_trace.cause.code)
) = (&sub_trace.values, &sup_trace.values, sub_trace.cause.code())
{
if sup_expected_found == sub_expected_found {
self.emit_err(
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_infer/src/infer/error_reporting/note.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,13 +359,13 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
match placeholder_origin {
infer::Subtype(box ref trace)
if matches!(
&trace.cause.code.peel_derives(),
&trace.cause.code().peel_derives(),
ObligationCauseCode::BindingObligation(..)
) =>
{
// Hack to get around the borrow checker because trace.cause has an `Rc`.
if let ObligationCauseCode::BindingObligation(_, span) =
&trace.cause.code.peel_derives()
&trace.cause.code().peel_derives()
{
let span = *span;
let mut err = self.report_concrete_failure(placeholder_origin, sub, sup);
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1824,7 +1824,7 @@ impl<'tcx> SubregionOrigin<'tcx> {
where
F: FnOnce() -> Self,
{
match cause.code {
match *cause.code() {
traits::ObligationCauseCode::ReferenceOutlivesReferent(ref_type) => {
SubregionOrigin::ReferenceOutlivesReferent(ref_type, cause.span)
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_infer/src/infer/outlives/obligations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl<'cx, 'tcx> InferCtxt<'cx, 'tcx> {
infer::RelateParamBound(
cause.span,
sup_type,
match cause.code.peel_derives() {
match cause.code().peel_derives() {
ObligationCauseCode::BindingObligation(_, span) => Some(*span),
_ => None,
},
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_infer/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl TraitObligation<'_> {

// `PredicateObligation` is used a lot. Make sure it doesn't unintentionally get bigger.
#[cfg(all(target_arch = "x86_64", target_pointer_width = "64"))]
static_assert_size!(PredicateObligation<'_>, 32);
static_assert_size!(PredicateObligation<'_>, 48);

pub type PredicateObligations<'tcx> = Vec<PredicateObligation<'tcx>>;

Expand Down
82 changes: 40 additions & 42 deletions compiler/rustc_middle/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ use rustc_span::{Span, DUMMY_SP};
use smallvec::SmallVec;

use std::borrow::Cow;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::ops::Deref;

pub use self::select::{EvaluationCache, EvaluationResult, OverflowError, SelectionCache};

Expand Down Expand Up @@ -80,38 +78,14 @@ pub enum Reveal {

/// The reason why we incurred this obligation; used for error reporting.
///
/// As the happy path does not care about this struct, storing this on the heap
/// ends up increasing performance.
/// Non-misc `ObligationCauseCode`s are stored on the heap. This gives the
/// best trade-off between keeping the type small (which makes copies cheaper)
/// while not doing too many heap allocations.
///
/// We do not want to intern this as there are a lot of obligation causes which
/// only live for a short period of time.
#[derive(Clone, PartialEq, Eq, Hash, Lift)]
pub struct ObligationCause<'tcx> {
/// `None` for `ObligationCause::dummy`, `Some` otherwise.
data: Option<Lrc<ObligationCauseData<'tcx>>>,
}

const DUMMY_OBLIGATION_CAUSE_DATA: ObligationCauseData<'static> =
ObligationCauseData { span: DUMMY_SP, body_id: hir::CRATE_HIR_ID, code: MiscObligation };

// Correctly format `ObligationCause::dummy`.
impl<'tcx> fmt::Debug for ObligationCause<'tcx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
ObligationCauseData::fmt(self, f)
}
}

impl<'tcx> Deref for ObligationCause<'tcx> {
type Target = ObligationCauseData<'tcx>;

#[inline(always)]
fn deref(&self) -> &Self::Target {
self.data.as_deref().unwrap_or(&DUMMY_OBLIGATION_CAUSE_DATA)
}
}

#[derive(Clone, Debug, PartialEq, Eq, Lift)]
pub struct ObligationCauseData<'tcx> {
pub struct ObligationCause<'tcx> {
pub span: Span,

/// The ID of the fn body that triggered this obligation. This is
Expand All @@ -122,46 +96,58 @@ pub struct ObligationCauseData<'tcx> {
/// information.
pub body_id: hir::HirId,

pub code: ObligationCauseCode<'tcx>,
/// `None` for `MISC_OBLIGATION_CAUSE_CODE` (a common case, occurs ~60% of
/// the time). `Some` otherwise.
code: Option<Lrc<ObligationCauseCode<'tcx>>>,
}

impl Hash for ObligationCauseData<'_> {
// This custom hash function speeds up hashing for `Obligation` deduplication
// greatly by skipping the `code` field, which can be large and complex. That
// shouldn't affect hash quality much since there are several other fields in
// `Obligation` which should be unique enough, especially the predicate itself
// which is hashed as an interned pointer. See #90996.
impl Hash for ObligationCause<'_> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.body_id.hash(state);
self.span.hash(state);
std::mem::discriminant(&self.code).hash(state);
}
}

const MISC_OBLIGATION_CAUSE_CODE: ObligationCauseCode<'static> = MiscObligation;

impl<'tcx> ObligationCause<'tcx> {
#[inline]
pub fn new(
span: Span,
body_id: hir::HirId,
code: ObligationCauseCode<'tcx>,
) -> ObligationCause<'tcx> {
ObligationCause { data: Some(Lrc::new(ObligationCauseData { span, body_id, code })) }
ObligationCause {
span,
body_id,
code: if code == MISC_OBLIGATION_CAUSE_CODE { None } else { Some(Lrc::new(code)) },
}
}

pub fn misc(span: Span, body_id: hir::HirId) -> ObligationCause<'tcx> {
ObligationCause::new(span, body_id, MiscObligation)
}

pub fn dummy_with_span(span: Span) -> ObligationCause<'tcx> {
ObligationCause::new(span, hir::CRATE_HIR_ID, MiscObligation)
}

#[inline(always)]
pub fn dummy() -> ObligationCause<'tcx> {
ObligationCause { data: None }
ObligationCause { span: DUMMY_SP, body_id: hir::CRATE_HIR_ID, code: None }
}

pub fn dummy_with_span(span: Span) -> ObligationCause<'tcx> {
ObligationCause { span, body_id: hir::CRATE_HIR_ID, code: None }
}

pub fn make_mut(&mut self) -> &mut ObligationCauseData<'tcx> {
Lrc::make_mut(self.data.get_or_insert_with(|| Lrc::new(DUMMY_OBLIGATION_CAUSE_DATA)))
pub fn make_mut_code(&mut self) -> &mut ObligationCauseCode<'tcx> {
Lrc::make_mut(self.code.get_or_insert_with(|| Lrc::new(MISC_OBLIGATION_CAUSE_CODE)))
}

pub fn span(&self, tcx: TyCtxt<'tcx>) -> Span {
match self.code {
match *self.code() {
ObligationCauseCode::CompareImplMethodObligation { .. }
| ObligationCauseCode::MainFunctionType
| ObligationCauseCode::StartFunctionType => {
Expand All @@ -174,6 +160,18 @@ impl<'tcx> ObligationCause<'tcx> {
_ => self.span,
}
}

#[inline]
pub fn code(&self) -> &ObligationCauseCode<'tcx> {
self.code.as_deref().unwrap_or(&MISC_OBLIGATION_CAUSE_CODE)
}

pub fn clone_code(&self) -> Lrc<ObligationCauseCode<'tcx>> {
match &self.code {
Some(code) => code.clone(),
None => Lrc::new(MISC_OBLIGATION_CAUSE_CODE),
}
}
}

#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift)]
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ impl<T> Trait<T> for X {
proj_ty,
values,
body_owner_def_id,
&cause.code,
cause.code(),
);
}
(_, ty::Projection(proj_ty)) => {
Expand Down
Loading

0 comments on commit ed7a206

Please sign in to comment.