Skip to content

Commit

Permalink
Add support for partial_apply and convert_escape_to_noescape (tensorf…
Browse files Browse the repository at this point in the history
…low#263)

* Add support for partial_apply and convert_escape_to_noescape

* Review feedback

* Remove .contains
  • Loading branch information
Adam Paszke authored and pschuh committed Sep 6, 2019
1 parent c39ec75 commit 6428965
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 11 deletions.
28 changes: 26 additions & 2 deletions Sources/SIL/SIL.swift
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ public enum Instruction {
// cond_fail %9 : $Builtin.Int1, "arithmetic overflow"
case condFail(_ operand: Operand, _ message: String)

// https://github.com/apple/swift/blob/master/docs/SIL.rst#convert-escape-to-noescape
// convert_escape_to_noescape [not_guaranteed] %29 : $@callee_guaranteed () -> Bool to $@noescape @callee_guaranteed () -> Bool
case convertEscapeToNoescape(_ notGuaranteed: Bool, _ escaped: Bool, _ operand: Operand, _ type: Type)

// https://github.com/apple/swift/blob/master/docs/SIL.rst#copy-addr
// copy_addr %1 to [initialization] %33 : $*Self
case copyAddr(_ take: Bool, _ value: String, _ initialization: Bool, _ operand: Operand)
Expand Down Expand Up @@ -186,12 +190,19 @@ public enum Instruction {

// https://github.com/apple/swift/blob/master/docs/SIL.rst#load
// load %117 : $*Optional<Int>
case load(_ operand: Operand)
case load(_ kind: LoadOwnership?, _ operand: Operand)

// https://github.com/apple/swift/blob/master/docs/SIL.rst#metatype
// metatype $@thin Int.Type
case metatype(_ type: Type)

// https://github.com/apple/swift/blob/master/docs/SIL.rst#partial-apply
// partial_apply [callee_guaranteed] [on_stack] %27(%28) : $@convention(thin) (@guaranteed Int) -> Bool
case partialApply(
_ calleeGuaranteed: Bool, _ onStack: Bool, _ value: String,
_ substitutions: [Type], _ arguments: [String], _ type: Type
)

// https://github.com/apple/swift/blob/master/docs/SIL.rst#pointer-to-address
// pointer_to_address %4 : $Builtin.RawPointer to [strict] $*Int
case pointerToAddress(_ operand: Operand, _ strict: Bool, _ type: Type)
Expand All @@ -202,7 +213,7 @@ public enum Instruction {

// https://github.com/apple/swift/blob/master/docs/SIL.rst#store
// store %88 to %89 : $*StrideTo<Int>
case store(_ value: String, _ operand: Operand)
case store(_ value: String, _ kind: StoreOwnership?, _ operand: Operand)

// https://github.com/apple/swift/blob/master/docs/SIL.rst#string-literal
// string_literal utf8 "Fatal error"
Expand Down Expand Up @@ -450,3 +461,16 @@ public enum TypeRequirement: Equatable {
case conformance(_ lhs: Type, _ rhs: Type)
case equality(_ lhs: Type, _ rhs: Type)
}

// Reverse-engineered from -emit-silgen
public enum LoadOwnership {
case copy
case take
case trivial
}

// Reverse-engineered from -emit-silgen
public enum StoreOwnership {
case `init`
case trivial
}
10 changes: 6 additions & 4 deletions Sources/SIL/SILAnalysis.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ extension Instruction {
}
switch self {
case .allocStack(_, _): return []
case let .apply(_, value, _, arguments, _): return arguments + [value]
case let .apply(_, value, _, arguments, _): return [value] + arguments
case .beginAccess(_, _, _, _, _): return []
case let .beginApply(_, value, _, arguments, _): return arguments + [value]
case let .beginApply(_, value, _, arguments, _): return [value] + arguments
case let .beginBorrow(operand): return [operand.value]
case let .br(_, operands): return unwrap(operands)
case let .builtin(_, operands, _): return unwrap(operands)
case let .condBr(cond, _, trueOperands, _, falseOperands):
return unwrap(trueOperands) + unwrap(falseOperands) + [cond]
case let .condFail(operand, _): return [operand.value]
case let .convertEscapeToNoescape(_, _, operand, _): return [operand.value]
case let .copyAddr(_, value, _, operand): return [value, operand.value]
case let .copyValue(operand): return [operand.value]
case let .deallocStack(operand): return [operand.value]
Expand All @@ -29,11 +30,12 @@ extension Instruction {
case .functionRef(_, _): return []
case let .indexAddr(addr, index): return [addr.value, index.value]
case .integerLiteral(_, _): return []
case let .load(operand): return [operand.value]
case let .load(_, operand): return [operand.value]
case .metatype(_): return []
case let .partialApply(_, _, value, _, arguments, _): return [value] + arguments
case let .pointerToAddress(operand, _, _): return [operand.value]
case let .return(operand): return [operand.value]
case let .store(value, operand): return [value, operand.value]
case let .store(value, _, operand): return [value, operand.value]
case .stringLiteral(_, _): return []
case let .struct(_, operands): return unwrap(operands)
case let .structElementAddr(operand, _): return [operand.value]
Expand Down
39 changes: 36 additions & 3 deletions Sources/SIL/SILParser.swift
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ class SILParser: Parser {
try take(",")
let message = try parseString()
return .condFail(operand, message)
case "convert_escape_to_noescape":
let notGuaranteed = skip("[not_guaranteed]")
let escaped = skip("[escaped]")
let operand = try parseOperand()
try take("to")
let type = try parseType()
return .convertEscapeToNoescape(notGuaranteed, escaped, operand, type)
case "copy_addr":
let take = skip("[take]")
let value = try parseValue()
Expand Down Expand Up @@ -183,11 +190,28 @@ class SILParser: Parser {
let value = try parseInt()
return .integerLiteral(type, value)
case "load":
var ownership: LoadOwnership?
if skip("[copy]") {
ownership = .copy
} else if skip("[take]") {
ownership = .take
} else if skip("[trivial]") {
ownership = .trivial
}
let operand = try parseOperand()
return .load(operand)
return .load(ownership, operand)
case "metatype":
let type = try parseType()
return .metatype(type)
case "partial_apply":
let calleeGuaranteed = skip("[callee_guaranteed]")
let onStack = skip("[on_stack]")
let value = try parseValue()
let substitutions = try parseNilOrMany("<", ",", ">") { try parseNakedType() } ?? []
let arguments = try parseMany("(", ",", ")") { try parseValue() }
try take(":")
let type = try parseType()
return .partialApply(calleeGuaranteed, onStack, value, substitutions, arguments, type)
case "pointer_to_address":
let operand = try parseOperand()
try take("to")
Expand All @@ -200,9 +224,14 @@ class SILParser: Parser {
case "store":
let value = try parseValue()
try take("to")
let _ = skip("[trivial]") // Used in ownership SSA
var ownership: StoreOwnership?
if skip("[init]") {
ownership = .init
} else if skip("[trivial]") {
ownership = .trivial
}
let operand = try parseOperand()
return .store(value, operand)
return .store(value, ownership, operand)
case "string_literal":
let encoding = try parseEncoding()
let value = try parseString()
Expand Down Expand Up @@ -485,6 +514,10 @@ class SILParser: Parser {
} else if skip("*") {
let type = try parseNakedType()
return .addressType(type)
} else if skip("[") {
let subtype = try parseNakedType()
try take("]")
return .specializedType(.namedType("Array"), [subtype])
} else if peek("(") {
let types = try parseMany("(", ",", ")") { try parseNakedType() }
if skip("->") {
Expand Down
43 changes: 41 additions & 2 deletions Sources/SIL/SILPrinter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ class SILPrinter: Printer {
print(operand)
print(", ")
literal(message)
case let .convertEscapeToNoescape(notGuaranteed, escaped, operand, type):
print("convert_escape_to_noescape ")
print(when: notGuaranteed, "[not_guaranteed] ")
print(when: escaped, "[escaped] ")
print(operand)
print(" to ")
print(type)
case let .copyAddr(take, value, initialization, operand):
print("copy_addr ")
print(when: take, "[take] ")
Expand Down Expand Up @@ -157,12 +164,25 @@ class SILPrinter: Printer {
print(type)
print(", ")
literal(value)
case let .load(operand):
case let .load(maybeOwnership, operand):
print("load ")
if let ownership = maybeOwnership {
print(ownership)
print(" ")
}
print(operand)
case let .metatype(type):
print("metatype ")
print(type)
case let .partialApply(calleeGuaranteed, onStack, value, substitutions, arguments, type):
print("partial_apply ")
print(when: calleeGuaranteed, "[callee_guaranteed] ")
print(when: onStack, "[on_stack] ")
print(value)
print(whenEmpty: false, "<", substitutions, ", ", ">") { naked($0) }
print("(", arguments, ", ", ")") { print($0) }
print(" : ")
print(type)
case let .pointerToAddress(operand, strict, type):
print("pointer_to_address ")
print(operand)
Expand All @@ -172,10 +192,14 @@ class SILPrinter: Printer {
case let .return(operand):
print("return ")
print(operand)
case let .store(value, operand):
case let .store(value, maybeOwnership, operand):
print("store ")
print(value)
print(" to ")
if let ownership = maybeOwnership {
print(ownership)
print(" ")
}
print(operand)
case let .stringLiteral(encoding, value):
print("string_literal ")
Expand Down Expand Up @@ -539,6 +563,21 @@ class SILPrinter: Printer {
naked(rhs)
}
}

func print(_ ownership: LoadOwnership) {
switch ownership {
case .copy: print("[copy]")
case .take: print("[take]")
case .trivial: print("[trivial]")
}
}

func print(_ ownership: StoreOwnership) {
switch ownership {
case .`init`: print("[init]")
case .trivial: print("[trivial]")
}
}
}

extension Module: CustomStringConvertible {
Expand Down
1 change: 1 addition & 0 deletions Tests/LinuxMain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ let silTests = [
testCase(BitsTests.allTests),
testCase(BitstreamTests.allTests),
testCase(SILTests.DescriptionTests.allTests),
testCase(SILParserTests.allTests),
testCase(InstructionTests.allTests),
testCase(ModuleTests.allTests),
testCase(PrinterTests.allTests),
Expand Down
7 changes: 7 additions & 0 deletions Tests/SILTests/InstructionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ let instructionDefs = [
"br label(%0 : $A, %1 : $B)",
"cond_br %11, bb3, bb2",
"cond_br %12, label(%0 : $A), label(%1 : $B)",
"%94 = convert_escape_to_noescape [not_guaranteed] %93 : $@callee_guaranteed () -> Bool to $@noescape @callee_guaranteed () -> Bool",
"copy_addr %1 to [initialization] %33 : $*Self",
"dealloc_stack %162 : $*IndexingIterator<Range<Int>>",
"debug_value %1 : $Array<Float>, let, name \"input\", argno 2",
Expand All @@ -48,9 +49,15 @@ let instructionDefs = [
"%7 = pointer_to_address %6 : $Builtin.RawPointer to [strict] $*Int",
"%7 = pointer_to_address %6 : $Builtin.RawPointer to $*Int",
"load %117 : $*Optional<Int>",
"%22 = load [copy] %21 : $*TensorShape",
"%71 = load [take] %52 : $*Zip2Sequence<Array<Int>, Array<Int>>",
"%84 = load [trivial] %79 : $*Optional<(Int, Int)>",
"metatype $@thick Self.Type",
"metatype $@thin Int.Type",
"%4 = partial_apply [callee_guaranteed] %2<Scalar>(%3) : $@convention(thin) <τ_0_0 where τ_0_0 : TensorFlowScalar> (@guaranteed Tensor<τ_0_0>) -> Bool",
"store %88 to %89 : $*StrideTo<Int>",
"store %88 to [trivial] %112 : $*Int",
"store %152 to [init] %155 : $*ArraySlice<Int>",
"string_literal utf8 \"Fatal error\"",
// TODO(#24): Parse string literals with control characters.
// "string_literal utf8 \"\\n\"",
Expand Down
23 changes: 23 additions & 0 deletions Tests/SILTests/SILParserTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import Foundation
import XCTest
@testable import SIL

public final class SILParserTests: XCTestCase {
public func testArrayDesugar() {
let instr = "%149 = apply %148<[Int], PartialRangeFrom<Int>>(%143, %146, %144) : $@convention(method) <τ_0_0 where τ_0_0 : MutableCollection><τ_1_0 where τ_1_0 : RangeExpression, τ_0_0.Index == τ_1_0.Bound> (@in_guaranteed τ_1_0, @in_guaranteed τ_0_0) -> @out τ_0_0.SubSequence"
let parser = SILParser(forString: instr)
do {
let def = try parser.parseInstructionDef()
XCTAssertEqual(def.description, instr.replacingOccurrences(of: "[Int]", with: "Array<Int>"))
} catch {
XCTFail("Failed to parse the instruction def: \(error)")
}
}
}

extension SILParserTests {
public static let allTests = [
("testArrayDesugar", testArrayDesugar),
]
}

0 comments on commit 6428965

Please sign in to comment.