基于MPI的大规模方矩阵乘法

最近写 **libgwmodel** 的 MPI 模式的时候遇到了一个问题:在 MGWR 里面,最令人头疼的就是计算 $S$ 这个 $n \times n$ 的矩阵,几个 $S$ 矩阵相乘极其容易造成性能瓶颈,也对内存造成了很大的压力。一组2万多样本11个变量的数据,就大概需要 83G 的内存才行。这就使得 MPI 模式下 MGWR 的并行计算出现了一个问题:能否将 $S$ 矩阵分散在各个进程中分块式存储?由此引发的问题是,如何让分块矩阵的加减乘也以分块的模式在不同进程中进行?

矩阵加减是逐元素的(element-wise),因此各个进程各自执行加减即可。唯独矩阵乘法非常复杂,每个进程都需要用到其他所有进程中的数据。这给我们带来一个挑战:**实现基于MPI的大规模方矩阵乘法**。

# 原理分析

假设现在有两个 $n\times n$ 的方阵 $A$ 和 $B$ ,我们要计算 $C=A \times B$ 。现在已经21世纪20年代了,我们没必要从计算 $C$ 的每个元素开始写起,有很多库已经帮我们解决了。我们只需要将这些矩阵理解成是分块方式存储的,研究如何计算矩阵块之间的运算即可。也就是说

$$
A=\pmatrix{A_1\\A_2\\\vdots\\A_P},B=\pmatrix{B_1\\B_2\\\vdots\\B_P}
$$

每个进程 $p(p=1,2,\cdots,P)$ 中存储一个 $A_p$ 和 $B_p$ ,且 $C$ 也是分块存储的,那么

$$
C=\pmatrix{C_1\\C_2\\\vdots\\C_P}
=\pmatrix{A_1\\A_2\\\vdots\\A_P} \times B
=\pmatrix{A_1B\\A_2B\\\vdots\\A_PB}
$$

所以 $C_p=A_pB$ 。现在的问题是,如何计算 $A_pB$ 呢?

因为 $B$ 是分块存储的,所以我们需要知道 $B_p$ 究竟是和 $A_p$ 的什么部分相乘的。进一步分解矩阵,可以发现

$$
\begin{aligned}
A_pB & = \pmatrix{A_{p1}&A_{p2}&\cdots&A_{pP}}\pmatrix{B_1\\B_2\\\vdots\\B_P}\\
&=\sum_{q=1}^P A_{pq}B_q
\end{aligned}
$$

由于每个 $A_{pq}$ 的行数是相同的,每个 $B_q$ 的 列数是相同的,所以 $A_{pq}B_q$ 的形状是相同的。

# 算法思路

根据上述原理,我们可以有这样一个思路:每个进程将自己的矩阵 $A_p$ 分发给所有进程(包括自己),所有进程得到 $A_{pq}$ 后与自己的 $B_q$ 矩阵相乘,并将所有进程各自计算得到的 $A_{pq}B_q$ 矩阵相加,就得到了 $C_p$ 矩阵,存放在进程 $p$ 的适当位置上。

从上面的思路可以发现,这是一个类似于 Map-Reduce 的模式。MPI 中有两个接口用于实现这一过程,一个是 `MPI_Scatter`/`MPI_Scatterv` ,另一个是 `MPI_Reduce` 。

```cpp
int MPI_Scatter(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
     void *recvbuf, int recvcount, MPI_Datatype recvtype, int root,
     MPI_Comm comm)

int MPI_Scatterv(const void *sendbuf, const int sendcounts[], const int displs[],
     MPI_Datatype sendtype, void *recvbuf, int recvcount,
     MPI_Datatype recvtype, int root, MPI_Comm comm)

int MPI_Reduce(const void *sendbuf, void *recvbuf, int count,
               MPI_Datatype datatype, MPI_Op op, int root,
               MPI_Comm comm)
```

接口 `MPI_Scatter`/`MPI_Scatterv` 用于将 `root` 进程中的一段数据分段发送给其他进程(包括跟),每个进程根据其 `rank` 得到对应的片段。前者每个进程收到的数据量是一样的,后者可以是不一样的。先说 `MPI_Scatter`,如果我们将下面的数据分发给4个进程 P0, P1, P2, P3

$$
\begin{bmatrix}
a_0 & a_1 & a_2 & a_3
\end{bmatrix}
$$

那么 P0 得到 $a_0$,P1 得到 $a_1$,以此类推。函数中的 `sendcount` 和 `recvcount` 分别是**实际向每个进程发送**和**实际每个进程接收**的数据数量,而不是发送或者接收的总量。有一个很现实的问题是,如果进程有3个,数据却有4个怎么办?也就是说如果数据量不是进程数的整数倍,怎么办?这时就需要用到 `MPI_Scatterv` 了。

接口 `MPI_Scatterv` 中的 `sendcounts[]` 是一个数组,因此可以自定义给每个进程发送的数据量,`displs[]` 用于指定给每个进程发送数据的起始偏移量。例如,假设还是给3个进程发上述的4个数据,并且我们想前两个进程得到1个,第三个进程得到2个数据,那么

```cpp
sendcounts[] = {1, 1, 2}
displs[] = {0, 1, 2}
```

但是实际编码过程中,一般是从各个进程中获取元素数量,并计算偏移量的。获取可以使用 `MPI_Gather` 函数,其与 `MPI_Scatter` 互为逆操作,也就是从各个进程中收集一定量的数据并组装在一起。其参数和 `MPI_Scatter` 几乎完全一样。但是该函数只会将数据收集到 `root` 指定的进程中。如果需要让所有进程都得到各个进程的元素数量,可以在 `MPI_Gather` 之后通过 `MPI_Bcast` 同步,也可以用 `MPI_Allgather` 函数,两种方式几乎是等价的。

当各个进程计算完毕,就可以用 `MPI_Reduce` 进行归约,得到我们想要的结果。进程间能归约的数据一定是形状一样的,也就是可以进行逐元素操作的数据。比如 P0 和 P1 中分别有 $[a_0,a_1]$ 和 $[b_0,b_1]$ ,那么这两个数组可以归约得到 $[a_0+b_0,a_1+b_1]$ 的。但如果两个进程中的数据是$[a_0, a_1]$ 和 $[b_0,b_1,b_2]$ ,那只能归约前两个元素。接口参数中的 `count` 就是用来指定归约元素量的。

# 算法实现

这里我们用 Armadillo 库作为矩阵计算的基础库进行矩阵块之间的运算。这个库默认是列优先存储的,比较有利于我们分散 $A_p$ 矩阵进行传输,因为我们要分散矩阵 $A_p$ 的列并进行分发的。

```cpp
#include <mpi.h>
#include <armadillo>

using namespace std;
using namespace arma;

void mat_mul_mpi(mat& a, mat& b, mat& c, const int ip, const int np, const size_t range)
{
    auto m = a.n_rows, n = b.n_cols;
    arma::uvec b_rows(np, arma::fill::zeros);
    MPI_Allgather(&b.n_rows, 1, MPI_UNSIGNED_LONG_LONG, b_rows.memptr(), 1, MPI_UNSIGNED_LONG_LONG, MPI_COMM_WORLD);
    c = mat(m, n, arma::fill::zeros);
    mat a_buf;
    for (size_t pi = 0; pi < np; pi++)
    {
        arma::Col<int> a_counts = b_rows(pi) * arma::conv_to<arma::Col<int>>::from(b_rows);
        arma::Col<int> a_disp = arma::cumsum(a_counts) - a_counts;
        a_buf.resize(b_rows(pi), b_rows(ip));
        MPI_Scatterv(a.memptr(), a_counts.mem, a_disp.mem, MPI_DOUBLE, a_buf.memptr(), a_buf.n_elem, MPI_DOUBLE, pi, MPI_COMM_WORLD);
        mat ci = a_buf * b;
        MPI_Reduce(ci.memptr(), c.memptr(), ci.n_elem, MPI_DOUBLE, MPI_SUM, pi, MPI_COMM_WORLD);
    }
}
```