Skip to content

Commit

Permalink
Merge pull request swiftlang#24741 from xymus/default-definition-request
Browse files Browse the repository at this point in the history
Request to lazily compute the default definition type of associated types
  • Loading branch information
xymus authored May 14, 2019
2 parents 0bb5771 + 547ba1c commit d4e6b58
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 51 deletions.
15 changes: 7 additions & 8 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3045,17 +3045,19 @@ class AssociatedTypeDecl : public AbstractTypeParamDecl {
SourceLoc KeywordLoc;

/// The default definition.
TypeLoc DefaultDefinition;
TypeRepr *DefaultDefinition;

/// The where clause attached to the associated type.
TrailingWhereClause *TrailingWhere;

LazyMemberLoader *Resolver = nullptr;
uint64_t ResolverContextData;

friend class DefaultDefinitionTypeRequest;

public:
AssociatedTypeDecl(DeclContext *dc, SourceLoc keywordLoc, Identifier name,
SourceLoc nameLoc, TypeLoc defaultDefinition,
SourceLoc nameLoc, TypeRepr *defaultDefinition,
TrailingWhereClause *trailingWhere);
AssociatedTypeDecl(DeclContext *dc, SourceLoc keywordLoc, Identifier name,
SourceLoc nameLoc, TrailingWhereClause *trailingWhere,
Expand All @@ -3071,17 +3073,14 @@ class AssociatedTypeDecl : public AbstractTypeParamDecl {
bool hasDefaultDefinitionType() const {
// If we have a TypeRepr, return true immediately without kicking off
// a request.
return !DefaultDefinition.isNull() || getDefaultDefinitionType();
return DefaultDefinition || getDefaultDefinitionType();
}

/// Retrieve the default definition type.
Type getDefaultDefinitionType() const;

TypeLoc &getDefaultDefinitionLoc() {
return DefaultDefinition;
}

const TypeLoc &getDefaultDefinitionLoc() const {
/// Retrieve the default definition as written in the source.
TypeRepr *getDefaultDefinitionTypeRepr() const {
return DefaultDefinition;
}

Expand Down
24 changes: 24 additions & 0 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,30 @@ class RequirementSignatureRequest :
void cacheResult(ArrayRef<Requirement> value) const;
};

/// Compute the default definition type of an associated type.
class DefaultDefinitionTypeRequest :
public SimpleRequest<DefaultDefinitionTypeRequest,
CacheKind::Cached,
Type,
AssociatedTypeDecl *> {
public:
using SimpleRequest::SimpleRequest;

private:
friend SimpleRequest;

// Evaluation.
llvm::Expected<Type> evaluate(Evaluator &evaluator, AssociatedTypeDecl *decl) const;

public:
// Cycle handling
void diagnoseCycle(DiagnosticEngine &diags) const;
void noteCycleStep(DiagnosticEngine &diags) const;

// Caching.
bool isCached() const { return true; }
};

/// Describes the owner of a where clause, from which we can extract
/// requirements.
struct WhereClauseOwner {
Expand Down
1 change: 1 addition & 0 deletions include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ SWIFT_TYPEID(IsFinalRequest)
SWIFT_TYPEID(IsDynamicRequest)
SWIFT_TYPEID(RequirementRequest)
SWIFT_TYPEID(RequirementSignatureRequest)
SWIFT_TYPEID(DefaultDefinitionTypeRequest)
SWIFT_TYPEID(USRGenerationRequest)
SWIFT_TYPEID(StructuralTypeRequest)
SWIFT_TYPEID(DefaultTypeRequest)
Expand Down
21 changes: 9 additions & 12 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3539,7 +3539,7 @@ SourceRange GenericTypeParamDecl::getSourceRange() const {

AssociatedTypeDecl::AssociatedTypeDecl(DeclContext *dc, SourceLoc keywordLoc,
Identifier name, SourceLoc nameLoc,
TypeLoc defaultDefinition,
TypeRepr *defaultDefinition,
TrailingWhereClause *trailingWhere)
: AbstractTypeParamDecl(DeclKind::AssociatedType, dc, name, nameLoc),
KeywordLoc(keywordLoc), DefaultDefinition(defaultDefinition),
Expand All @@ -3552,8 +3552,9 @@ AssociatedTypeDecl::AssociatedTypeDecl(DeclContext *dc, SourceLoc keywordLoc,
LazyMemberLoader *definitionResolver,
uint64_t resolverData)
: AbstractTypeParamDecl(DeclKind::AssociatedType, dc, name, nameLoc),
KeywordLoc(keywordLoc), TrailingWhere(trailingWhere),
Resolver(definitionResolver), ResolverContextData(resolverData) {
KeywordLoc(keywordLoc), DefaultDefinition(nullptr),
TrailingWhere(trailingWhere), Resolver(definitionResolver),
ResolverContextData(resolverData) {
assert(Resolver && "missing resolver");
}

Expand All @@ -3566,21 +3567,17 @@ void AssociatedTypeDecl::computeType() {
}

Type AssociatedTypeDecl::getDefaultDefinitionType() const {
if (Resolver) {
const_cast<AssociatedTypeDecl *>(this)->DefaultDefinition
= TypeLoc::withoutLoc(
Resolver->loadAssociatedTypeDefault(this, ResolverContextData));
const_cast<AssociatedTypeDecl *>(this)->Resolver = nullptr;
}
return DefaultDefinition.getType();
return evaluateOrDefault(getASTContext().evaluator,
DefaultDefinitionTypeRequest{const_cast<AssociatedTypeDecl *>(this)},
Type());
}

SourceRange AssociatedTypeDecl::getSourceRange() const {
SourceLoc endLoc;
if (auto TWC = getTrailingWhereClause()) {
endLoc = TWC->getSourceRange().End;
} else if (getDefaultDefinitionLoc().hasLocation()) {
endLoc = getDefaultDefinitionLoc().getSourceRange().End;
} else if (auto defaultDefinition = getDefaultDefinitionTypeRepr()) {
endLoc = defaultDefinition->getEndLoc();
} else if (!getInherited().empty()) {
endLoc = getInherited().back().getSourceRange().End;
} else {
Expand Down
14 changes: 14 additions & 0 deletions lib/AST/TypeCheckRequests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,20 @@ void RequirementSignatureRequest::cacheResult(ArrayRef<Requirement> value) const
proto->setRequirementSignature(value);
}

//----------------------------------------------------------------------------//
// DefaultDefinitionTypeRequest computation.
//----------------------------------------------------------------------------//

void DefaultDefinitionTypeRequest::diagnoseCycle(DiagnosticEngine &diags) const {
auto decl = std::get<0>(getStorage());
diags.diagnose(decl, diag::circular_reference);
}

void DefaultDefinitionTypeRequest::noteCycleStep(DiagnosticEngine &diags) const {
auto decl = std::get<0>(getStorage());
diags.diagnose(decl, diag::circular_reference_through);
}

//----------------------------------------------------------------------------//
// Requirement computation.
//----------------------------------------------------------------------------//
Expand Down
9 changes: 6 additions & 3 deletions lib/Sema/TypeCheckAccess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,8 @@ class AccessControlChecker : public AccessControlCheckerBase,
}
});
});
checkTypeAccess(assocType->getDefaultDefinitionLoc(), assocType,
checkTypeAccess(assocType->getDefaultDefinitionType(),
assocType->getDefaultDefinitionTypeRepr(), assocType,
/*mayBeInferred*/false,
[&](AccessScope typeAccessScope,
const TypeRepr *thisComplainRepr,
Expand Down Expand Up @@ -1215,7 +1216,8 @@ class UsableFromInlineChecker : public AccessControlCheckerBase,
highlightOffendingType(TC, diag, complainRepr);
});
});
checkTypeAccess(assocType->getDefaultDefinitionLoc(), assocType,
checkTypeAccess(assocType->getDefaultDefinitionType(),
assocType->getDefaultDefinitionTypeRepr(), assocType,
/*mayBeInferred*/false,
[&](AccessScope typeAccessScope,
const TypeRepr *complainRepr,
Expand Down Expand Up @@ -1775,7 +1777,8 @@ class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {
checkType(requirement, assocType, getDiagnoseCallback(assocType),
getDiagnoseCallback(assocType));
});
checkType(assocType->getDefaultDefinitionLoc(), assocType,
checkType(assocType->getDefaultDefinitionType(),
assocType->getDefaultDefinitionTypeRepr(), assocType,
getDiagnoseCallback(assocType), getDiagnoseCallback(assocType));

if (assocType->getTrailingWhereClause()) {
Expand Down
64 changes: 38 additions & 26 deletions lib/Sema/TypeCheckDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1413,6 +1413,25 @@ RequirementSignatureRequest::evaluate(Evaluator &evaluator, ProtocolDecl *proto)
return reqSignature->getRequirements();
}

llvm::Expected<Type>
DefaultDefinitionTypeRequest::evaluate(Evaluator &evaluator,
AssociatedTypeDecl *assocType) const {
if (assocType->Resolver) {
auto defaultType = assocType->Resolver->loadAssociatedTypeDefault(
assocType, assocType->ResolverContextData);
assocType->Resolver = nullptr;
return defaultType;
}

TypeRepr *defaultDefinition = assocType->getDefaultDefinitionTypeRepr();
if (defaultDefinition) {
auto resolution = TypeResolution::forInterface(assocType->getDeclContext());
return resolution.resolveType(defaultDefinition, None);
}

return Type();
}

namespace {
/// How to generate the raw value for each element of an enum that doesn't
/// have one explicitly specified.
Expand Down Expand Up @@ -2652,6 +2671,25 @@ class DeclChecker : public DeclVisitor<DeclChecker> {

// Trigger the checking for overridden declarations.
(void)AT->getOverriddenDecls();

auto defaultType = AT->getDefaultDefinitionType();
if (defaultType && !defaultType->hasError()) {
// associatedtype X = X is invalid
auto mentionsItself =
defaultType.findIf([&](Type defaultType) {
if (auto DMT = defaultType->getAs<DependentMemberType>()) {
return DMT->getAssocType() == AT;
}
return false;
});

if (mentionsItself) {
TC.diagnose(AT->getDefaultDefinitionTypeRepr()->getLoc(),
diag::recursive_decl_reference,
AT->getDescriptiveKind(), AT->getName());
AT->diagnose(diag::kind_declared_here, DescriptiveDeclKind::Type);
}
}
}

void checkUnsupportedNestedType(NominalTypeDecl *NTD) {
Expand Down Expand Up @@ -3754,32 +3792,6 @@ void TypeChecker::validateDecl(ValueDecl *D) {

DeclValidationRAII IBV(assocType);

// Check the default definition, if there is one.
TypeLoc &defaultDefinition = assocType->getDefaultDefinitionLoc();
if (!defaultDefinition.isNull()) {
if (validateType(
defaultDefinition,
TypeResolution::forInterface(
assocType->getDeclContext()),
None)) {
defaultDefinition.setInvalidType(Context);
} else {
// associatedtype X = X is invalid
auto mentionsItself =
defaultDefinition.getType().findIf([&](Type type) {
if (auto DMT = type->getAs<DependentMemberType>()) {
return DMT->getAssocType() == assocType;
}
return false;
});

if (mentionsItself) {
diagnose(defaultDefinition.getLoc(), diag::recursive_decl_reference,
assocType->getDescriptiveKind(), assocType->getName());
diagnose(assocType, diag::kind_declared_here, DescriptiveDeclKind::Type);
}
}
}
// Finally, set the interface type.
if (!assocType->hasInterfaceType())
assocType->computeType();
Expand Down
2 changes: 1 addition & 1 deletion lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5648,7 +5648,7 @@ void TypeChecker::inferDefaultWitnesses(ProtocolDecl *proto) {
diagnose(defaultedAssocType, diag::assoc_type_default_here,
assocType->getFullName(), defaultAssocType)
.highlight(
defaultedAssocType->getDefaultDefinitionLoc().getSourceRange());
defaultedAssocType->getDefaultDefinitionTypeRepr()->getSourceRange());

continue;
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Sema/TypeCheckProtocolInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1866,7 +1866,7 @@ bool AssociatedTypeInference::diagnoseAmbiguousSolutions(
// Otherwise, we have a default.
diags.diagnose(assocType, diag::associated_type_deduction_default,
type)
.highlight(assocType->getDefaultDefinitionLoc().getSourceRange());
.highlight(assocType->getDefaultDefinitionTypeRepr()->getSourceRange());
};

diagnoseWitness(firstMatch, firstType);
Expand Down

0 comments on commit d4e6b58

Please sign in to comment.