-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathprimal.lisp
298 lines (268 loc) · 13.6 KB
/
primal.lisp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
;;; -*- Mode:Lisp; Syntax:ANSI-Common-Lisp; Coding:utf-8; Package:Varray -*-
;;;; primal.lisp
(in-package #:varray)
"Definitions of virtual arrays derived solely from their input parameters and not based on other arrays."
(defclass vapri-apro-vector (varray-primal)
((%number :accessor vapip-number
:initform 1
:initarg :number
:documentation "The number of values.")
(%origin :accessor vapip-origin
:initform 0
:initarg :origin
:documentation "The origin point - by default, the index origin.")
(%offset :accessor vapip-offset
:initform 0
:initarg :offset
:documentation "The offset - an amount added to or subtracted from each value.")
(%factor :accessor vapip-factor
:initform 1
:initarg :factor
:documentation "Factor of values.")
(%repeat :accessor vapip-repeat
:initform 1
:initarg :repeat
:documentation "Instances of each value."))
(:metaclass va-class)
(:documentation "Arithmetic progression vector - a series of numeric values generated by [⍳ index]."))
(defmethod etype-of ((vvector vapri-apro-vector))
(if (floatp (vapip-factor vvector))
'double-float (if (typep (vapip-factor vvector) 'ratio)
t (list 'integer (min (* (vapip-factor vvector)
(vapip-number vvector))
(+ (vapip-offset vvector)
(vapip-origin vvector)))
(max (vapip-offset vvector)
(+ (vapip-origin vvector)
(+ (vapip-offset vvector)
(* (vapip-factor vvector)
(first (shape-of vvector))))))))))
(defmethod prototype-of ((vvector vapri-apro-vector))
(declare (ignore vvector))
0)
;; the shape of an IP vector is its number times its repetition
(defmethod shape-of ((vvector vapri-apro-vector))
;; TODO: it's still possible to create something like ⍳¯5, the error doesn't happen
;; until it's rendered - is there a better way to implement this check?
(let ((number (setf (vapip-number vvector)
(render (vapip-number vvector)))))
(unless (and (integerp number) (or (zerop number) (plusp number)))
(error "The argument to [⍳ index] must be an integer 0 or higher."))
(get-promised (varray-shape vvector) (list (* number (vapip-repeat vvector))))))
;; the IP vector's parameters are used to index its contents
(defmethod generator-of ((vvector vapri-apro-vector) &optional indexers params)
(declare (ignore indexers) (optimize (speed 3) (safety 0)))
(let* ((converter #'identity)
(origin (the (unsigned-byte 62) (vapip-origin vvector)))
(offset (the fixnum (vapip-offset vvector)))
(factor (the real (vapip-factor vvector)))
(repeat (the (unsigned-byte 62) (vapip-repeat vvector)))
(indexer (funcall (if (or (and (integerp factor) (= 1 factor))
(and (typep factor 'single-float) (= 1.0 factor))
(and (typep factor 'double-float) (= 1.0d0 factor)))
(if (zerop offset)
#'identity (lambda (fn)
(declare (type function fn))
(lambda (item) (+ offset (funcall fn item)))))
(if (integerp factor)
(lambda (fn) (lambda (item)
(declare (type (unsigned-byte 62) item))
(+ offset
(* factor (funcall (the function fn) item)))))
(lambda (fn) (lambda (item)
(declare (type function fn))
(+ offset (* (the float factor)
(funcall fn item)))))))
(if (= 1 repeat)
(if (zerop origin)
#'identity
(lambda (index)
(declare (type (unsigned-byte 62) index))
;; (the (unsigned-byte 64) (+ origin index))
;; (print (list :oo (+ origin index)
;; (funcall converter (+ origin index))))
(funcall converter (+ origin index))))
(lambda (index)
(declare (type (unsigned-byte 62) index))
(+ origin (funcall converter (floor index repeat))))))))
(case (getf params :format)
(:encoded (setf (getf params :format) :linear)
(generator-of vvector nil params))
(:linear indexer)
(t indexer))))
(deftype fast-iota-sum-fixnum ()
"The largest integer that can be supplied to fast-iota-sum without causing a fixnum overflow"
'(integer 0 #.(isqrt (* 2 most-positive-fixnum))))
(declaim (ftype (function (fast-iota-sum-fixnum) fixnum) fast-iota-sum))
(defun fast-iota-sum (n)
"Fast version of iota-sum for integers of type fast-iota-sum-fixnum"
(declare (optimize (speed 3) (safety 0)))
(if (oddp n)
(* n (the fixnum (/ (1+ n) 2)))
(let ((n/2 (the fixnum (/ n 2))))
(+ (* n n/2) n/2))))
(defun iota-sum (n index-origin)
"Fast implementation of +/⍳X."
(cond ((< n 0)
(error "The argument to [⍳ index] must be a positive integer, i.e. ⍳9, or a vector, i.e. ⍳2 3."))
((= n 0) 0)
((= n 1) index-origin)
((typep n 'fast-iota-sum-fixnum)
(if (= index-origin 1) (fast-iota-sum n)
(fast-iota-sum (1- n))))
(t (* n (/ (+ n index-origin index-origin -1) 2)))))
(defmethod get-reduced ((vvector vapri-apro-vector) function)
(let ((fn-meta (funcall function :get-metadata)))
(case (getf fn-meta :lexical-reference)
(#\+ (iota-sum (vapip-number vvector) (vapip-origin vvector)))
;; TODO: extend below to support any ⎕IO
(#\× (sprfact (+ (vapip-number vvector) (- (vapip-origin vvector) 1))))
(t (let* ((generator (generator-of vvector))
(output (funcall generator 0)))
(loop :for i :from 1 :below (vapip-number vvector)
:do (setf output (funcall function (funcall generator i) output)))
output)))))
(defclass vapri-coordinate-vector (varray-primal)
((%reference :accessor vacov-reference
:initform nil
:initarg :reference
:documentation "The array to which this coordinate vector belongs.")
(%index :accessor vacov-index
:initform 0
:initarg :index
:documentation "The row-major index of the referenced array this coordinate vector represents."))
(:metaclass va-class)
(:documentation "Coordinate vector - a vector of the integer coordinates corresponding to a given row-major index in an array."))
(defmethod etype-of ((vvector vapri-coordinate-vector))
"The type of the coordinate vector."
;; if this refers to a [⍸ where] invocation, it is based on the shape of the argument to [⍸ where];
;; it cannot directly reference the argument because the [⍸ where] invocation" because the dimensional
;; factors are stored along with the [⍸ where] object
(list 'integer 0 (reduce #'max (if (typep (vacov-reference vvector) 'vader-where)
(shape-of (vader-base (vacov-reference vvector)))
(shape-of (vacov-reference vvector))))))
(defmethod prototype-of ((vvector vapri-coordinate-vector))
(declare (ignore vvector))
0)
(defmethod shape-of ((vvector vapri-coordinate-vector))
(get-promised (varray-shape vvector)
(list (length (vads-dfactors (vacov-reference vvector))))))
(defmethod generator-of ((vvector vapri-coordinate-vector) &optional indexers params)
(declare (ignore indexers))
(let* ((dfactors (vads-dfactors (vacov-reference vvector)))
(output (make-array (length dfactors) :element-type (etype-of vvector)))
(remaining (vacov-index vvector)))
(loop :for f :across dfactors :for ix :from 0
:do (multiple-value-bind (item remainder) (floor remaining f)
(setf (aref output ix) (+ item (vads-io (vacov-reference vvector)))
remaining remainder)))
(case (getf params :base-format)
(:encoded)
(:linear)
(t (lambda (index) (aref output index))))))
(defclass vapri-coordinate-identity (vad-nested varray-primal vad-with-io vad-with-dfactors)
((%shape :accessor vapci-shape
:initform 1
:initarg :number
:documentation "The shape of the array."))
(:metaclass va-class)
(:documentation "Coordinate identity array - an array of coordinate vectors generated by [⍳ index]."))
(defmethod etype-of ((varray vapri-coordinate-identity))
"Being a nested array, the type is always t."
(declare (ignore varray))
t)
(defmethod prototype-of ((varray vapri-coordinate-identity))
"Prototype is an array of zeroes with length equal to the array's rank."
(make-array (length (varray-shape varray)) :element-type 'bit :initial-element 0))
(defmethod shape-of ((varray vapri-coordinate-identity))
"Shape is explicit; dimensional factors are generated by this function if not set."
(unless (vads-dfactors varray)
(let ((rendered-shape (coerce (render (vapci-shape varray)) 'list)))
(setf (varray-shape varray) rendered-shape
(vads-dfactors varray) (strides-of rendered-shape t))))
(varray-shape varray))
(defmethod generator-of ((varray vapri-coordinate-identity) &optional indexers params)
"Each index returns a coordinate vector."
(declare (ignore indexers))
(case (getf params :base-format)
(:encoded)
(:linear)
(t (lambda (index) (make-instance 'vapri-coordinate-vector
:reference varray :index index)))))
(defclass vapri-onehot-vector (varray-primal)
((%index :accessor vaohv-index
:initform 0
:initarg :index
:documentation "The index equal to one."))
(:metaclass va-class)
(:documentation "One-hot vector - a binary vector where one element is set to 1."))
(defmethod etype-of ((varray vapri-onehot-vector))
"Always a binary array."
(declare (ignore varray))
'bit)
(defmethod prototype-of ((varray vapri-onehot-vector))
"Prototype is zero, of course."
(declare (ignore varray))
0)
(defmethod shape-of ((varray vapri-onehot-vector))
"The vector's shape is explicit."
(varray-shape varray))
(defmethod generator-of ((varray vapri-onehot-vector) &optional indexers params)
"Return one if the index matches, zero if not."
(declare (ignore indexers))
(case (getf params :base-format)
(:encoded)
(:linear)
(t (lambda (index)
(if (= index (vaohv-index varray)) 1 0)))))
(defclass vapri-axis-vector (vad-nested varray-primal vad-with-io vad-with-dfactors)
((%reference :accessor vaxv-reference
:initform nil
:initarg :reference
:documentation "The array to which this axis vector belongs.")
(%axis :accessor vaxv-axis
:initform nil
:initarg :axis
:documentation "The axis along which the axis vector leads.")
(%window :accessor vaxv-window
:initform nil
:initarg :window
:documentation "The window of division along the axis.")
(%index :accessor vaxv-index
:initform nil
:initarg :index
:documentation "This axis vector's index within the reference array reduced along the axis."))
(:metaclass va-class)
(:documentation "A sub-vector along an axis of an array."))
(defmethod etype-of ((varray vapri-axis-vector))
(etype-of (vaxv-reference varray)))
(defmethod shape-of ((varray vapri-axis-vector))
(get-promised (varray-shape varray)
(list (or (vaxv-window varray)
(nth (vaxv-axis varray) (shape-of (vaxv-reference varray)))))))
(defmethod generator-of ((varray vapri-axis-vector) &optional indexers params)
(declare (ignore indexers))
(let* ((axis (vaxv-axis varray))
(window (vaxv-window varray))
(wsegment)
(ref-index (vaxv-index varray))
(ref-indexer (generator-of (vaxv-reference varray)))
(irank (rank-of (vaxv-reference varray)))
(idims (shape-of (vaxv-reference varray)))
(rlen (nth axis idims))
(increment (reduce #'* (nthcdr (1+ axis) idims))))
(loop :for dim :in idims :for dx :from 0
:when (and window (= dx axis))
:do (setq wsegment (- dim (1- window))))
(let ((delta (+ (if window (* rlen (floor ref-index wsegment))
(if (= 1 increment)
0 (* (floor ref-index increment)
(- (* increment rlen) increment))))
(if (/= 1 increment) ref-index
(if window (if (>= 1 irank) ref-index
(mod ref-index wsegment))
(* ref-index rlen))))))
(case (getf params :base-format)
(:encoded)
(:linear)
(t (lambda (index) (funcall ref-indexer (+ delta (* index increment)))))))))