forked from rust-ndarray/ndarray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbroadcast.rs
100 lines (91 loc) · 2.52 KB
/
broadcast.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
use ndarray::prelude::*;
#[test]
#[cfg(feature = "std")]
fn broadcast_1()
{
let a_dim = Dim([2, 4, 2, 2]);
let b_dim = Dim([2, 1, 2, 1]);
let a = ArcArray::linspace(0., 1., a_dim.size())
.into_shape_with_order(a_dim)
.unwrap();
let b = ArcArray::linspace(0., 1., b_dim.size())
.into_shape_with_order(b_dim)
.unwrap();
assert!(b.broadcast(a.dim()).is_some());
let c_dim = Dim([2, 1]);
let c = ArcArray::linspace(0., 1., c_dim.size())
.into_shape_with_order(c_dim)
.unwrap();
assert!(c.broadcast(1).is_none());
assert!(c.broadcast(()).is_none());
assert!(c.broadcast((2, 1)).is_some());
assert!(c.broadcast((2, 2)).is_some());
assert!(c.broadcast((32, 2, 1)).is_some());
assert!(c.broadcast((32, 1, 2)).is_none());
/* () can be broadcast to anything */
let z = ArcArray::<f32, _>::zeros(());
assert!(z.broadcast(()).is_some());
assert!(z.broadcast(1).is_some());
assert!(z.broadcast(3).is_some());
assert!(z.broadcast((7, 2, 9)).is_some());
}
#[test]
#[cfg(feature = "std")]
fn test_add()
{
let a_dim = Dim([2, 4, 2, 2]);
let b_dim = Dim([2, 1, 2, 1]);
let mut a = ArcArray::linspace(0.0, 1., a_dim.size())
.into_shape_with_order(a_dim)
.unwrap();
let b = ArcArray::linspace(0.0, 1., b_dim.size())
.into_shape_with_order(b_dim)
.unwrap();
a += &b;
let t = ArcArray::from_elem((), 1.0f32);
a += &t;
}
#[test]
#[should_panic]
#[cfg(feature = "std")]
fn test_add_incompat()
{
let a_dim = Dim([2, 4, 2, 2]);
let mut a = ArcArray::linspace(0.0, 1., a_dim.size())
.into_shape_with_order(a_dim)
.unwrap();
let incompat = ArcArray::from_elem(3, 1.0f32);
a += &incompat;
}
#[test]
fn test_broadcast()
{
let (_, n, k) = (16, 16, 16);
let x1 = 1.;
// b0 broadcast 1 -> n, k
let x = Array::from(vec![x1]);
let b0 = x.broadcast((n, k)).unwrap();
// b1 broadcast n -> n, k
let b1 = Array::from_elem(n, x1);
let b1 = b1.broadcast((n, k)).unwrap();
// b2 is n, k
let b2 = Array::from_elem((n, k), x1);
println!("b0=\n{:?}", b0);
println!("b1=\n{:?}", b1);
println!("b2=\n{:?}", b2);
assert_eq!(b0, b1);
assert_eq!(b0, b2);
}
#[test]
fn test_broadcast_1d()
{
let n = 16;
let x1 = 1.;
// b0 broadcast 1 -> n
let x = Array::from(vec![x1]);
let b0 = x.broadcast(n).unwrap();
let b2 = Array::from_elem(n, x1);
println!("b0=\n{:?}", b0);
println!("b2=\n{:?}", b2);
assert_eq!(b0, b2);
}