Last updated: March 2019
Automatic differentiation and differentiable programming are being incubated in the 'tensorflow' branch of apple/swift and released as part of the Swift for TensorFlow toolchains, which you can play with. The authors will propose this feature through Swift Evolution in 2019.
Speaking in terms of elementary calculus, only functions are "differentiable": only functions have derivatives and can be differentiated. In this document, the terminology "differentiable types" is used as a shorthand for "types that can be used as arguments and results of differentiable functions". This notion is important because not all types are "differentiable" in this sense. For example, types representing real numbers and vector spaces are "differentiable", but strings and integers are not.
Elementary calculus defines differentiation on real numbers: most people are familiar with this definition of "differentiation". However, differentiation is defined for many concepts across different branches of mathematics:
- Scalar differentiation: differentiation on real numbers. This is taught in introductory calculus.
- Vector calculus: a branch of mathematics that involves differentiation of vector fields.
- Differential geometry: a branch of mathematics that involves differentiation of functions over manifolds.
In Swift, we want to build a general system for differentiation that can represent all of these cases. Differentiation should not be limited to functions over specific types (e.g. functions over floating-point numbers); it should also work with functions whose parameters/result are custom types.
This raises some questions: what kind of types can be used as arguments and results of differentiable functions, and how can we generalize them using a protocol?
The Differentiable
protocol generalizes all types that can be used as arguments and results of differentiable functions.
The compiler can automatically provide default implementations of Differentiable
protocol requirements for struct types whose stored properties all conform to Differentiable
.
Here are some examples:
struct Vector: Differentiable, VectorNumeric {
// The compiler synthesizes all `Differentiable` protocol requirements
// when all stored properties conform to `Differentiable`.
var x, y, z: Float
}
// Differential operators like `gradient(at:in:)` just work!
let v = Vector(x: 1, y: 2, z: 3)
let 𝛁v = gradient(at: v) { v in (v + v).x }
print(𝛁v)
// Vector(x: 2.0, y: 0.0, z: 0.0)
A Differentiable
-conforming type may have stored properties that are not meant to have a derivative with respect to self
. Use the @noDerivative
attribute to mark those properties; they will not have a corresponding entry in the synthesized TangentVector
, CotangentVector
, and AllDifferentiableVariables
struct types.
Here’s an example deep learning layer with some @noDerivative
properties:
struct DenseLayer: Differentiable {
// These properties should have derivative values.
var weight: Tensor<Float>
var bias: Tensor<Float>
// These auxiliary properties should not have derivative values.
// Thus, they are marked with `@noDerivative`.
//
// `@noDerivative` properties do not have a corresponding entry in synthesized associated struct
// types.
@noDerivative var useBias: Bool = true
@noDerivative var previousWeight: Tensor<Float> = Tensor(0)
init(weight: Tensor<Float>, bias: Tensor<Float>) {
self.weight = weight
self.bias = bias
}
// The compiler synthesizes all `Differentiable` protocol requirements, adding only properties
// not marked with `@noDerivative` to associated tangent space types.
func call(_ input: Tensor<Float>) -> Tensor<Float> {
return matmul(input, weight) + bias
}
}
// Differential operators like `gradient(at:in:)` just work!
let dense = DenseLayer(weight: [[1, 1], [1, 1]], bias: [0, 0])
let 𝛁dense = gradient(at: dense) { dense in dense([[3, 3]]).sum() }
dump(𝛁dense)
// ▿ DenseLayer.AllDifferentiableVariables
// - weight: [[3.0, 3.0], [3.0, 3.0]]
// - bias: [1.0, 1.0]
Here is the full Differentiable
protocol definition. More explanation is provided below.
/// A type that mathematically represents a differentiable manifold whose
/// tangent spaces are finite-dimensional.
public protocol Differentiable {
/// The tangent bundle of this differentiable manifold.
associatedtype TangentVector: AdditiveArithmetic & Differentiable
where TangentVector.TangentVector == TangentVector,
TangentVector.CotangentVector == CotangentVector,
TangentVector.AllDifferentiableVariables == TangentVector
/// The cotangent bundle of this differentiable manifold.
associatedtype CotangentVector: AdditiveArithmetic & Differentiable
where CotangentVector.TangentVector == CotangentVector,
CotangentVector.CotangentVector == TangentVector,
CotangentVector.AllDifferentiableVariables == CotangentVector
/// The type of all differentiable variables in this type.
associatedtype AllDifferentiableVariables: Differentiable
where AllDifferentiableVariables.AllDifferentiableVariables == AllDifferentiableVariables,
AllDifferentiableVariables.TangentVector == TangentVector,
AllDifferentiableVariables.CotangentVector == CotangentVector
/// All differentiable variables in this type.
var allDifferentiableVariables: AllDifferentiableVariables { get }
/// Returns `self` moved along the value space towards the given tangent vector.
/// In Riemannian geometry (mathematics), this represents exponential map.
func moved(along direction: TangentVector) -> Self
/// Converts a cotangent vector to its corresponding tangent vector.
func tangentVector(from cotangent: CotangentVector) -> TangentVector
}
Mathematically, Differentiable
represents a differentiable manifold: this is a technical term for smooth-surfaced objects like spheres and generalizes types that are compatible with differentiation, like Float
, Double
, Tensor
, and SIMD4<Float>
. This definition comes from differential geometry and is quite technical, and not all details are relevant for most use cases.
Image showing two differentiable manifolds: a sphere and a spheroid.
From https://en.wikipedia.org/wiki/Pushforward_(differential).
Here is a detailed explanation of the Differentiable
protocol:
associatedtype TangentVector
represents the type of directional derivatives computed via forward-mode differentiation.associatedtype CotangentVector
represents the type of gradient values computed via reverse-mode differentiation.CotangentVector
types are used and produced by differential operators likegradient
andpullback
.
var allDifferentiableVariables: AllDifferentiableVariables
represents all differentiable variables in an instance of the conforming type, whereassociatedtype AllDifferentiableVariables
is the type of all differentiable variables.- The motivation/design behind "all differentiable variables" is enabling key-path-based parameter optimization by making parameters and their gradients have the same type. Read the synthesis rules below and the parameter optimization document for more information.
TangentVector
,CotangentVector
, andAllDifferentiableVariables
are closely related.- All three associated types must themselves conform to
Differentiable
. - The
Differentiable
protocol associated types of the associated types themselves are defined to be mathematically correct.Foo.TangentVector.TangentVector
isFoo.TangentVector
itself.Foo.CotangentVector.TangentVector
isFoo.CotangentVector
itself.Foo.TangentVector.CotangentVector
isFoo.CotangentVector
.Foo.CotangentVector.CotangentVector
isFoo.TangentVector
.Foo.AllDifferentiableVariables
has the sameTangentVector
andCotangentVector
asFoo
.
- Additionally,
TangentVector
andCotangentVector
must conform toAdditiveArithmetic
, so that they can be zero-initialized and accumulated via addition. These are necessary to perform the chain rule of differentiation.
- All three associated types must themselves conform to
- Manifold operations.
- These currently involve
tangentVector(from:)
andmoved(along:)
. These operations can be useful for implementing manifold-related algorithms, like optimization on manifolds, but are not relevant for simple differentiation use cases.
- These currently involve
The standard library defines conformances to the Differentiable
protocol for Float
, Double
, and Float80
. Conditional conformances will be added to floating-point SIMD vector types. The Tensor
type defined in the TensorFlow library also conditionally conforms to Differentiable
:
extension Float: Differentiable {
public typealias TangentVector = Float
public typealias CotangentVector = Float
public typealias AllDifferentiableVariables = Float
}
// Conformances for `Double` and `Float80` are defined similarly.
// `Tensor` is defined in the TensorFlow library and represents a multidimensional array.
extension Tensor: Differentiable where Scalar: TensorFlowFloatingPoint {
public typealias TangentVector = Tensor
public typealias CotangentVector = Tensor
public typealias AllDifferentiableVariables = Tensor
}
As shown above, the compiler automatically synthesizes implementations of Differentiable
requirements for struct types.
Here are the current conditions for synthesis:
- The type must declare a conformance to
Differentiable
, either on the type declaration or on an extension in the same file. - The conforming type must be a
struct
. - All stored properties of the conforming type must either conform to
Differentiable
or be marked with the@noDerivative
attribute.- If a non-
Differentiable
stored property is not marked with@noDerivative
, then it is treated as if it has@noDerivative
and the compiler emits a warning (with a fix-it in IDEs) asking the user to make the attribute explicit.
- If a non-
The synthesis behavior is explained below.
Here are the synthesis rules for the three Differentiable
associated types: TangentVector
, CotangentVector
, and AllDifferentiableVariables
.
Let "differentiation properties" refer to all stored properties of the conforming type that are not marked with @noDerivative
. These stored properties are guaranteed by the synthesis condition to all conform to Differentiable
.
The synthesis rules are:
- Set associated types to
Self
, if possible.- If the conforming type conforms to
AdditiveArithmetic
, and no@noDerivative
stored properties exist, and all stored properties satisfySelf == Self.TangentVector == Self.CotangentVector == Self.AllDifferentiableVariables
, then all associated types can be set to typealiases ofSelf
.
- If the conforming type conforms to
- Synthesize a single
AllDifferentiableVariables
member struct. SetTangentVector
andCotangentVector
toAllDifferentiableVariables
if possible; otherwise synthesize more member structs.- Regarding member struct synthesis: for each "differentiation property" in the conforming type, a corresponding stored property is synthesized in the member structs, with type equal to the property’s associated type.
TangentVector
andCotangentVector
can be set toAllDifferentiableVariables
if all differentiation properties conform toAdditiveArithmetic
and satisfySelf.TangentVector == Self.CotangentVector == Self.AllDifferentiableVariables
. This is useful because it prevents redundant struct synthesis. Also, this enables key-path-based parameter optimization because parameters and gradients have the same type.
A memberwise initializer is synthesized for the conforming type itself, in addition to all associated structs. This is important for differentiating struct properties accesses and synthesizing manifold operation requirements.
var allDifferentiableVariables: AllDifferentiableVariables
is synthesized as a computed property that mirrors the differentiation properties of the conforming type.
- It is always synthesized with a getter.
- It is synthesized with a setter only when all differentiation properties are mutable and themselves all have mutable
allDifferentiableVariables
properties.
Examples:
// Example when `AllDifferentiableVariables == Self`.
var allDifferentiableVariables: AllDifferentiableVariables {
get { return self }
set { return newValue }
}
// Example when `AllDifferentiableVariables != Self`.
var allDifferentiableVariables: AllDifferentiableVariables {
get { return AllDifferentiableVariables(x: x, y: y, ...) }
set { x = newValue.x; y = newValue.y; ... }
}
Manifold operations are synthesized to forward the same operation defined on differentiation properties:
// Let `Foo` be the name of the type conforming to `Differentiable`.
func tangentVector(from cotangent: CotangentVector) -> TangentVector {
return TangentVector(x: x.tangentVector(from: cotangent.x), ...)
}
func moved(along tangent: TangentVector) -> Foo {
return Foo(x: x.moved(along: tangent.x), ...)
}
// Potential shortcuts for synthesis:
// When `TangentVector == CotangentVector`:
func tangentVector(from cotangent: CotangentVector) -> TangentVector {
return cotangent
}
// When `Foo == TangentVector`:
func moved(along tangent: TangentVector) -> Foo {
return tangent
}
Let’s look at a complete example:
struct GenericWrapper<T: Differentiable, U: Differentiable>: Differentiable {
// `x` and `y` are the "differentiation properties".
var x: T
var y: U
@noDerivative var customFlag: Bool
@noDerivative var helperVariable: T
// The compiler synthesizes:
//
// struct TangentVector: Differentiable, AdditiveArithmetic {
// var x: T.TangentVector
// var y: U.TangentVector
// ...
// }
// struct CotangentVector: Differentiable, AdditiveArithmetic {
// var x: T.CotangentVector
// var y: U.CotangentVector
// ...
// }
// struct AllDifferentiableVariables: Differentiable {
// var x: T.AllDifferentiableVariables
// var y: U.AllDifferentiableVariables
// ...
// }
// var allDifferentiableVariables: AllDifferentiableVariables {
// get { return AllDifferentiableVariables(x: x, y: y) }
// set { x = newValue.x; y = newValue.y }
// }
// func tangentVector(from cotangent: CotangentVector) -> TangentVector {
// return TangentVector(x: x.tangentVector(from: cotangent.x),
// y: y.tangentVector(from: cotangent.y))
// }
// func moved(along tangent: TangentVector) -> Foo {
// return GenericWrapper(x: x.moved(along: tangent.x)
// y: y.moved(along: tangent.y))
// }
}
The authors would like to thank Casey Chu, Dougal Maclaurin, Matthew Johnson, Roy Frostig, Gordon Plotkin, Marc Rasi, Steve Canon, and James Bradbury for their input to the design of the Differentiable
protocol.