10 February 2013

A note on Matrix Multiplication in Java

I needed to calculate the B AxAT for some task. Here A is a matrix of some size. Initially I considered a few open-source libs e.g. Colt, Parallel Colt and JAMA, I found JAMA to be a bit easier to use and also it had better documentation than the other two. The size of my matrix A was around 1700x2700. It used to take around 65 secs for getting matrix with JAMA. It was fine with me util recently, my matrix size grew to 7000x14000. First let me make it clear that this matrix A is sparse and does contain values between [0,2]. I tried JAMA on this new matrix and tried to calculate the matrix B, first it threw Out of Memory error, when I increased the heap memory to 4GB (max I could allocate!) it took some time in calculation and at the end again threw Out of Memory error.

I tried to run the program on a higher memory machine and allocated 10GB to heap. Here is the code for JAMA for matrix multiplication:
double[][] matrix // your pre-populated matrix
Matrix aMatrix = new Matrix(matrix);
Matrix bMatrix = aMatrix.transpose();
Matrix finalMatrix= aMatrix.times(bMatrix);
double[][] finalArray = finalMatrix.getArray(); // this step is not necessary
On this higher configuration machine the JAMA ran around 3 Hours before I killed it. I started looking for other libs and other ways to calculated this matrix in as less as possible amount of time. For me a time between 15-30 minutes was acceptable. While googling I got this link: http://stackoverflow.com/questions/529457/performance-of-java-matrix-math-libraries, I found this benchmark page: http://code.google.com/p/java-matrix-benchmark/ and decided to give EJML a try, basic code for multiplication is:
double[][] matrix // your pre-populated matrix
SimpleMatrix aMatrix = new SimpleMatrix(matrix);
SimpleMatrix cMatrix = aMatrix.mult(aMatrix.transpose());
SimpleMatrix cMatrix =  aMatrix.mult(bMatrix)
EJML calculated the smaller matrix multiplication (1700x2700) in 6-7 Secs, which was way faster compared to JAMA which took more than 60 Secs for the same calculation. So I went ahead and ran it on the bigger matrix and left it running on my home machine which has 12GB RAM and an Intel i7 Proc. When I returned from my work after 9 hours the program was still running. But after half an hour I saw the output. It took 9 hours and 30 minutes for EJML to finish this calculation and a total of 5.5 GB of RAM during this operation.

Meanwhile I was also looking at JBLAS, but due to its dependency on some native libs, I was not able to run it on production as well as on my home machine which runs Windows 8. I had to install cygwin to make it work on Win 8. I ran it against the small matrix and output appeared on the screen in no time. Total time taken for the 1700x2700 matrix was 0.6 Secs, it was exciting to see the output so quickly. Following is the code for matrix multiplication using JBLAS:
double[][] matrix // your pre-populated matrix
DoubleMatrix aDoubleMatrix = new DoubleMatrix(matrix);
DoubleMatrix bDoubleMatrix = aDoubleMatrix.transpose();
DoubleMatrix fDoubleMatrix = aDoubleMatrix.mmul(bDoubleMatrix);
I ran the program on the big matrix, and got the results immediately, in 6 Secs. It uses the native libs BLAS and LAPACK, which make it so much faster than other purely java based libs. Also it uses all the cores available on your machine. Support from the author of the library is also great, he went some extra miles to resolve the issue I was facing, when I used the lib on a Centos 6 box (due to some older GLIB installation).