說明表達模板的基本示例
表示式模板是一種主要用於科學計算的編譯時優化技術。它的主要目的是避免不必要的臨時性並使用單次通過優化迴圈計算(通常在對數字聚合執行操作時)。最初設計表達模板是為了在實現數字 Array
或 Matrix
型別時避免天真運算子過載的低效率。Bjarne Stroustrup 引入了表達模板的等價術語,他在他的書“C++程式語言”的最新版本中將其稱為融合操作。
在真正深入研究表達模板之前,你應該首先了解為什麼需要它們。為了說明這一點,請考慮下面給出的非常簡單的 Matrix 類:
template <typename T, std::size_t COL, std::size_t ROW>
class Matrix {
public:
using value_type = T;
Matrix() : values(COL * ROW) {}
static size_t cols() { return COL; }
static size_t rows() { return ROW; }
const T& operator()(size_t x, size_t y) const { return values[y * COL + x]; }
T& operator()(size_t x, size_t y) { return values[y * COL + x]; }
private:
std::vector<T> values;
};
template <typename T, std::size_t COL, std::size_t ROW>
Matrix<T, COL, ROW>
operator+(const Matrix<T, COL, ROW>& lhs, const Matrix<T, COL, ROW>& rhs)
{
Matrix<T, COL, ROW> result;
for (size_t y = 0; y != lhs.rows(); ++y) {
for (size_t x = 0; x != lhs.cols(); ++x) {
result(x, y) = lhs(x, y) + rhs(x, y);
}
}
return result;
}
給定以前的類定義,你現在可以編寫 Matrix 表示式,例如:
const std::size_t cols = 2000;
const std::size_t rows = 1000;
Matrix<double, cols, rows> a, b, c;
// initialize a, b & c
for (std::size_t y = 0; y != rows; ++y) {
for (std::size_t x = 0; x != cols; ++x) {
a(x, y) = 1.0;
b(x, y) = 2.0;
c(x, y) = 3.0;
}
}
Matrix<double, cols, rows> d = a + b + c; // d(x, y) = 6
如上圖所示,能夠超載 operator+()
為你提供了一種模仿矩陣的自然數學符號的符號。
不幸的是,與等效的手工製作版本相比,先前的實現效率也非常低。
要理解為什麼,你必須考慮當你寫一個像 Matrix d = a + b + c
這樣的表達時會發生什麼。這實際上擴充套件到 ((a + b) + c)
或 operator+(operator+(a, b), c)
。換句話說,operator+()
內的迴圈執行兩次,而它可以在一次通過中輕鬆執行。這也導致產生 2 個臨時值,這進一步降低了效能。實質上,通過新增靈活性來使用接近其數學對應的符號,你還使得 Matrix
類效率非常低。
例如,如果沒有運算子過載,你可以使用單個傳遞實現更高效的矩陣求和:
template<typename T, std::size_t COL, std::size_t ROW>
Matrix<T, COL, ROW> add3(const Matrix<T, COL, ROW>& a,
const Matrix<T, COL, ROW>& b,
const Matrix<T, COL, ROW>& c)
{
Matrix<T, COL, ROW> result;
for (size_t y = 0; y != ROW; ++y) {
for (size_t x = 0; x != COL; ++x) {
result(x, y) = a(x, y) + b(x, y) + c(x, y);
}
}
return result;
}
然而,前面的示例有其自身的缺點,因為它為 Matrix 類建立了更復雜的介面(你必須考慮諸如 Matrix::add2()
,Matrix::AddMultiply()
等方法)。
相反,讓我們退後一步,看看我們如何調整運算子過載以更有效的方式執行
問題源於這樣一個事實,即在你有機會構建整個表示式樹之前,表示式 Matrix d = a + b + c
的評估過於急切。換句話說,你真正想要實現的是在一次通過中評估 a + b + c
,並且只有在你真正需要將結果表示式分配給 d
時。
這是表示式模板背後的核心思想:不是讓 operator+()
立即評估新增兩個 Matrix 例項的結果,而是在構建整個表示式樹之後,它將返回一個 表示式模板 以供將來評估。
例如,以下是對應於 2 種型別總和的表示式模板的可能實現:
template <typename LHS, typename RHS>
class MatrixSum
{
public:
using value_type = typename LHS::value_type;
MatrixSum(const LHS& lhs, const RHS& rhs) : rhs(rhs), lhs(lhs) {}
value_type operator() (int x, int y) const {
return lhs(x, y) + rhs(x, y);
}
private:
const LHS& lhs;
const RHS& rhs;
};
這是 operator+()
的更新版本
template <typename LHS, typename RHS>
MatrixSum<LHS, RHS> operator+(const LHS& lhs, const LHS& rhs) {
return MatrixSum<LHS, RHS>(lhs, rhs);
}
如你所見,operator+()
不再返回新增 2 個 Matrix 例項(這將是另一個 Matrix 例項)的結果的急切評估,而是返回表示新增操作的表示式模板。要記住的最重要的一點是表示式尚未被評估。它僅包含對其運算元的引用。
實際上,沒有什麼可以阻止你例項化 MatrixSum<>
表示式模板,如下所示:
MatrixSum<Matrix<double>, Matrix<double> > SumAB(a, b);
但是,在稍後階段,當你確實需要求和的結果時,請按如下方式評估表示式 d = a + b
:
for (std::size_t y = 0; y != a.rows(); ++y) {
for (std::size_t x = 0; x != a.cols(); ++x) {
d(x, y) = SumAB(x, y);
}
}
正如你所看到的,使用表示式模板的另一個好處是,你基本上已經設法評估 a
和 b
的總和,並在一次傳遞中將其分配給 d
。
此外,沒有什麼能阻止你組合多個表示式模板。例如,a + b + c
將生成以下表示式模板:
MatrixSum<MatrixSum<Matrix<double>, Matrix<double> >, Matrix<double> > SumABC(SumAB, c);
再次,你可以使用一次通過評估最終結果:
for (std::size_t y = 0; y != a.rows(); ++y) {
for (std::size_t x = 0; x != a.cols(); ++x) {
d(x, y) = SumABC(x, y);
}
}
最後,拼圖的最後一部分是將表示式模板實際插入 Matrix
類。這主要是通過為 Matrix::operator=()
提供一個實現來實現的,Matrix::operator=()
將表示式模板作為引數並在一次傳遞中對其進行評估,就像之前手動一樣:
template <typename T, std::size_t COL, std::size_t ROW>
class Matrix {
public:
using value_type = T;
Matrix() : values(COL * ROW) {}
static size_t cols() { return COL; }
static size_t rows() { return ROW; }
const T& operator()(size_t x, size_t y) const { return values[y * COL + x]; }
T& operator()(size_t x, size_t y) { return values[y * COL + x]; }
template <typename E>
Matrix<T, COL, ROW>& operator=(const E& expression) {
for (std::size_t y = 0; y != rows(); ++y) {
for (std::size_t x = 0; x != cols(); ++x) {
values[y * COL + x] = expression(x, y);
}
}
return *this;
}
private:
std::vector<T> values;
};