Skip to content

Commit

Permalink
Add support for subcube binary operations
Browse files Browse the repository at this point in the history
If a binary operation involves a cube which is a strict sub-cube of
another then we can support the same semantics that numpy does in these
instances.
  • Loading branch information
iamsrp-deshaw committed Aug 22, 2024
1 parent 6f19d5c commit 186fb01
Show file tree
Hide file tree
Showing 5 changed files with 378 additions and 67 deletions.
152 changes: 124 additions & 28 deletions java/src/main/java/com/deshaw/hypercube/CubeMath.java
Original file line number Diff line number Diff line change
Expand Up @@ -4694,12 +4694,6 @@ private static <T> Hypercube<T> binaryOp(final Hypercube<T> a,
if (b == null) {
throw new NullPointerException("Given a null cube, 'b'");
}
if (!a.matches(b)) {
throw new IllegalArgumentException("Given incompatible cubes");
}
if (w != null && !a.matchesInShape(w)) {
throw new IllegalArgumentException("Given an incompatible 'w' cube");
}

// See if we can do it
if (a.getElementType().equals(Double.class)) {
Expand Down Expand Up @@ -4764,9 +4758,81 @@ private static <T> Hypercube<T> binaryOp(final Hypercube<T> a,
if (r == null) {
throw new NullPointerException("Given a null cube, 'r'");
}

if (!a.matches(b) || !a.matches(r)) {
throw new IllegalArgumentException("Given incompatible cubes");
// The don't exactly match, so let's see if one is a sub-cube of the
// other and that we can map into the result nicely
if (a.submatches(b) && a.matches(r)) {
// If we have weights then they need to match 'a' in shape
if (w != null && !a.matchesInShape(w)) {
throw new IllegalArgumentException(
"Given an incompatible 'w' cube"
);
}

// How we will slice up the 'a' and 'r' cubes so that we can
// recurse using 'b' directly
final Dimension<?>[] aDims = a.getDimensions();
final Dimension.Accessor<?>[] aAccessors =
new Dimension.Accessor<?>[aDims.length];
final Dimension<?>[] wDims =
(w == null) ? null : w.getDimensions();
final Dimension.Accessor<?>[] wAccessors =
(w == null) ? null : new Dimension.Accessor<?>[wDims.length];
final Dimension<?> aDim = aDims[0];
final Dimension<?> wDim = (w == null) ? null : wDims[0];

// Slice with each part of the dimension and recurse on the
// resultant subcubes
for (long i=0; i < aDim.length(); i++) {
aAccessors[0] = aDim.at(i);
if (w != null) {
wAccessors[0] = wDim.at(i);
}
binaryOp(
a.slice(aAccessors),
b,
r.slice(aAccessors),
(w == null) ? null : w.slice(wAccessors),
op
);
}
return r;
}

// The converse case to the above
if (b.submatches(a) && b.matches(r)) {
if (w != null && !b.matchesInShape(w)) {
throw new IllegalArgumentException(
"Given an incompatible 'w' cube"
);
}
final Dimension<?>[] bDims = b.getDimensions();
final Dimension.Accessor<?>[] bAccessors =
new Dimension.Accessor<?>[bDims.length];
final Dimension<?>[] wDims =
(w == null) ? null : w.getDimensions();
final Dimension.Accessor<?>[] wAccessors =
(w == null) ? null : new Dimension.Accessor<?>[wDims.length];
final Dimension<?> bDim = bDims[0];
final Dimension<?> wDim = (w == null) ? null : wDims[0];
for (long i=0; i < bDim.length(); i++) {
bAccessors[0] = bDim.at(i);
if (w != null) {
wAccessors[0] = wDim.at(i);
}
binaryOp(
a,
b.slice(bAccessors),
r.slice(bAccessors),
(w == null) ? null : w.slice(wAccessors),
op
);
}
return r;
}
}

if (w != null && !a.matchesInShape(w)) {
throw new IllegalArgumentException("Given an incompatible 'w' cube");
}
Expand Down Expand Up @@ -9370,12 +9436,18 @@ private static Hypercube<Boolean> booleanBinaryOp(
if (b == null) {
throw new NullPointerException("Given a null cube, 'b'");
}
if (!a.matches(b)) {
throw new IllegalArgumentException("Given incompatible cubes");

// Depending on which cube is a non-strict supercube of the other, create a simple
// BitSet destination one
if (a.submatches(b)) {
return binaryOp(a, b, new BooleanBitSetHypercube(a.getDimensions()), dw, op);
}
if (b.submatches(a)) {
return binaryOp(a, b, new BooleanBitSetHypercube(b.getDimensions()), dw, op);
}

// Create the destination, a simple BitSet one by default
return binaryOp(a, b, new BooleanBitSetHypercube(a.getDimensions()), dw, op);
// No match between the cubes
throw new IllegalArgumentException("Given incompatible cubes");
}

/**
Expand Down Expand Up @@ -10906,12 +10978,18 @@ private static Hypercube<Integer> intBinaryOp(
if (b == null) {
throw new NullPointerException("Given a null cube, 'b'");
}
if (!a.matches(b)) {
throw new IllegalArgumentException("Given incompatible cubes");

// Depending on which cube is a non-strict supercube of the other, create a simple
// Array destination one
if (a.submatches(b)) {
return binaryOp(a, b, new IntegerArrayHypercube(a.getDimensions()), dw, op);
}
if (b.submatches(a)) {
return binaryOp(a, b, new IntegerArrayHypercube(b.getDimensions()), dw, op);
}

// Create the destination, a simple Array one by default
return binaryOp(a, b, new IntegerArrayHypercube(a.getDimensions()), dw, op);
// No match between the cubes
throw new IllegalArgumentException("Given incompatible cubes");
}

/**
Expand Down Expand Up @@ -12454,12 +12532,18 @@ private static Hypercube<Long> longBinaryOp(
if (b == null) {
throw new NullPointerException("Given a null cube, 'b'");
}
if (!a.matches(b)) {
throw new IllegalArgumentException("Given incompatible cubes");

// Depending on which cube is a non-strict supercube of the other, create a simple
// Array destination one
if (a.submatches(b)) {
return binaryOp(a, b, new LongArrayHypercube(a.getDimensions()), dw, op);
}
if (b.submatches(a)) {
return binaryOp(a, b, new LongArrayHypercube(b.getDimensions()), dw, op);
}

// Create the destination, a simple Array one by default
return binaryOp(a, b, new LongArrayHypercube(a.getDimensions()), dw, op);
// No match between the cubes
throw new IllegalArgumentException("Given incompatible cubes");
}

/**
Expand Down Expand Up @@ -13999,12 +14083,18 @@ private static Hypercube<Float> floatBinaryOp(
if (b == null) {
throw new NullPointerException("Given a null cube, 'b'");
}
if (!a.matches(b)) {
throw new IllegalArgumentException("Given incompatible cubes");

// Depending on which cube is a non-strict supercube of the other, create a simple
// Array destination one
if (a.submatches(b)) {
return binaryOp(a, b, new FloatArrayHypercube(a.getDimensions()), dw, op);
}
if (b.submatches(a)) {
return binaryOp(a, b, new FloatArrayHypercube(b.getDimensions()), dw, op);
}

// Create the destination, a simple Array one by default
return binaryOp(a, b, new FloatArrayHypercube(a.getDimensions()), dw, op);
// No match between the cubes
throw new IllegalArgumentException("Given incompatible cubes");
}

/**
Expand Down Expand Up @@ -15563,12 +15653,18 @@ private static Hypercube<Double> doubleBinaryOp(
if (b == null) {
throw new NullPointerException("Given a null cube, 'b'");
}
if (!a.matches(b)) {
throw new IllegalArgumentException("Given incompatible cubes");

// Depending on which cube is a non-strict supercube of the other, create a simple
// Array destination one
if (a.submatches(b)) {
return binaryOp(a, b, new DoubleArrayHypercube(a.getDimensions()), dw, op);
}
if (b.submatches(a)) {
return binaryOp(a, b, new DoubleArrayHypercube(b.getDimensions()), dw, op);
}

// Create the destination, a simple Array one by default
return binaryOp(a, b, new DoubleArrayHypercube(a.getDimensions()), dw, op);
// No match between the cubes
throw new IllegalArgumentException("Given incompatible cubes");
}

/**
Expand Down Expand Up @@ -16572,4 +16668,4 @@ private static Hypercube<Double> doubleExtract(
}
}

// [[[end]]] (checksum: 7d429fc88498b76e3459643a3e08a022)
// [[[end]]] (checksum: d6f3bef31cdaaf690de998752896891e)
32 changes: 32 additions & 0 deletions java/src/main/java/com/deshaw/hypercube/Hypercube.java
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,38 @@ public default boolean matches(final Hypercube<T> that)
that.getShape()));
}

/**
* Whether this hypercube instance matches the given one in the higher
* dimensions. This is essentially saying whether {@code that} is a
* compatible non-strict subcube of this one.
*/
public default boolean submatches(final Hypercube<T> that)
{
// Similar to matches(), since simple checks
if (this == that) {
return true;
}
if (that == null ||
!getElementType().equals(that.getElementType()) ||
getNDim() < that.getNDim())
{
return false;
}

// Now we want to compare the higher dimensions
final Dimension<?>[] thisDim = this.getDimensions();
final Dimension<?>[] thatDim = that.getDimensions();
for (int i = 1; i <= thatDim.length; i++) {
if (!thisDim[thisDim.length-i].equals(
thatDim[thatDim.length-i]
))
{
return false;
}
}
return true;
}

/**
* Whether this hypercube instance matches the given one in shape. For this
* to be true, the following properties must match between the two cubes:<ol>
Expand Down
Loading

0 comments on commit 186fb01

Please sign in to comment.