说明表达模板的基本示例

表达式模板是一种主要用于科学计算的编译时优化技术。它的主要目的是避免不必要的临时性并使用单次通过优化循环计算(通常在对数字聚合执行操作时)。最初设计表达模板是为了在实现数字 ArrayMatrix 类型时避免天真运算符重载的低效率。Bjarne Stroustrup 引入了表达模板的等价术语,他在他的书“C++编程语言”的最新版本中将其称为融合操作

在真正深入研究表达模板之前,你应该首先了解为什么需要它们。为了说明这一点,请考虑下面给出的非常简单的 Matrix 类:

template <typename T, std::size_t COL, std::size_t ROW>
class Matrix {
public:
    using value_type = T;

    Matrix() : values(COL * ROW) {}

    static size_t cols() { return COL; }
    static size_t rows() { return ROW; }

    const T& operator()(size_t x, size_t y) const { return values[y * COL + x]; }
    T& operator()(size_t x, size_t y) { return values[y * COL + x]; }

private:
    std::vector<T> values;
};

template <typename T, std::size_t COL, std::size_t ROW>
Matrix<T, COL, ROW>
operator+(const Matrix<T, COL, ROW>& lhs, const Matrix<T, COL, ROW>& rhs)
{
    Matrix<T, COL, ROW> result;

    for (size_t y = 0; y != lhs.rows(); ++y) {
        for (size_t x = 0; x != lhs.cols(); ++x) {
            result(x, y) = lhs(x, y) + rhs(x, y);
        }
    }
    return result;
}

给定以前的类定义,你现在可以编写 Matrix 表达式,例如:

const std::size_t cols = 2000;
const std::size_t rows = 1000;

Matrix<double, cols, rows> a, b, c;

// initialize a, b & c
for (std::size_t y = 0; y != rows; ++y) {
    for (std::size_t x = 0; x != cols; ++x) {
        a(x, y) = 1.0;
        b(x, y) = 2.0;
        c(x, y) = 3.0;
    }
}  

Matrix<double, cols, rows> d = a + b + c;  // d(x, y) = 6 

如上图所示,能够超载 operator+() 为你提供了一种模仿矩阵的自然数学符号的符号。

不幸的是,与等效的手工制作版本相比,先前的实现效率也非常低。

要理解为什么,你必须考虑当你写一个像 Matrix d = a + b + c 这样的表达时会发生什么。这实际上扩展到 ((a + b) + c)operator+(operator+(a, b), c)。换句话说,operator+() 内的循环执行两次,而它可以在一次通过中轻松执行。这也导致产生 2 个临时值,这进一步降低了性能。实质上,通过添加灵活性来使用接近其数学对应的符号,你还使得 Matrix 类效率非常低。

例如,如果没有运算符重载,你可以使用单个传递实现更高效的矩阵求和:

template<typename T, std::size_t COL, std::size_t ROW>
Matrix<T, COL, ROW> add3(const Matrix<T, COL, ROW>& a,
                         const Matrix<T, COL, ROW>& b,
                         const Matrix<T, COL, ROW>& c)
{
    Matrix<T, COL, ROW> result;
    for (size_t y = 0; y != ROW; ++y) {
        for (size_t x = 0; x != COL; ++x) {
            result(x, y) = a(x, y) + b(x, y) + c(x, y);
        }
    }
    return result;
}

然而,前面的示例有其自身的缺点,因为它为 Matrix 类创建了更复杂的接口(你必须考虑诸如 Matrix::add2()Matrix::AddMultiply() 等方法)。

相反,让我们退后一步,看看我们如何调整运算符重载以更有效的方式执行

问题源于这样一个事实,即在你有机会构建整个表达式树之前,表达式 Matrix d = a + b + c 的评估过于急切。换句话说,你真正想要实现的是在一次通过中评估 a + b + c,并且只有在你真正需要将结果表达式分配给 d 时。

这是表达式模板背后的核心思想:不是让 operator+() 立即评估添加两个 Matrix 实例的结果,而是在构建整个表达式树之后,它将返回一个 表达式模板 以供将来评估。

例如,以下是对应于 2 种类型总和的表达式模板的可能实现:

template <typename LHS, typename RHS>
class MatrixSum
{
public:
    using value_type = typename LHS::value_type;

    MatrixSum(const LHS& lhs, const RHS& rhs) : rhs(rhs), lhs(lhs) {}
    
    value_type operator() (int x, int y) const  {
        return lhs(x, y) + rhs(x, y);
    }
private:
    const LHS& lhs;
    const RHS& rhs;
};

这是 operator+() 的更新版本

template <typename LHS, typename RHS>
MatrixSum<LHS, RHS> operator+(const LHS& lhs, const LHS& rhs) {
    return MatrixSum<LHS, RHS>(lhs, rhs);
}

如你所见,operator+() 不再返回添加 2 个 Matrix 实例(这将是另一个 Matrix 实例)的结果的急切评估,而是返回表示添加操作的表达式模板。要记住的最重要的一点是表达式尚未被评估。它仅包含对其操作数的引用。

实际上,没有什么可以阻止你实例化 MatrixSum<> 表达式模板,如下所示:

MatrixSum<Matrix<double>, Matrix<double> > SumAB(a, b);

但是,在稍后阶段,当你确实需要求和的结果时,请按如下方式评估表达式 d = a + b

for (std::size_t y = 0; y != a.rows(); ++y) {
    for (std::size_t x = 0; x != a.cols(); ++x) {
        d(x, y) = SumAB(x, y);
    }
}

正如你所看到的,使用表达式模板的另一个好处是,你基本上已经设法评估 ab 的总和,并在一次传递中将其分配给 d

此外,没有什么能阻止你组合多个表达式模板。例如,a + b + c 将生成以下表达式模板:

MatrixSum<MatrixSum<Matrix<double>, Matrix<double> >, Matrix<double> > SumABC(SumAB, c);

再次,你可以使用一次通过评估最终结果:

for (std::size_t y = 0; y != a.rows(); ++y) {
    for (std::size_t x = 0; x != a.cols(); ++x) {
        d(x, y) = SumABC(x, y);
    }
}

最后,拼图的最后一部分是将表达式模板实际插入 Matrix 类。这主要是通过为 Matrix::operator=() 提供一个实现来实现的,Matrix::operator=() 将表达式模板作为参数并在一次传递中对其进行评估,就像之前手动一样:

template <typename T, std::size_t COL, std::size_t ROW>
class Matrix {
public:
    using value_type = T;

    Matrix() : values(COL * ROW) {}

    static size_t cols() { return COL; }
    static size_t rows() { return ROW; }

    const T& operator()(size_t x, size_t y) const { return values[y * COL + x]; }
    T& operator()(size_t x, size_t y) { return values[y * COL + x]; }

    template <typename E>
    Matrix<T, COL, ROW>& operator=(const E& expression) {
        for (std::size_t y = 0; y != rows(); ++y) {
            for (std::size_t x = 0; x != cols(); ++x) {
                values[y * COL + x] = expression(x, y);
            }
        }
        return *this;
    }

private:
    std::vector<T> values;
};