Skip to content

Commit

Permalink
Update type protocols for new structure
Browse files Browse the repository at this point in the history
  • Loading branch information
aidangomez committed Dec 14, 2015
1 parent 4d936a7 commit 1d33b74
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 7 deletions.
48 changes: 45 additions & 3 deletions Source/1D/LinearType.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
// THE SOFTWARE.

/// The `LinearType` protocol should be implemented by any collection that stores its values in a contiguous memory block. This is the building block for one-dimensional operations that are single-instruction, multiple-data (SIMD).
public protocol LinearType : CollectionType {
public protocol LinearType: TensorType {
typealias Element

/// The pointer to the beginning of the memory block
Expand All @@ -40,18 +40,60 @@ public extension LinearType {
public var count: Int {
return (endIndex - startIndex + step - 1) / step
}

public var dimensions: [Int] {
return [count]
}
}

internal extension LinearType {
var span: Span {
return Span(start: [startIndex], end: [endIndex])
}

func indexIsValid(index: Int) -> Bool {
return startIndex <= index && index < endIndex
}
}

public protocol MutableLinearType : LinearType, MutableCollectionType {
public protocol MutableLinearType : LinearType, MutableTensorType {
/// The mutable pointer to the beginning of the memory block
var mutablePointer: UnsafeMutablePointer<Element> { get }
}

extension Array : LinearType {
public typealias Slice = ArraySlice<Element>

public var step: Int {
return 1
}


public subscript(indices: [Int]) -> Element {
get {
assert(indices.count == 1)
return self[indices[0]]
}
set {
assert(indices.count == 1)
self[indices[0]] = newValue
}
}

public subscript(intervals: [IntervalType]) -> Slice {
get {
assert(indices.count == 1)
let start = intervals[0].start ?? startIndex
let end = intervals[0].end ?? endIndex
return self[start..<end]
}
set {
assert(indices.count == 1)
let start = intervals[0].start ?? startIndex
let end = intervals[0].end ?? endIndex
self[start..<end] = newValue
}
}

public var pointer: UnsafePointer<Element> {
return withUnsafeBufferPointer{ return $0.baseAddress }
}
Expand Down
19 changes: 15 additions & 4 deletions Source/2D/QuadraticType.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ public enum QuadraticArrangement {
case ColumnMajor
}

public protocol QuadraticType {
public protocol QuadraticType: TensorType {
typealias Element

/// The arrangement of rows and columns
var arragement: QuadraticArrangement { get }
var arrangement: QuadraticArrangement { get }

/// The pointer to the beginning of the memory block
var pointer: UnsafePointer<Element> { get }

/// The number of rows
var rows: Int { get }

Expand All @@ -43,16 +43,27 @@ public protocol QuadraticType {

/// The step size between major-axis elements
var stride: Int { get }

/// The step of the base elements
var step: Int { get }
}

public extension QuadraticType {
/// The number of valid element in the memory block, taking into account the step size.
public var count: Int {
return rows * columns
}

public var dimensions: [Int] {
if arrangement == .RowMajor {
return [rows, columns]
} else {
return [columns, rows]
}
}
}

public protocol MutableQuadraticType : QuadraticType {
public protocol MutableQuadraticType: QuadraticType, MutableTensorType {
/// The mutable pointer to the beginning of the memory block
var mutablePointer: UnsafeMutablePointer<Element> { get }
}
78 changes: 78 additions & 0 deletions Source/ND/TensorType.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright © 2015 Venture Media Labs.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

public protocol TensorType {
typealias Element
typealias Slice

/// The pointer to the beginning of the memory block
var pointer: UnsafePointer<Element> { get }

/// The count of each dimension
var dimensions: [Int] { get }

subscript(intervals: [IntervalType]) -> Slice { get }
subscript(intervals: [Int]) -> Element { get }
}

internal extension TensorType {
var span: Span {
return Span(zeroTo: dimensions)
}
}

public extension TensorType {
/// The number of valid element in the memory block, taking into account the step size.
public var count: Int {
return dimensions.reduce(1, combine: *)
}

/// The number of dimensions
public var rank: Int {
return dimensions.count
}

public func linearIndex(indices: [Int]) -> Int {
assert(indexIsValid(indices))
var index = indices[0]
for (i, dim) in dimensions[1..<dimensions.count].enumerate() {
index = (dim * index) + indices[i+1]
}
return index
}

public func indexIsValid(indices: [Int]) -> Bool {
assert(indices.count == dimensions.count)
for (i, index) in indices.enumerate() {
if index < span[i].startIndex || span[i].endIndex <= index {
return false
}
}
return true
}
}

public protocol MutableTensorType: TensorType {
/// The mutable pointer to the beginning of the memory block
var mutablePointer: UnsafeMutablePointer<Element> { get }

subscript(intervals: [IntervalType]) -> Slice { get set }
subscript(intervals: [Int]) -> Element { get set }
}

0 comments on commit 1d33b74

Please sign in to comment.