Skip to content

Commit

Permalink
Sema: Use InferredGenericSignatureRequest in TypeCheckAttr.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
slavapestov committed Oct 20, 2021
1 parent d847bc1 commit 1262193
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 64 deletions.
99 changes: 35 additions & 64 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "swift/AST/DiagnosticsParse.h"
#include "swift/AST/Effects.h"
#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/GenericSignatureBuilder.h"
#include "swift/AST/ImportCache.h"
#include "swift/AST/ModuleNameLookup.h"
#include "swift/AST/NameLookup.h"
Expand Down Expand Up @@ -2231,28 +2230,17 @@ void AttributeChecker::visitSpecializeAttr(SpecializeAttr *attr) {
return;
}

// Form a new generic signature based on the old one.
GenericSignatureBuilder Builder(D->getASTContext());
InferredGenericSignatureRequest request{
DC->getParentModule(),
genericSig.getPointer(),
/*genericParams=*/nullptr,
WhereClauseOwner(FD, attr),
/*addedRequirements=*/{},
/*inferenceSources=*/{},
/*allowConcreteGenericParams=*/true};

// First, add the old generic signature.
Builder.addGenericSignature(genericSig);

// Go over the set of requirements, adding them to the builder.
WhereClauseOwner(FD, attr).visitRequirements(TypeResolutionStage::Interface,
[&](const Requirement &req, RequirementRepr *reqRepr) {
// Add the requirement to the generic signature builder.
using FloatingRequirementSource =
GenericSignatureBuilder::FloatingRequirementSource;
Builder.addRequirement(req, reqRepr,
FloatingRequirementSource::forExplicit(
reqRepr->getSeparatorLoc()),
nullptr, DC->getParentModule());
return false;
});

// Check the result.
auto specializedSig = std::move(Builder).computeGenericSignature(
/*allowConcreteGenericParams=*/true);
auto specializedSig = evaluateOrDefault(Ctx.evaluator, request,
GenericSignature());

// Check the validity of provided requirements.
checkSpecializeAttrRequirements(attr, genericSig, specializedSig, Ctx);
Expand Down Expand Up @@ -4266,7 +4254,8 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
// - If the `@differentiable` attribute has a `where` clause, use it to
// compute the derivative generic signature.
// - Otherwise, use the original function's generic signature by default.
derivativeGenSig = original->getGenericSignature();
auto originalGenSig = original->getGenericSignature();
derivativeGenSig = originalGenSig;

// Handle the `where` clause, if it exists.
// - Resolve attribute where clause requirements and store in the attribute
Expand All @@ -4291,7 +4280,6 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
return true;
}

auto originalGenSig = original->getGenericSignature();
if (!originalGenSig) {
// `where` clauses are valid only when the original function is generic.
diags
Expand All @@ -4304,51 +4292,34 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
return true;
}

// Build a new generic signature for autodiff derivative functions.
GenericSignatureBuilder builder(ctx);
// Add the original function's generic signature.
builder.addGenericSignature(originalGenSig);

using FloatingRequirementSource =
GenericSignatureBuilder::FloatingRequirementSource;

bool errorOccurred = false;
WhereClauseOwner(original, attr)
.visitRequirements(
TypeResolutionStage::Structural,
[&](const Requirement &req, RequirementRepr *reqRepr) {
switch (req.getKind()) {
case RequirementKind::SameType:
case RequirementKind::Superclass:
case RequirementKind::Conformance:
break;

// Layout requirements are not supported.
case RequirementKind::Layout:
diags
.diagnose(attr->getLocation(),
diag::differentiable_attr_layout_req_unsupported)
.highlight(reqRepr->getSourceRange());
errorOccurred = true;
return false;
}
InferredGenericSignatureRequest request{
original->getParentModule(),
originalGenSig.getPointer(),
/*genericParams=*/nullptr,
WhereClauseOwner(original, attr),
/*addedRequirements=*/{},
/*inferenceSources=*/{},
/*allowConcreteParams=*/true};

// Compute generic signature for derivative functions.
derivativeGenSig = evaluateOrDefault(ctx.evaluator, request,
GenericSignature());

// Add requirement to generic signature builder.
builder.addRequirement(
req, reqRepr, FloatingRequirementSource::forExplicit(
reqRepr->getSeparatorLoc()),
nullptr, original->getModuleContext());
return false;
});
bool hadInvalidRequirements = false;
for (auto req : derivativeGenSig.requirementsNotSatisfiedBy(originalGenSig)) {
if (req.getKind() == RequirementKind::Layout) {
// Layout requirements are not supported.
diags
.diagnose(attr->getLocation(),
diag::differentiable_attr_layout_req_unsupported);
hadInvalidRequirements = true;
}
}

if (errorOccurred) {
if (hadInvalidRequirements) {
attr->setInvalid();
return true;
}

// Compute generic signature for derivative functions.
derivativeGenSig = std::move(builder).computeGenericSignature(
/*allowConcreteGenericParams=*/true);
}

attr->setDerivativeGenericSignature(derivativeGenSig);
Expand Down
1 change: 1 addition & 0 deletions test/AutoDiff/Sema/differentiable_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ func invalidRequirementConformance<Scalar>(x: Scalar) -> Scalar {
return x
}

// expected-error @+1 {{'@differentiable' attribute does not yet support layout requirements}}
@differentiable(reverse where T: AnyObject)
func invalidAnyObjectRequirement<T: Differentiable>(x: T) -> T {
return x
Expand Down

0 comments on commit 1262193

Please sign in to comment.