diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp index c5bb33e0c47e0..3b1a18fd681eb 100644 --- a/lib/Sema/CSSolver.cpp +++ b/lib/Sema/CSSolver.cpp @@ -1683,14 +1683,91 @@ getOperatorDesignatedNominalTypes(Constraint *bindOverload) { return operatorDecl->getDesignatedNominalTypes(); } +void ConstraintSystem::sortDesignatedTypes( + SmallVectorImpl &nominalTypes, + Constraint *bindOverload) { + auto *tyvar = bindOverload->getFirstType()->castTo(); + llvm::SetVector applicableFns; + getConstraintGraph().gatherConstraints( + tyvar, applicableFns, ConstraintGraph::GatheringKind::EquivalenceClass, + [](Constraint *match) { + return match->getKind() == ConstraintKind::ApplicableFunction; + }); + + // FIXME: This is not true when we run the constraint optimizer. + // assert(applicableFns.size() <= 1); + + // We have a disjunction for an operator but no application of it, + // so it's being passed as an argument. + if (applicableFns.size() == 0) + return; + + // FIXME: We have more than one applicable per disjunction as a + // result of merging disjunction type variables. We may want + // to rip that out at some point. + Constraint *foundApplicable = nullptr; + SmallVector, 2> argumentTypes; + for (auto *applicableFn : applicableFns) { + argumentTypes.clear(); + auto *fnTy = applicableFn->getFirstType()->castTo(); + ArgumentInfoCollector argInfo(*this, fnTy); + // Stop if we hit anything with concrete types or conformances to + // literals. + if (!argInfo.getTypes().empty() || !argInfo.getLiteralProtocols().empty()) { + foundApplicable = applicableFn; + break; + } + } + + if (!foundApplicable) + return; + + // FIXME: It would be good to avoid this redundancy. + auto *fnTy = foundApplicable->getFirstType()->castTo(); + ArgumentInfoCollector argInfo(*this, fnTy); + + size_t nextType = 0; + for (auto argType : argInfo.getTypes()) { + auto *nominal = argType->getAnyNominal(); + for (size_t i = nextType + 1; i < nominalTypes.size(); ++i) { + if (nominal == nominalTypes[i]) { + std::swap(nominalTypes[nextType], nominalTypes[i]); + ++nextType; + break; + } + } + } + + if (nextType + 1 >= nominalTypes.size()) + return; + + for (auto *protocol : argInfo.getLiteralProtocols()) { + auto defaultType = TC.getDefaultType(protocol, DC); + auto *nominal = defaultType->getAnyNominal(); + for (size_t i = nextType + 1; i < nominalTypes.size(); ++i) { + if (nominal == nominalTypes[i]) { + std::swap(nominalTypes[nextType], nominalTypes[i]); + ++nextType; + break; + } + } + } +} + void ConstraintSystem::partitionForDesignatedTypes( ArrayRef Choices, ConstraintMatchLoop forEachChoice, PartitionAppendCallback appendPartition) { - auto designatedNominalTypes = getOperatorDesignatedNominalTypes(Choices[0]); - if (designatedNominalTypes.empty()) + auto types = getOperatorDesignatedNominalTypes(Choices[0]); + if (types.empty()) return; + SmallVector designatedNominalTypes(types.begin(), + types.end()); + + if (designatedNominalTypes.size() > 1) + sortDesignatedTypes(designatedNominalTypes, Choices[0]); + SmallVector, 4> definedInDesignatedType; SmallVector, 4> definedInExtensionOfDesignatedType; diff --git a/lib/Sema/ConstraintSystem.h b/lib/Sema/ConstraintSystem.h index a76baa4f8840c..aa11af7190635 100644 --- a/lib/Sema/ConstraintSystem.h +++ b/lib/Sema/ConstraintSystem.h @@ -3227,6 +3227,12 @@ class ConstraintSystem { typedef std::function &options)> PartitionAppendCallback; + // Attempt to sort nominalTypes based on what we can discover about + // calls into the overloads in the disjunction that bindOverload is + // a part of. + void sortDesignatedTypes(SmallVectorImpl &nominalTypes, + Constraint *bindOverload); + // Partition the choices in a disjunction based on those that match // the designated types for the operator that the disjunction was // formed for. diff --git a/validation-test/Sema/type_checker_perf/fast/rdar17024694.swift b/validation-test/Sema/type_checker_perf/fast/rdar17024694.swift index 0181500267f6b..4a799ea63449d 100644 --- a/validation-test/Sema/type_checker_perf/fast/rdar17024694.swift +++ b/validation-test/Sema/type_checker_perf/fast/rdar17024694.swift @@ -1,4 +1,4 @@ -// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1 +// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1 -swift-version 5 -solver-disable-shrink -disable-constraint-solver-performance-hacks -solver-enable-operator-designated-types // REQUIRES: tools-release,no_asserts _ = (2...100).reversed().filter({ $0 % 11 == 0 }).map {