Skip to content

Commit

Permalink
Update Tensor and Matrix + Slices
Browse files Browse the repository at this point in the history
* added 2-D Tensor Slice class that has well defined Matrix operations
  • Loading branch information
aidangomez committed Dec 14, 2015
1 parent 6d738f9 commit 2e3b6a6
Show file tree
Hide file tree
Showing 5 changed files with 633 additions and 149 deletions.
219 changes: 219 additions & 0 deletions Source/2D/2DTensorSlice.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
// Copyright © 2015 Venture Media Labs.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.


public class TwoDimensionalTensorSlice<Element: Value> : MutableQuadraticType, Equatable {
public typealias Index = [Int]
public typealias Slice = TwoDimensionalTensorSlice<Element>

public var arrangement: QuadraticArrangement {
return .RowMajor
}

public let dimensions: [Int]
public var count: Int {
return dimensions.reduce(1, combine: *)
}

public let rows: Int
public let columns: Int
public var stride: Int

var base: Tensor<Element>
var span: Span

public var pointer: UnsafePointer<Element> {
return base.pointer
}

public var mutablePointer: UnsafeMutablePointer<Element> {
return base.mutablePointer
}

public var step: Int {
return base.elements.step
}

init(base: Tensor<Element>, span: Span) {
assert(span.dimensions.count == base.dimensions.count)
self.base = base
self.span = span
self.dimensions = span.dimensions

assert(base.spanIsValid(span))
assert(span.dimensions.reduce(0){ $0.1 > 1 ? $0.0 + 1 : $0.0 } <= 2)
assert(span.dimensions.last >= 1)

var rowIndex: Int
if let index = span.dimensions.indexOf({ $0 > 1 }) {
rowIndex = index
} else {
rowIndex = span.dimensions.count - 2
}
rows = span.dimensions[rowIndex]
columns = span.dimensions.last!

stride = span.dimensions.suffixFrom(rowIndex + 1).reduce(1, combine: *)
}

public subscript(row: Int, column: Int) -> Element {
get {
return self[[row, column]]
}
set {
self[[row, column]] = newValue
}
}

public subscript(indices: Index) -> Element {
get {
var index = span.startIndex
let indexReplacementRage: Range<Int> = span.startIndex.count - indices.count..<span.startIndex.count
index.replaceRange(indexReplacementRage, with: indices)
assert(indexIsValid(index))
return base[index]
}
set {
var index = span.startIndex
let indexReplacementRage: Range<Int> = span.startIndex.count - indices.count..<span.startIndex.count
index.replaceRange(indexReplacementRage, with: indices)
assert(indexIsValid(index))
base[index] = newValue
}
}

public subscript(slice: [IntervalType]) -> Slice {
get {
let span = Span(base: self.span, intervals: slice)
return self[span]
}
set {
let span = Span(base: self.span, intervals: slice)
assert(span newValue.span)
self[span] = newValue
}
}

public subscript(slice: IntervalType...) -> Slice {
get {
return self[slice]
}
set {
self[slice] = newValue
}
}

subscript(span: Span) -> Slice {
get {
assert(self.span.contains(span))
return Slice(base: base, span: span)
}
set {
assert(self.span.contains(span))
assert(span newValue.span)
for (lhsIndex, rhsIndex) in zip(span, newValue.span) {
base[lhsIndex] = newValue[rhsIndex]
}
}
}

public var isContiguous: Bool {
let onesCount: Int
if let index = dimensions.indexOf({ $0 != 1 }) {
onesCount = index
} else {
onesCount = rank
}

let diff = (0..<rank).map({ dimensions[$0] - base.dimensions[$0] }).reverse()
let fullCount: Int
if let index = diff.indexOf({ $0 != 0 }) where index.base < count {
fullCount = diff.startIndex.distanceTo(index)
} else {
fullCount = rank
}

return rank - fullCount - onesCount <= 1
}

public func indexIsValid(indices: [Int]) -> Bool {
assert(indices.count == dimensions.count)
for (i, index) in indices.enumerate() {
if index < span[i].startIndex || span[i].endIndex <= index {
return false
}
}
return true
}
}

// MARK: - Equatable

public func ==<T>(lhs: TwoDimensionalTensorSlice<T>, rhs: TwoDimensionalTensorSlice<T>) -> Bool {
assert(lhs.span rhs.span)
for (lhsIndex, rhsIndex) in zip(lhs.span, rhs.span) {
if lhs[lhsIndex] != rhs[rhsIndex] {
return false
}
}
return true
}

public func ==<T: Equatable>(lhs: TwoDimensionalTensorSlice<T>, rhs: TensorSlice<T>) -> Bool {
assert(lhs.span rhs.span)
for (lhsIndex, rhsIndex) in zip(lhs.span, rhs.span) {
if lhs[lhsIndex] != rhs[rhsIndex] {
return false
}
}
return true
}

public func ==<T: Equatable>(lhs: TwoDimensionalTensorSlice<T>, rhs: Tensor<T>) -> Bool {
assert(lhs.span rhs.span)
for (lhsIndex, rhsIndex) in zip(lhs.span, rhs.span) {
if lhs[lhsIndex] != rhs[rhsIndex] {
return false
}
}
return true
}

public func ==<T: Equatable>(lhs: TwoDimensionalTensorSlice<T>, rhs: Matrix<T>) -> Bool {
assert(lhs.span rhs.span)
for (lhsIndex, rhsIndex) in zip(lhs.span, rhs.span) {
if lhs[lhsIndex] != rhs[rhsIndex] {
return false
}
}
return true
}

public func ==<T: Equatable>(lhs: TwoDimensionalTensorSlice<T>, rhs: MatrixSlice<T>) -> Bool {
assert(lhs.span rhs.span)
for (lhsIndex, rhsIndex) in zip(lhs.span, rhs.span) {
if lhs[lhsIndex] != rhs[rhsIndex] {
return false
}
}
return true
}


116 changes: 91 additions & 25 deletions Source/2D/Matrix.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@

import Accelerate

public class Matrix<Element: Value> : MutableQuadraticType, Equatable, CustomStringConvertible {
public class Matrix<Element: Value>: MutableQuadraticType, Equatable, CustomStringConvertible {
public typealias Index = (Int, Int)
public typealias Slice = MatrixSlice<Element>

public var rows: Int
public var columns: Int
public var elements: ValueArray<Element>

var span: Span {
return Span(zeroTo: dimensions)
}

public var pointer: UnsafePointer<Element> {
return elements.pointer
}
Expand All @@ -35,13 +40,17 @@ public class Matrix<Element: Value> : MutableQuadraticType, Equatable, CustomStr
return elements.mutablePointer
}

public var arragement: QuadraticArrangement {
public var arrangement: QuadraticArrangement {
return .RowMajor
}

public var stride: Int {
return columns
}

public var step: Int {
return elements.step
}

/// Construct a Matrix from a `QuadraticType`
public init<M: QuadraticType where M.Element == Element>(_ quad: M) {
Expand Down Expand Up @@ -88,15 +97,58 @@ public class Matrix<Element: Value> : MutableQuadraticType, Equatable, CustomStr
elements.replaceRange(i*cols..<i*cols+min(cols, row.count), with: row)
}
}

public subscript(row: Int, column: Int) -> Element {

public subscript(indices: Int...) -> Element {
get {
return self[indices]
}
set {
self[indices] = newValue
}
}

public subscript(indices: [Int]) -> Element {
get {
assert(indices.count == 2)
assert(indexIsValidForRow(indices[0], column: indices[1]))
return elements[(indices[0] * columns) + indices[1]]
}
set {
assert(indices.count == 2)
assert(indexIsValidForRow(indices[0], column: indices[1]))
elements[(indices[0] * columns) + indices[1]] = newValue
}
}

public subscript(intervals: IntervalType...) -> Slice {
get {
return self[intervals]
}
set {
self[intervals] = newValue
}
}

public subscript(intervals: [IntervalType]) -> Slice {
get {
assert(indexIsValidForRow(row, column: column))
return elements[(row * columns) + column]
let span = Span(dimensions: dimensions, intervals: intervals)
return self[span]
}
set {
assert(indexIsValidForRow(row, column: column))
elements[(row * columns) + column] = newValue
let span = Span(dimensions: dimensions, intervals: intervals)
self[span] = newValue
}
}

subscript(span: Span) -> Slice {
get {
return MatrixSlice(base: self, span: span)
}
set {
assert(span newValue.span)
for (lhsIndex, rhsIndex) in zip(span, newValue.span) {
self[lhsIndex] = newValue[rhsIndex]
}
}
}

Expand All @@ -113,7 +165,7 @@ public class Matrix<Element: Value> : MutableQuadraticType, Equatable, CustomStr
return Matrix(rows: rows, columns: columns, elements: copy)
}

private func indexIsValidForRow(row: Int, column: Int) -> Bool {
public func indexIsValidForRow(row: Int, column: Int) -> Bool {
return row >= 0 && row < rows && column >= 0 && column < columns
}

Expand Down Expand Up @@ -141,30 +193,44 @@ public class Matrix<Element: Value> : MutableQuadraticType, Equatable, CustomStr
}
}

// MARK: - SequenceType
// MARK: - Equatable

extension Matrix: SequenceType {
public func generate() -> AnyGenerator<MutableSlice<ValueArray<Element>>> {
let endIndex = rows * columns
var nextRowStartIndex = 0
public func ==<T>(lhs: Matrix<T>, rhs: Matrix<T>) -> Bool {
return lhs.elements == rhs.elements
}

return anyGenerator {
if nextRowStartIndex == endIndex {
return nil
}
public func ==<T>(lhs: Matrix<T>, rhs: MatrixSlice<T>) -> Bool {
assert(lhs.span rhs.span)
for (lhsIndex, rhsIndex) in zip(lhs.span, rhs.span) {
if lhs[lhsIndex] != rhs[rhsIndex] {
return false
}
}
return true
}

let currentRowStartIndex = nextRowStartIndex
nextRowStartIndex += self.columns
public func ==<T>(lhs: Matrix<T>, rhs: Tensor<T>) -> Bool {
return lhs.elements == rhs.elements
}

return self.elements[currentRowStartIndex..<nextRowStartIndex]
public func ==<T>(lhs: Matrix<T>, rhs: TensorSlice<T>) -> Bool {
assert(lhs.span rhs.span)
for (lhsIndex, rhsIndex) in zip(lhs.span, rhs.span) {
if lhs[lhsIndex] != rhs[rhsIndex] {
return false
}
}
return true
}

// MARK: - Equatable

public func ==<T>(lhs: Matrix<T>, rhs: Matrix<T>) -> Bool {
return lhs.elements == rhs.elements
public func ==<T>(lhs: Matrix<T>, rhs: TwoDimensionalTensorSlice<T>) -> Bool {
assert(lhs.span rhs.span)
for (lhsIndex, rhsIndex) in zip(lhs.span, rhs.span) {
if lhs[lhsIndex] != rhs[rhsIndex] {
return false
}
}
return true
}

// MARK: -
Expand Down
Loading

0 comments on commit 2e3b6a6

Please sign in to comment.