0%

OpenBLAS 中矩阵运算函数学习

GEMM 是矩阵乘法最成熟的优化计算方式,也有很多现成的优化好的库可以调用。

OpenBLAS 矩阵计算

OpenBLAS 库实现成熟优化的矩阵与矩阵乘法的函数 cblas_sgemm 和矩阵与向量乘法函数 cblas_sgemv,二者使用方法基本相同,参数较多,所以对参数的使用做个记录。

矩阵与矩阵乘法

cblas_sgemm 计算的矩阵公式:C=alpha*A*B+beta*C,其中 ABC 都是矩阵,C 初始中存放的可以是偏置值。

cblas_sgemm 函数定义:

cblas_sgemm(layout, transA, transB, M, N, K, alpha, A, LDA, B, LDB, beta, C, LDC);

  • layout:存储格式,有行主序(CblasRowMajor)和列主序(CblasColMajor),C/C++ 一般是行主序。
  • transAA 矩阵是否需要转置。
  • transBB 矩阵是否需要转置。
  • MNKA 矩阵经过 transA 之后的维度是 M*KB 矩阵经过 transB 之后的维度是 K*NC 矩阵的维度是 M*N
  • LDALDBLDC矩阵在 trans (如果需要转置)之前,在主维度方向的维度(如果是行主序,那这个参数就是列数)。

示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#include <stdio.h>
#include <cblas.h>

int main() {
int i, j;
float a[6]={1,3,5,2,7,8};
float b[6]={5,3,7,2,4,2};
float c[6]={0,0,0,0,0,0};
cblas_sgemm(CblasRowMajor, CblasTrans, CblasTrans, 3, 3, 2, 1.0, a, 3, b, 2, 0.0, c, 3);
for(i = 0; i < 3; ++i){
for(j = 0; j < 3; ++j){
printf("%f ", c[i*3+j]);
}
printf("\n");
}
return 1;
}

矩阵与向量乘法

矩阵与向量乘法本质也是矩阵与矩阵,只不过 gemvgemm 要快一些,所以有时候也需要用 gemv。计算式:C=alpha*A*b+beta*C

cblas_sgemv 函数定义:

cblas_sgemv(layout, trans, M, N, alpha, A, LDA, b, 1, beta, C, 1)

参数的定义基本和 gemm 相同,MNA 的行数和列数,bC 的列数都是 1。