diff --git a/src/main/java/org/apache/sysds/hops/estim/EstimatorRowWise.java b/src/main/java/org/apache/sysds/hops/estim/EstimatorRowWise.java new file mode 100644 index 00000000000..5fd85783ba8 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/estim/EstimatorRowWise.java @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.estim; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.data.SparseRow; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; + +import java.util.stream.DoubleStream; +import java.util.stream.IntStream; + +/** + * This estimator implements an approach based on row-wise sparsity estimation, + * introduced in + * Lin, Chunxu, Wensheng Luo, Yixiang Fang, Chenhao Ma, Xilin Liu and Yuchi Ma: + * On Efficient Large Sparse Matrix Chain Multiplication. + * Proceedings of the ACM on Management of Data 2 (2024): 1 - 27. + */ +public class EstimatorRowWise extends SparsityEstimator { + @Override + public DataCharacteristics estim(MMNode root) { + estimInternChain(root); + double sparsity = DoubleStream.of((double[])root.getSynopsis()).average().orElse(0); + + DataCharacteristics outputCharacteristics = deriveOutputCharacteristics(root, sparsity); + return root.setDataCharacteristics(outputCharacteristics); + } + + @Override + public double estim(MatrixBlock m1, MatrixBlock m2) { + return estim(m1, m2, OpCode.MM); + } + + @Override + public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) { + if( isExactMetadataOp(op, m1.getNumColumns()) ) { + return estimExactMetaData(m1.getDataCharacteristics(), + m2.getDataCharacteristics(), op).getSparsity(); + } + + double[] rsOut = estimIntern(m1, m2, op); + return DoubleStream.of(rsOut).average().orElse(0); + } + + @Override + public double estim(MatrixBlock m1, OpCode op) { + if( isExactMetadataOp(op, m1.getNumColumns()) ) + return estimExactMetaData(m1.getDataCharacteristics(), null, op).getSparsity(); + + double[] rsOut = estimIntern(m1, op); + return DoubleStream.of(rsOut).average().orElse(0); + } + + private void estimInternChain(MMNode node) { + estimInternChain(node, null, null); + } + + private void estimInternChain(MMNode node, double[] rsRightNeighbor, OpCode opRightNeighbor) { + double[] rsOut; + if(node.isLeaf()) { + MatrixBlock mb = node.getData(); + if(rsRightNeighbor != null) + rsOut = estimIntern(mb, rsRightNeighbor, opRightNeighbor); + else + rsOut = getRowWiseSparsityVector(mb); + } + else { + switch(node.getOp()) { + case MM: + estimInternChain(node.getRight(), rsRightNeighbor, opRightNeighbor); + estimInternChain(node.getLeft(), (double[])(node.getRight().getSynopsis()), node.getOp()); + rsOut = (double[])node.getLeft().getSynopsis(); + break; + case CBIND: + /** NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of + * the right neighbor cannot be aggregated into a cbind operation when having only row sparsity vectors + */ + estimInternChain(node.getLeft()); + estimInternChain(node.getRight()); + double[] rsCBind = estimInternCBind((double[])(node.getLeft().getSynopsis()), (double[])(node.getRight().getSynopsis())); + if(rsRightNeighbor != null) { + rsOut = (double[])estimInternMMFallback(rsCBind, rsRightNeighbor); + if(opRightNeighbor != OpCode.MM) + throw new NotImplementedException("Fallback sparsity estimation has only been " + + "considered for MM operation w/ right neighbor yet."); + } + else + rsOut = (double[])rsCBind; + break; + case RBIND: + /** NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of + * the right neighbor cannot be aggregated into an rbind operation when having only row sparsity vectors + */ + estimInternChain(node.getLeft()); + estimInternChain(node.getRight()); + double[] rsRBind = estimInternRBind((double[])(node.getLeft().getSynopsis()), (double[])(node.getRight().getSynopsis())); + if(rsRightNeighbor != null) { + rsOut = (double[])estimInternMMFallback(rsRBind, rsRightNeighbor); + if(opRightNeighbor != OpCode.MM) + throw new NotImplementedException("Fallback sparsity estimation has only been " + + "considered for MM operation w/ right neighbor yet."); + } + else + rsOut = (double[])rsRBind; + break; + case PLUS: + /** NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of + * the right neighbor cannot be aggregated into an element-wise operation when having only row sparsity vectors + */ + estimInternChain(node.getLeft()); + estimInternChain(node.getRight()); + double[] rsPlus = estimInternPlus((double[])(node.getLeft().getSynopsis()), (double[])(node.getRight().getSynopsis())); + if(rsRightNeighbor != null) { + rsOut = (double[])estimInternMMFallback(rsPlus, rsRightNeighbor); + if(opRightNeighbor != OpCode.MM) + throw new NotImplementedException("Fallback sparsity estimation has only been " + + "considered for MM operation w/ right neighbor yet."); + } + else + rsOut = (double[])rsPlus; + break; + case MULT: + /** NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of + * the right neighbor cannot be aggregated into an element-wise operation when having only row sparsity vectors + */ + estimInternChain(node.getLeft()); + estimInternChain(node.getRight()); + double[] rsMult = estimInternMult((double[])(node.getLeft().getSynopsis()), (double[])(node.getRight().getSynopsis())); + if(rsRightNeighbor != null) { + rsOut = (double[])estimInternMMFallback(rsMult, rsRightNeighbor); + if(opRightNeighbor != OpCode.MM) + throw new NotImplementedException("Fallback sparsity estimation has only been " + + "considered for MM operation w/ right neighbor yet."); + } + else + rsOut = (double[])rsMult; + break; + default: + throw new NotImplementedException("Chain estimation for operator " + node.getOp().toString() + + " is not supported yet."); + } + } + node.setSynopsis(rsOut); + node.setDataCharacteristics(deriveOutputCharacteristics(node, DoubleStream.of(rsOut).average().orElse(0))); + return; + } + + private double[] estimIntern(MatrixBlock m1, MatrixBlock m2, OpCode op) { + double[] rsM2 = getRowWiseSparsityVector(m2); + return estimIntern(m1, rsM2, op); + } + + private double[] estimIntern(MatrixBlock m1, double[] rsM2, OpCode op) { + switch(op) { + case MM: + return estimInternMM(m1, rsM2); + case CBIND: + return estimInternCBind(getRowWiseSparsityVector(m1), rsM2); + case RBIND: + return estimInternRBind(getRowWiseSparsityVector(m1), rsM2); + case PLUS: + return estimInternPlus(getRowWiseSparsityVector(m1), rsM2); + case MULT: + return estimInternMult(getRowWiseSparsityVector(m1), rsM2); + default: + throw new NotImplementedException("Sparsity estimation for operation " + op.toString() + " not supported yet."); + } + } + + private double[] estimIntern(MatrixBlock mb, OpCode op) { + switch(op) { + case DIAG: + return estimInternDiag(mb); + default: + throw new NotImplementedException("Sparsity estimation for operation " + op.toString() + " not supported yet."); + } + } + + // Corresponds to Algorithm 1 in the publication + private double[] estimInternMM(MatrixBlock m1, double[] rsM2) { + double[] rsOut = IntStream.range(0, m1.getNumRows()).mapToDouble( + r -> (double) 1 - IntStream.of(getNonZeroColumnIndices(m1, r)).mapToDouble( + c -> (double) 1 - rsM2[c] + ).reduce((double) 1, (currentVal, val) -> currentVal * val)) + .toArray(); + return rsOut; + } + + // NOTE: this is the best estimation possible when we only have the two row sparsity vectors + private double[] estimInternMMFallback(double[] rsM1, double[] rsM2) { + // NOTE: Considering the average would probably not be far off while saving computing time + // double avgRsM2 = DoubleStream.of(rsM2).average().orElse(0); + // double[] rsOut = DoubleStream.of(rsM1).map( + // rsM1I -> (double) 1 - Math.pow((double) 1 - (rsM1I * avgRsM2), rsM2.length)).toArray(); + double[] rsOut = DoubleStream.of(rsM1).map( + rsM1I -> (double) 1 - DoubleStream.of(rsM2).reduce((double) 1, + (currentVal, rsM2J) -> currentVal * ((double) 1 - (rsM1I * rsM2J)))).toArray(); + return rsOut; + } + + private double[] estimInternCBind(double[] rsM1, double[] rsM2) { + // FIXME: this assumes that the number of columns is equivalent for both inputs + return IntStream.range(0, rsM1.length).mapToDouble( + idx -> (rsM1[idx] + rsM2[idx]) / (double) 2).toArray(); + } + + private double[] estimInternRBind(double[] rsM1, double[] rsM2) { + return ArrayUtils.addAll(rsM1, rsM2); + } + + private double[] estimInternPlus(double[] rsM1, double[] rsM2) { + // row-wise average case estimates + // rsM1 + rsM2 - (rsM1 * rsM2) + return IntStream.range(0, rsM1.length).mapToDouble( + idx -> rsM1[idx] + rsM2[idx] - (rsM1[idx] * rsM2[idx])).toArray(); + } + + private double[] estimInternMult(double[] rsM1, double[] rsM2) { + // row-wise average case estimates + // rsM1 * rsM2 + return IntStream.range(0, rsM1.length).mapToDouble( + idx -> rsM1[idx] * rsM2[idx]).toArray(); + } + + private double[] estimInternDiag(MatrixBlock mb) { + double[] rsOut = IntStream.range(0, mb.getNumRows()).mapToDouble( + rIdx -> (mb.get(rIdx, rIdx) == 0) ? 0d : 1d) + .toArray(); + return rsOut; + } + + private double[] getRowWiseSparsityVector(MatrixBlock mb) { + int numRows = mb.getNumRows(); + if(mb.isInSparseFormat()) { + double[] rsArray = new double[numRows]; + for(int counter = 0; counter < numRows; counter++) { + SparseRow sparseRow = mb.getSparseBlock().get(counter); + rsArray[counter] = (sparseRow == null) ? 0 : (double) sparseRow.size() / mb.getNumColumns(); + } + return rsArray; + } + else { + return IntStream.range(0, numRows).mapToDouble( + rIdx -> (double) mb.getDenseBlock().countNonZeros(rIdx) / mb.getNumColumns()) + .toArray(); + } + } + + private int[] getNonZeroColumnIndices(MatrixBlock mb, final int rIdx) { + int[] nonZeroCols; + if(mb.isInSparseFormat()) { + SparseRow sparseRow = mb.getSparseBlock().get(rIdx); + nonZeroCols = (sparseRow == null) ? new int[0] : sparseRow.indexes(); + } + else { + nonZeroCols = IntStream.range(0, mb.getNumColumns()) + .filter(cIdx -> mb.get(rIdx, cIdx) != 0).toArray(); + } + return nonZeroCols; + } + + public static DataCharacteristics deriveOutputCharacteristics(MMNode node, double spOut) { + if(node.isLeaf() || + (node.getDataCharacteristics() != null && node.getDataCharacteristics().getNonZeros() != -1)) { + return node.getDataCharacteristics(); + } + + MMNode nodeLeft = node.getLeft(); + MMNode nodeRight = node.getRight(); + int leftNRow = nodeLeft.getRows(); + int leftNCol = nodeLeft.getCols(); + int rightNRow = nodeRight.getRows(); + int rightNCol = nodeRight.getCols(); + switch(node.getOp()) { + case MM: + return new MatrixCharacteristics(leftNRow, rightNCol, + OptimizerUtils.getNnz(leftNRow, rightNCol, spOut)); + case MULT: + case PLUS: + case NEQZERO: + case EQZERO: + return new MatrixCharacteristics(leftNRow, leftNCol, + OptimizerUtils.getNnz(leftNRow, leftNCol, spOut)); + case RBIND: + return new MatrixCharacteristics(leftNRow+rightNRow, leftNCol, + OptimizerUtils.getNnz(leftNRow+rightNRow, leftNCol, spOut)); + case CBIND: + return new MatrixCharacteristics(leftNRow, leftNCol+rightNCol, + OptimizerUtils.getNnz(leftNRow, leftNCol+rightNCol, spOut)); + case DIAG: + int ncol = (leftNCol == 1) ? leftNRow : 1; + return new MatrixCharacteristics(leftNRow, ncol, + OptimizerUtils.getNnz(leftNRow, ncol, spOut)); + case TRANS: + return new MatrixCharacteristics(leftNCol, leftNRow, + OptimizerUtils.getNnz(leftNCol, leftNRow, spOut)); + case RESHAPE: + throw new NotImplementedException("Characteristics derivation for " + node.getOp() +" has not been " + + "implemented yet, but could be implemented similar to EstimatorMatrixHistogram.java"); + default: + throw new NotImplementedException(); + } + } +}; diff --git a/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java b/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java index 35efedaf625..05fd9d32c8b 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java @@ -24,6 +24,7 @@ import org.apache.sysds.hops.estim.EstimatorBasicWorst; import org.apache.sysds.hops.estim.EstimatorBitsetMM; import org.apache.sysds.hops.estim.EstimatorMatrixHistogram; +import org.apache.sysds.hops.estim.EstimatorRowWise; import org.apache.sysds.hops.estim.EstimatorLayeredGraph; import org.apache.sysds.hops.estim.MMNode; import org.apache.sysds.hops.estim.SparsityEstimator; @@ -35,7 +36,7 @@ import org.apache.commons.lang3.NotImplementedException; /** - * this is the basic operation check for all estimators with single operations + * this is the basic operation check for all estimators with chains of operations including binding operations */ public class OpBindChainTest extends AutomatedTestBase { @@ -127,41 +128,41 @@ public void testLGCasecbind() { new EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS, 3), m, k, n, sparsity, cbind); } - - + + // Row Wise Sparsity Estimator + @Test + public void testRowWiseRbind() { + runSparsityEstimateTest(new EstimatorRowWise(), m, k, n, sparsity, rbind); + } + + @Test + public void testRowWiseCbind() { + runSparsityEstimateTest(new EstimatorRowWise(), m, k, n, sparsity, cbind); + } + + private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int k, int n, double[] sp, OpCode op) { - MatrixBlock m1; + MatrixBlock m1 = MatrixBlock.randOperations(m, k, sp[0], 1, 1, "uniform", 3); MatrixBlock m2; MatrixBlock m3 = new MatrixBlock(); MatrixBlock m4; - MatrixBlock m5 = new MatrixBlock(); - double est = 0; switch(op) { case RBIND: - m1 = MatrixBlock.randOperations(m, k, sp[0], 1, 1, "uniform", 3); m2 = MatrixBlock.randOperations(n, k, sp[1], 1, 1, "uniform", 7); m1.append(m2, m3, false); m4 = MatrixBlock.randOperations(k, m, sp[1], 1, 1, "uniform", 5); - m5 = m3.aggregateBinaryOperations(m3, m4, - new MatrixBlock(), InstructionUtils.getMatMultOperator(1)); - est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m4), OpCode.MM)).getSparsity(); - //System.out.println(est); - //System.out.println(m5.getSparsity()); break; case CBIND: - m1 = MatrixBlock.randOperations(m, k, sp[0], 1, 1, "uniform", 3); m2 = MatrixBlock.randOperations(m, n, sp[1], 1, 1, "uniform", 7); m1.append(m2, m3, true); m4 = MatrixBlock.randOperations(k+n, m, sp[1], 1, 1, "uniform", 5); - m5 = m3.aggregateBinaryOperations(m3, m4, - new MatrixBlock(), InstructionUtils.getMatMultOperator(1)); - est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m4), OpCode.MM)).getSparsity(); - //System.out.println(est); - //System.out.println(m5.getSparsity()); break; default: throw new NotImplementedException(); } + MatrixBlock m5 = m3.aggregateBinaryOperations(m3, m4, + new MatrixBlock(), InstructionUtils.getMatMultOperator(1)); + double est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m4), OpCode.MM)).getSparsity(); //compare estimated and real sparsity TestUtils.compareScalars(est, m5.getSparsity(), (estim instanceof EstimatorBasicWorst) ? 5e-1 : diff --git a/src/test/java/org/apache/sysds/test/component/estim/OpBindTest.java b/src/test/java/org/apache/sysds/test/component/estim/OpBindTest.java index 3e7ad24fe86..97e7fec06ed 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/OpBindTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/OpBindTest.java @@ -24,6 +24,7 @@ import org.apache.sysds.hops.estim.EstimatorBasicWorst; import org.apache.sysds.hops.estim.EstimatorBitsetMM; import org.apache.sysds.hops.estim.EstimatorMatrixHistogram; +import org.apache.sysds.hops.estim.EstimatorRowWise; import org.apache.sysds.hops.estim.EstimatorLayeredGraph; import org.apache.sysds.hops.estim.SparsityEstimator; import org.apache.sysds.hops.estim.SparsityEstimator.OpCode; @@ -33,7 +34,7 @@ import org.apache.commons.lang3.NotImplementedException; /** - * this is the basic operation check for all estimators with single operations + * this is the basic operation check for all estimators with binding operations */ public class OpBindTest extends AutomatedTestBase { @@ -132,33 +133,38 @@ public void testSampleCaserbind() { public void testSampleCasecbind() { runSparsityEstimateTest(new EstimatorSample(), m, k, n, sparsity, cbind); }*/ - + + // Row Wise Sparsity Estimator + @Test + public void testRowWiseRbind() { + runSparsityEstimateTest(new EstimatorRowWise(), m, k, n, sparsity, rbind); + } + + @Test + public void testRowWiseCbind() { + runSparsityEstimateTest(new EstimatorRowWise(), m, k, n, sparsity, cbind); + } + private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int k, int n, double[] sp, OpCode op) { MatrixBlock m1; MatrixBlock m2; MatrixBlock m3 = new MatrixBlock(); - double est = 0; switch(op) { case RBIND: m1 = MatrixBlock.randOperations(m, k, sp[0], 1, 1, "uniform", 3); m2 = MatrixBlock.randOperations(n, k, sp[1], 1, 1, "uniform", 3); m1.append(m2, m3, false); - est = estim.estim(m1, m2, op); - // System.out.println(est); - // System.out.println(m3.getSparsity()); break; case CBIND: m1 = MatrixBlock.randOperations(10, 130, sp[0], 1, 1, "uniform", 3); m2 = MatrixBlock.randOperations(10, 70, sp[1], 1, 1, "uniform", 3); m1.append(m2, m3); - est = estim.estim(m1, m2, op); - // System.out.println(est); - // System.out.println(m3.getSparsity()); break; default: throw new NotImplementedException(); } + double est = estim.estim(m1, m2, op); //compare estimated and real sparsity TestUtils.compareScalars(est, m3.getSparsity(), (estim instanceof EstimatorBasicWorst) ? 5e-1 : diff --git a/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java b/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java index a1b6594a927..e61c25a67bc 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java @@ -25,6 +25,7 @@ import org.apache.sysds.hops.estim.EstimatorBitsetMM; import org.apache.sysds.hops.estim.EstimatorLayeredGraph; import org.apache.sysds.hops.estim.EstimatorMatrixHistogram; +import org.apache.sysds.hops.estim.EstimatorRowWise; import org.apache.sysds.hops.estim.EstimatorDensityMap; import org.apache.sysds.hops.estim.MMNode; import org.apache.sysds.hops.estim.SparsityEstimator; @@ -39,7 +40,7 @@ import org.apache.commons.lang3.NotImplementedException; /** - * this is the basic operation check for all estimators with single operations + * this is the basic operation check for all estimators with chains of operations including element-wise operations */ public class OpElemWChainTest extends AutomatedTestBase { @@ -118,38 +119,39 @@ public void testLGCasemult() { public void testLGCaseplus() { runSparsityEstimateTest(new EstimatorLayeredGraph(), m, n, sparsity, plus); } - - + + // Row Wise Sparsity Estimator + @Test + public void testRowWiseCaseMult() { + runSparsityEstimateTest(new EstimatorRowWise(), m, n, sparsity, mult); + } + + @Test + public void testRowWiseCasePlus() { + runSparsityEstimateTest(new EstimatorRowWise(), m, n, sparsity, plus); + } + private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int n, double[] sp, OpCode op) { MatrixBlock m1 = MatrixBlock.randOperations(m, n, sp[0], 1, 1, "uniform", 3); MatrixBlock m2 = MatrixBlock.randOperations(m, n, sp[1], 1, 1, "uniform", 5); MatrixBlock m3 = MatrixBlock.randOperations(n, m, sp[1], 1, 1, "uniform", 7); MatrixBlock m4 = new MatrixBlock(); - MatrixBlock m5 = new MatrixBlock(); BinaryOperator bOp; - double est = 0; switch(op) { case MULT: bOp = new BinaryOperator(Multiply.getMultiplyFnObject()); - m1.binaryOperations(bOp, m2, m4); - m5 = m4.aggregateBinaryOperations(m4, m3, - new MatrixBlock(), InstructionUtils.getMatMultOperator(1)); - est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m3), OpCode.MM)).getSparsity(); - // System.out.println(m5.getSparsity()); - // System.out.println(est); break; case PLUS: bOp = new BinaryOperator(Plus.getPlusFnObject()); - m1.binaryOperations(bOp, m2, m4); - m5 = m4.aggregateBinaryOperations(m4, m3, - new MatrixBlock(), InstructionUtils.getMatMultOperator(1)); - est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m3), OpCode.MM)).getSparsity(); - // System.out.println(m5.getSparsity()); - // System.out.println(est); break; default: throw new NotImplementedException(); } + m1.binaryOperations(bOp, m2, m4); + MatrixBlock m5 = m4.aggregateBinaryOperations(m4, m3, + new MatrixBlock(), InstructionUtils.getMatMultOperator(1)); + double est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m3), OpCode.MM)).getSparsity(); + //compare estimated and real sparsity TestUtils.compareScalars(est, m5.getSparsity(), (estim instanceof EstimatorBasicWorst) ? 9e-1 : (estim instanceof EstimatorLayeredGraph) ? 7e-2 : 1e-2); diff --git a/src/test/java/org/apache/sysds/test/component/estim/OpElemWTest.java b/src/test/java/org/apache/sysds/test/component/estim/OpElemWTest.java index f8ddb91bcef..5dc7d407220 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/OpElemWTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/OpElemWTest.java @@ -25,6 +25,7 @@ import org.apache.sysds.hops.estim.EstimatorBitsetMM; import org.apache.sysds.hops.estim.EstimatorDensityMap; import org.apache.sysds.hops.estim.EstimatorMatrixHistogram; +import org.apache.sysds.hops.estim.EstimatorRowWise; import org.apache.sysds.hops.estim.EstimatorLayeredGraph; import org.apache.sysds.hops.estim.EstimatorSample; import org.apache.sysds.hops.estim.SparsityEstimator; @@ -38,7 +39,7 @@ import org.apache.commons.lang3.NotImplementedException; /** - * this is the basic operation check for all estimators with single operations + * this is the basic operation check for all estimators with element-wise operations */ public class OpElemWTest extends AutomatedTestBase { @@ -128,31 +129,35 @@ public void testSampleMult() { public void testSamplePlus() { runSparsityEstimateTest(new EstimatorSample(), m, n, sparsity, plus); } - + + // Row Wise Sparsity Estimator + @Test + public void testRowWiseMult() { + runSparsityEstimateTest(new EstimatorRowWise(), m, n, sparsity, mult); + } + + @Test + public void testRowWisePlus() { + runSparsityEstimateTest(new EstimatorRowWise(), m, n, sparsity, plus); + } + private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int n, double[] sp, OpCode op) { MatrixBlock m1 = MatrixBlock.randOperations(m, n, sp[0], 1, 1, "uniform", 3); MatrixBlock m2 = MatrixBlock.randOperations(m, n, sp[1], 1, 1, "uniform", 7); MatrixBlock m3 = new MatrixBlock(); BinaryOperator bOp; - double est = 0; switch(op) { case MULT: bOp = new BinaryOperator(Multiply.getMultiplyFnObject()); - m1.binaryOperations(bOp, m2, m3); - est = estim.estim(m1, m2, op); - // System.out.println(est); - // System.out.println(m3.getSparsity()); break; case PLUS: bOp = new BinaryOperator(Plus.getPlusFnObject()); - m1.binaryOperations(bOp, m2, m3); - est = estim.estim(m1, m2, op); - // System.out.println(est); - // System.out.println(m3.getSparsity()); break; - default: - throw new NotImplementedException(); + default: + throw new NotImplementedException(); } + m1.binaryOperations(bOp, m2, m3); + double est = estim.estim(m1, m2, op); //compare estimated and real sparsity TestUtils.compareScalars(est, m3.getSparsity(), (estim instanceof EstimatorBasicWorst) ? 5e-1 : (estim instanceof EstimatorLayeredGraph) ? 3e-2 : 5e-3); diff --git a/src/test/java/org/apache/sysds/test/component/estim/OpSingleTest.java b/src/test/java/org/apache/sysds/test/component/estim/OpSingleTest.java index d40f84c4fb3..02284eeb449 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/OpSingleTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/OpSingleTest.java @@ -26,6 +26,7 @@ import org.apache.sysds.hops.estim.EstimatorBasicWorst; import org.apache.sysds.hops.estim.EstimatorBitsetMM; import org.apache.sysds.hops.estim.EstimatorLayeredGraph; +import org.apache.sysds.hops.estim.EstimatorRowWise; import org.apache.sysds.hops.estim.SparsityEstimator; import org.apache.sysds.hops.estim.SparsityEstimator.OpCode; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -40,7 +41,7 @@ public class OpSingleTest extends AutomatedTestBase private final static int m = 600; private final static int k = 300; private final static double sparsity = 0.2; -// private final static OpCode eqzero = OpCode.EQZERO; + private final static OpCode eqzero = OpCode.EQZERO; private final static OpCode diag = OpCode.DIAG; private final static OpCode neqzero = OpCode.NEQZERO; private final static OpCode trans = OpCode.TRANS; @@ -237,37 +238,64 @@ public void testLGCasetrans() { // public void testSampleCasereshape() { // runSparsityEstimateTest(new EstimatorSample(), m, k, sparsity, reshape); // } - + + // Row Wise Sparsity Estimator + @Test + public void testRowWiseEqzero() { + runSparsityEstimateTest(new EstimatorRowWise(), m, k, sparsity, eqzero); + } + + @Test + public void testRowWiseDiagMV() { + runSparsityEstimateTest(new EstimatorRowWise(), m, m, sparsity, diag); + } + + @Test + public void testRowWiseDiagVM() { + runSparsityEstimateTest(new EstimatorRowWise(), m, 1, sparsity, diag); + } + + @Test + public void testRowWiseNeqzero() { + runSparsityEstimateTest(new EstimatorRowWise(), m, k, sparsity, neqzero); + } + + @Test + public void testRowWiseTrans() { + runSparsityEstimateTest(new EstimatorRowWise(), m, k, sparsity, trans); + } + + @Test + public void testRowWiseReshape() { + runSparsityEstimateTest(new EstimatorRowWise(), m, k, sparsity, reshape); + } + private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int k, double sp, OpCode op) { MatrixBlock m1 = MatrixBlock.randOperations(m, k, sp, 1, 1, "uniform", 3); - MatrixBlock m2 = new MatrixBlock(); - double est = 0; + MatrixBlock m2; + double ref = -1; switch(op) { case EQZERO: - //TODO find out how to do eqzero + ref = 1 - m1.getSparsity(); + break; case DIAG: m2 = m1.getNumColumns() == 1 ? LibMatrixReorg.diag(m1, new MatrixBlock(m1.getNumRows(), m1.getNumRows(), false)) : LibMatrixReorg.diag(m1, new MatrixBlock(m1.getNumRows(), 1, false)); - est = estim.estim(m1, op); + ref = m2.getSparsity(); break; case NEQZERO: - m2 = m1; - est = estim.estim(m1, op); - break; case TRANS: - m2 = m1; - est = estim.estim(m1, op); - break; case RESHAPE: m2 = m1; - est = estim.estim(m1, op); + ref = m2.getSparsity(); break; default: throw new NotImplementedException(); } + double est = estim.estim(m1, op); //compare estimated and real sparsity - TestUtils.compareScalars(est, m2.getSparsity(), + TestUtils.compareScalars(est, ref, (estim instanceof EstimatorBasicWorst) ? 5e-1 : (estim instanceof EstimatorLayeredGraph) ? 3e-2 : 2e-2); } diff --git a/src/test/java/org/apache/sysds/test/component/estim/OuterProductTest.java b/src/test/java/org/apache/sysds/test/component/estim/OuterProductTest.java index fdc33d878db..f71d9989ccd 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/OuterProductTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/OuterProductTest.java @@ -26,6 +26,7 @@ import org.apache.sysds.hops.estim.EstimatorDensityMap; import org.apache.sysds.hops.estim.EstimatorMatrixHistogram; import org.apache.sysds.hops.estim.EstimatorLayeredGraph; +import org.apache.sysds.hops.estim.EstimatorRowWise; import org.apache.sysds.hops.estim.EstimatorSample; import org.apache.sysds.hops.estim.SparsityEstimator; import org.apache.sysds.runtime.instructions.InstructionUtils; @@ -150,6 +151,16 @@ public void testLayeredGraphCase2() { runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, case2); } + @Test + public void testRowWiseCase1() { + runSparsityEstimateTest(new EstimatorRowWise(), m, k, n, case1); + } + + @Test + public void testRowWiseCase2() { + runSparsityEstimateTest(new EstimatorRowWise(), m, k, n, case2); + } + private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int k, int n, double[] sp) { MatrixBlock m1 = MatrixBlock.randOperations(m, k, sp[0], 1, 1, "uniform", 3); MatrixBlock m2 = MatrixBlock.randOperations(k, n, sp[1], 1, 1, "uniform", 3); diff --git a/src/test/java/org/apache/sysds/test/component/estim/SelfProductTest.java b/src/test/java/org/apache/sysds/test/component/estim/SelfProductTest.java index d99f38d939b..2feeae6fc37 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/SelfProductTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/SelfProductTest.java @@ -28,6 +28,7 @@ import org.apache.sysds.hops.estim.EstimatorDensityMap; import org.apache.sysds.hops.estim.EstimatorLayeredGraph; import org.apache.sysds.hops.estim.EstimatorMatrixHistogram; +import org.apache.sysds.hops.estim.EstimatorRowWise; import org.apache.sysds.hops.estim.EstimatorSample; import org.apache.sysds.hops.estim.EstimatorSampleRa; import org.apache.sysds.hops.estim.SparsityEstimator; @@ -156,7 +157,15 @@ public void testLayeredGraphCase1() { public void testLayeredGraphCase2() { runSparsityEstimateTest(new EstimatorLayeredGraph(), m, sparsity2); } - + + @Test + public void testRowWiseCase() { + runSparsityEstimateTest(new EstimatorRowWise(), m/4, sparsity0); + runSparsityEstimateTest(new EstimatorRowWise(), m/2, sparsity1); + runSparsityEstimateTest(new EstimatorRowWise(), m, sparsity2); + runSparsityEstimateTest(new EstimatorRowWise(), m, sparsity3); + } + private static void runSparsityEstimateTest(SparsityEstimator estim, int n, double sp) { MatrixBlock m1 = MatrixBlock.randOperations(n, n, sp, 1, 1, "uniform", 3); MatrixBlock m3 = m1.aggregateBinaryOperations(m1, m1, diff --git a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java index f799b02c96d..502ed62de29 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java @@ -26,6 +26,7 @@ import org.apache.sysds.hops.estim.EstimatorDensityMap; import org.apache.sysds.hops.estim.EstimatorLayeredGraph; import org.apache.sysds.hops.estim.EstimatorMatrixHistogram; +import org.apache.sysds.hops.estim.EstimatorRowWise; import org.apache.sysds.hops.estim.MMNode; import org.apache.sysds.hops.estim.SparsityEstimator; import org.apache.sysds.hops.estim.SparsityEstimator.OpCode; @@ -146,7 +147,17 @@ public void testLayeredGraph32Case1() { public void testLayeredGraph32Case2() { runSparsityEstimateTest(new EstimatorLayeredGraph(32), m, k, n, n2, case2); } - + + @Test + public void testRowWiseCase1() { + runSparsityEstimateTest(new EstimatorRowWise(), m, k, n, n2, case1); + } + + @Test + public void testRowWiseCase2() { + runSparsityEstimateTest(new EstimatorRowWise(), m, k, n, n2, case2); + } + private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int k, int n, int n2, double[] sp) { MatrixBlock m1 = MatrixBlock.randOperations(m, k, sp[0], 1, 1, "uniform", 1); MatrixBlock m2 = MatrixBlock.randOperations(k, n, sp[1], 1, 1, "uniform", 2); diff --git a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductTest.java b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductTest.java index 2a898f9c39f..678c5daa31a 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductTest.java @@ -25,6 +25,7 @@ import org.apache.sysds.hops.estim.EstimatorBitsetMM; import org.apache.sysds.hops.estim.EstimatorDensityMap; import org.apache.sysds.hops.estim.EstimatorMatrixHistogram; +import org.apache.sysds.hops.estim.EstimatorRowWise; import org.apache.sysds.hops.estim.EstimatorLayeredGraph; import org.apache.sysds.hops.estim.EstimatorSample; import org.apache.sysds.hops.estim.SparsityEstimator; @@ -154,7 +155,17 @@ public void testLayeredGraphCase1() { public void testLayeredGraphCase2() { runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, case2); } - + + @Test + public void testRowWiseCase1() { + runSparsityEstimateTest(new EstimatorRowWise(), m, k, n, case1); + } + + @Test + public void testRowWiseCase2() { + runSparsityEstimateTest(new EstimatorRowWise(), m, k, n, case2); + } + private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int k, int n, double[] sp) { MatrixBlock m1 = MatrixBlock.randOperations(m, k, sp[0], 1, 1, "uniform", 3); MatrixBlock m2 = MatrixBlock.randOperations(k, n, sp[1], 1, 1, "uniform", 7);