Skip to content

Commit

Permalink
[AutoDiff upstream] Add SIL transpose function type calculation. (swi…
Browse files Browse the repository at this point in the history
…ftlang#29755)

Add `SILFunctionType::getAutoDiffTransposeFunctionType`.

It computes the transpose `SILFucntionType` for an original `SILFunctionType`,
given:

- Linearity parameter indices
- Transpose function generic signature (optional)
- Other auxiliary parameters

Add doc comments explaining typing rules, preconditions, and other details.
Add `isTranspose` flag to `autodiff::getConstrainedDerivativeGenericSignature`.

Partially resolves TF-1125.
Unblocks TF-1141: upstream `differentiability_witness_function` instruction.
  • Loading branch information
dan-zheng authored Feb 11, 2020
1 parent 9092b82 commit c2ac96f
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 10 deletions.
13 changes: 8 additions & 5 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,17 @@ void getSubsetParameterTypes(IndexSubset *indices, AnyFunctionType *type,
/// "Constrained" derivative generic signatures require all differentiability
/// parameters to conform to the `Differentiable` protocol.
///
/// Returns the "constrained" derivative generic signature given:
/// "Constrained" transpose generic signatures additionally require all
/// linearity parameters to satisfy `Self == Self.TangentVector`.
///
/// Returns the "constrained" derivative/transpose generic signature given:
/// - An original SIL function type.
/// - Differentiability parameter indices.
/// - A possibly "unconstrained" derivative generic signature.
GenericSignature
getConstrainedDerivativeGenericSignature(SILFunctionType *originalFnTy,
IndexSubset *diffParamIndices,
GenericSignature derivativeGenSig);
GenericSignature getConstrainedDerivativeGenericSignature(
SILFunctionType *originalFnTy, IndexSubset *diffParamIndices,
GenericSignature derivativeGenSig, LookupConformanceFn lookupConformance,
bool isTranspose = false);

} // end namespace autodiff

Expand Down
36 changes: 36 additions & 0 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -4542,6 +4542,42 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
CanGenericSignature derivativeFunctionGenericSignature = nullptr,
bool isReabstractionThunk = false);

/// Returns the type of the transpose function for the given parameter
/// indices, transpose function generic signature (optional), and other
/// auxiliary parameters.
///
/// Preconditions:
/// - Linearity parameters corresponding to parameter indices must conform to
/// `Differentiable` and satisfy `Self == Self.TangentVector`.
///
/// Typing rules, given:
/// - Original function type: $(T0, T1, ...) -> (R0, R1, ...)
///
/// Transpose function type:
/// - Takes non-linearity parameters, followed by original results, as
/// parameters.
/// - Returns linearity parameters.
///
/// A "constrained transpose generic signature" is computed from
/// `transposeFunctionGenericSignature`, if specified. Otherwise, it is
/// computed from the original generic signature. A "constrained transpose
/// generic signature" requires all linearity parameters to conform to
/// `Differentiable` and to satisfy `Self == Self.TangentVector`; this is
/// important for correctness.
///
/// This "constrained transpose generic signature" is used for
/// parameter/result type lowering. It is used as the actual generic signature
/// of the transpose function type iff the original function type has a
/// generic signature and not all generic parameters are bound to concrete
/// types. Otherwise, no transpose generic signature is used.
///
/// Other properties of the original function type are copied exactly:
/// `ExtInfo`, callee convention, witness method conformance, etc.
CanSILFunctionType getAutoDiffTransposeFunctionType(
IndexSubset *parameterIndices, Lowering::TypeConverter &TC,
LookupConformanceFn lookupConformance,
CanGenericSignature transposeFunctionGenericSignature = nullptr);

ExtInfo getExtInfo() const {
return ExtInfo(Bits.SILFunctionType.ExtInfoBits, getClangFunctionType());
}
Expand Down
13 changes: 11 additions & 2 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,29 @@ void autodiff::getSubsetParameterTypes(IndexSubset *subset,

GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
SILFunctionType *originalFnTy, IndexSubset *diffParamIndices,
GenericSignature derivativeGenSig) {
GenericSignature derivativeGenSig, LookupConformanceFn lookupConformance,
bool isTranspose) {
if (!derivativeGenSig)
derivativeGenSig = originalFnTy->getSubstGenericSignature();
if (!derivativeGenSig)
return nullptr;
// Constrain all differentiability parameters to `Differentiable`.
auto &ctx = originalFnTy->getASTContext();
auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
SmallVector<Requirement, 4> requirements;
for (unsigned paramIdx : diffParamIndices->getIndices()) {
// Require differentiability parameters to conform to `Differentiable`.
auto paramType = originalFnTy->getParameters()[paramIdx].getInterfaceType();
Requirement req(RequirementKind::Conformance, paramType,
diffableProto->getDeclaredType());
requirements.push_back(req);
if (isTranspose) {
// Require linearity parameters to additionally satisfy
// `Self == Self.TangentVector`.
auto tanSpace = paramType->getAutoDiffTangentSpace(lookupConformance);
auto paramTanType = tanSpace->getCanonicalType();
Requirement req(RequirementKind::SameType, paramType, paramTanType);
requirements.push_back(req);
}
}
return evaluateOrDefault(
ctx.evaluator,
Expand Down
94 changes: 91 additions & 3 deletions lib/SIL/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,13 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
if (isDiffParamIndex(valueAndIndex.index()))
diffParams.push_back(valueAndIndex.value());

// Get the canonical derivative function generic signature.
// Get the "constrained" derivative function generic signature.
if (!derivativeFnGenSig)
derivativeFnGenSig = getSubstGenericSignature();
derivativeFnGenSig = autodiff::getConstrainedDerivativeGenericSignature(
this, parameterIndices, derivativeFnGenSig).getCanonicalSignature();
derivativeFnGenSig =
autodiff::getConstrainedDerivativeGenericSignature(
this, parameterIndices, derivativeFnGenSig, lookupConformance)
.getCanonicalSignature();

// Given a type, returns its formal SIL parameter info.
auto getTangentParameterInfoForOriginalResult =
Expand Down Expand Up @@ -401,6 +403,92 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
ctx, getWitnessMethodConformanceOrInvalid());
}

CanSILFunctionType SILFunctionType::getAutoDiffTransposeFunctionType(
IndexSubset *parameterIndices, Lowering::TypeConverter &TC,
LookupConformanceFn lookupConformance,
CanGenericSignature transposeFnGenSig) {
// Get the "constrained" transpose function generic signature.
if (!transposeFnGenSig)
transposeFnGenSig = getSubstGenericSignature();
transposeFnGenSig = autodiff::getConstrainedDerivativeGenericSignature(
this, parameterIndices, transposeFnGenSig,
lookupConformance, /*isLinear*/ true)
.getCanonicalSignature();

// Given a type, returns its formal SIL parameter info.
auto getParameterInfoForOriginalResult =
[&](const SILResultInfo &result) -> SILParameterInfo {
AbstractionPattern pattern(transposeFnGenSig, result.getInterfaceType());
auto &tl = TC.getTypeLowering(pattern, result.getInterfaceType(),
TypeExpansionContext::minimal());
ParameterConvention newConv;
switch (result.getConvention()) {
case ResultConvention::Owned:
case ResultConvention::Autoreleased:
newConv = tl.isTrivial() ? ParameterConvention::Direct_Unowned
: ParameterConvention::Direct_Guaranteed;
break;
case ResultConvention::Unowned:
case ResultConvention::UnownedInnerPointer:
newConv = ParameterConvention::Direct_Unowned;
break;
case ResultConvention::Indirect:
newConv = ParameterConvention::Indirect_In_Guaranteed;
break;
}
return {result.getInterfaceType()->getCanonicalType(transposeFnGenSig),
newConv};
};

// Given a type, returns its formal SIL result info.
auto getResultInfoForOriginalParameter =
[&](const SILParameterInfo &param) -> SILResultInfo {
AbstractionPattern pattern(transposeFnGenSig, param.getInterfaceType());
auto &tl = TC.getTypeLowering(pattern, param.getInterfaceType(),
TypeExpansionContext::minimal());
ResultConvention newConv;
switch (param.getConvention()) {
case ParameterConvention::Direct_Owned:
case ParameterConvention::Direct_Guaranteed:
case ParameterConvention::Direct_Unowned:
newConv =
tl.isTrivial() ? ResultConvention::Unowned : ResultConvention::Owned;
break;
case ParameterConvention::Indirect_In:
case ParameterConvention::Indirect_Inout:
case ParameterConvention::Indirect_In_Constant:
case ParameterConvention::Indirect_In_Guaranteed:
case ParameterConvention::Indirect_InoutAliasable:
newConv = ResultConvention::Indirect;
break;
}
return {param.getInterfaceType()->getCanonicalType(transposeFnGenSig),
newConv};
};

SmallVector<SILParameterInfo, 4> newParameters;
SmallVector<SILResultInfo, 4> newResults;
for (auto param : llvm::enumerate(getParameters())) {
if (parameterIndices->contains(param.index()))
newResults.push_back(getResultInfoForOriginalParameter(param.value()));
else
newParameters.push_back(param.value());
}
for (auto &res : getResults())
newParameters.push_back(getParameterInfoForOriginalResult(res));
// Transpose function type has a generic signature only if the original
// function type does, and if `transposeFnGenSig` does not have all concrete
// generic parameters.
CanGenericSignature canGenSig;
if (getSubstGenericSignature() && transposeFnGenSig &&
!transposeFnGenSig->areAllParamsConcrete())
canGenSig = transposeFnGenSig;
return SILFunctionType::get(
canGenSig, getExtInfo(), getCoroutineKind(), getCalleeConvention(),
newParameters, getYields(), newResults, getOptionalErrorResult(),
getSubstitutions(), isGenericSignatureImplied(), getASTContext());
}

static CanType getKnownType(Optional<CanType> &cacheSlot, ASTContext &C,
StringRef moduleName, StringRef typeName) {
if (!cacheSlot) {
Expand Down

0 comments on commit c2ac96f

Please sign in to comment.