说明表达模板的基本示例
表达式模板是一种主要用于科学计算的编译时优化技术。它的主要目的是避免不必要的临时性并使用单次通过优化循环计算(通常在对数字聚合执行操作时)。最初设计表达模板是为了在实现数字 Array
或 Matrix
类型时避免天真运算符重载的低效率。Bjarne Stroustrup 引入了表达模板的等价术语,他在他的书“C++编程语言”的最新版本中将其称为融合操作。
在真正深入研究表达模板之前,你应该首先了解为什么需要它们。为了说明这一点,请考虑下面给出的非常简单的 Matrix 类:
template <typename T, std::size_t COL, std::size_t ROW>
class Matrix {
public:
using value_type = T;
Matrix() : values(COL * ROW) {}
static size_t cols() { return COL; }
static size_t rows() { return ROW; }
const T& operator()(size_t x, size_t y) const { return values[y * COL + x]; }
T& operator()(size_t x, size_t y) { return values[y * COL + x]; }
private:
std::vector<T> values;
};
template <typename T, std::size_t COL, std::size_t ROW>
Matrix<T, COL, ROW>
operator+(const Matrix<T, COL, ROW>& lhs, const Matrix<T, COL, ROW>& rhs)
{
Matrix<T, COL, ROW> result;
for (size_t y = 0; y != lhs.rows(); ++y) {
for (size_t x = 0; x != lhs.cols(); ++x) {
result(x, y) = lhs(x, y) + rhs(x, y);
}
}
return result;
}
给定以前的类定义,你现在可以编写 Matrix 表达式,例如:
const std::size_t cols = 2000;
const std::size_t rows = 1000;
Matrix<double, cols, rows> a, b, c;
// initialize a, b & c
for (std::size_t y = 0; y != rows; ++y) {
for (std::size_t x = 0; x != cols; ++x) {
a(x, y) = 1.0;
b(x, y) = 2.0;
c(x, y) = 3.0;
}
}
Matrix<double, cols, rows> d = a + b + c; // d(x, y) = 6
如上图所示,能够超载 operator+()
为你提供了一种模仿矩阵的自然数学符号的符号。
不幸的是,与等效的手工制作版本相比,先前的实现效率也非常低。
要理解为什么,你必须考虑当你写一个像 Matrix d = a + b + c
这样的表达时会发生什么。这实际上扩展到 ((a + b) + c)
或 operator+(operator+(a, b), c)
。换句话说,operator+()
内的循环执行两次,而它可以在一次通过中轻松执行。这也导致产生 2 个临时值,这进一步降低了性能。实质上,通过添加灵活性来使用接近其数学对应的符号,你还使得 Matrix
类效率非常低。
例如,如果没有运算符重载,你可以使用单个传递实现更高效的矩阵求和:
template<typename T, std::size_t COL, std::size_t ROW>
Matrix<T, COL, ROW> add3(const Matrix<T, COL, ROW>& a,
const Matrix<T, COL, ROW>& b,
const Matrix<T, COL, ROW>& c)
{
Matrix<T, COL, ROW> result;
for (size_t y = 0; y != ROW; ++y) {
for (size_t x = 0; x != COL; ++x) {
result(x, y) = a(x, y) + b(x, y) + c(x, y);
}
}
return result;
}
然而,前面的示例有其自身的缺点,因为它为 Matrix 类创建了更复杂的接口(你必须考虑诸如 Matrix::add2()
,Matrix::AddMultiply()
等方法)。
相反,让我们退后一步,看看我们如何调整运算符重载以更有效的方式执行
问题源于这样一个事实,即在你有机会构建整个表达式树之前,表达式 Matrix d = a + b + c
的评估过于急切。换句话说,你真正想要实现的是在一次通过中评估 a + b + c
,并且只有在你真正需要将结果表达式分配给 d
时。
这是表达式模板背后的核心思想:不是让 operator+()
立即评估添加两个 Matrix 实例的结果,而是在构建整个表达式树之后,它将返回一个 表达式模板 以供将来评估。
例如,以下是对应于 2 种类型总和的表达式模板的可能实现:
template <typename LHS, typename RHS>
class MatrixSum
{
public:
using value_type = typename LHS::value_type;
MatrixSum(const LHS& lhs, const RHS& rhs) : rhs(rhs), lhs(lhs) {}
value_type operator() (int x, int y) const {
return lhs(x, y) + rhs(x, y);
}
private:
const LHS& lhs;
const RHS& rhs;
};
这是 operator+()
的更新版本
template <typename LHS, typename RHS>
MatrixSum<LHS, RHS> operator+(const LHS& lhs, const LHS& rhs) {
return MatrixSum<LHS, RHS>(lhs, rhs);
}
如你所见,operator+()
不再返回添加 2 个 Matrix 实例(这将是另一个 Matrix 实例)的结果的急切评估,而是返回表示添加操作的表达式模板。要记住的最重要的一点是表达式尚未被评估。它仅包含对其操作数的引用。
实际上,没有什么可以阻止你实例化 MatrixSum<>
表达式模板,如下所示:
MatrixSum<Matrix<double>, Matrix<double> > SumAB(a, b);
但是,在稍后阶段,当你确实需要求和的结果时,请按如下方式评估表达式 d = a + b
:
for (std::size_t y = 0; y != a.rows(); ++y) {
for (std::size_t x = 0; x != a.cols(); ++x) {
d(x, y) = SumAB(x, y);
}
}
正如你所看到的,使用表达式模板的另一个好处是,你基本上已经设法评估 a
和 b
的总和,并在一次传递中将其分配给 d
。
此外,没有什么能阻止你组合多个表达式模板。例如,a + b + c
将生成以下表达式模板:
MatrixSum<MatrixSum<Matrix<double>, Matrix<double> >, Matrix<double> > SumABC(SumAB, c);
再次,你可以使用一次通过评估最终结果:
for (std::size_t y = 0; y != a.rows(); ++y) {
for (std::size_t x = 0; x != a.cols(); ++x) {
d(x, y) = SumABC(x, y);
}
}
最后,拼图的最后一部分是将表达式模板实际插入 Matrix
类。这主要是通过为 Matrix::operator=()
提供一个实现来实现的,Matrix::operator=()
将表达式模板作为参数并在一次传递中对其进行评估,就像之前手动一样:
template <typename T, std::size_t COL, std::size_t ROW>
class Matrix {
public:
using value_type = T;
Matrix() : values(COL * ROW) {}
static size_t cols() { return COL; }
static size_t rows() { return ROW; }
const T& operator()(size_t x, size_t y) const { return values[y * COL + x]; }
T& operator()(size_t x, size_t y) { return values[y * COL + x]; }
template <typename E>
Matrix<T, COL, ROW>& operator=(const E& expression) {
for (std::size_t y = 0; y != rows(); ++y) {
for (std::size_t x = 0; x != cols(); ++x) {
values[y * COL + x] = expression(x, y);
}
}
return *this;
}
private:
std::vector<T> values;
};