說明表達模板的基本示例

表示式模板是一種主要用於科學計算的編譯時優化技術。它的主要目的是避免不必要的臨時性並使用單次通過優化迴圈計算(通常在對數字聚合執行操作時)。最初設計表達模板是為了在實現數字 ArrayMatrix 型別時避免天真運算子過載的低效率。Bjarne Stroustrup 引入了表達模板的等價術語,他在他的書“C++程式語言”的最新版本中將其稱為融合操作

在真正深入研究表達模板之前,你應該首先了解為什麼需要它們。為了說明這一點,請考慮下面給出的非常簡單的 Matrix 類:

template <typename T, std::size_t COL, std::size_t ROW>
class Matrix {
public:
    using value_type = T;

    Matrix() : values(COL * ROW) {}

    static size_t cols() { return COL; }
    static size_t rows() { return ROW; }

    const T& operator()(size_t x, size_t y) const { return values[y * COL + x]; }
    T& operator()(size_t x, size_t y) { return values[y * COL + x]; }

private:
    std::vector<T> values;
};

template <typename T, std::size_t COL, std::size_t ROW>
Matrix<T, COL, ROW>
operator+(const Matrix<T, COL, ROW>& lhs, const Matrix<T, COL, ROW>& rhs)
{
    Matrix<T, COL, ROW> result;

    for (size_t y = 0; y != lhs.rows(); ++y) {
        for (size_t x = 0; x != lhs.cols(); ++x) {
            result(x, y) = lhs(x, y) + rhs(x, y);
        }
    }
    return result;
}

給定以前的類定義,你現在可以編寫 Matrix 表示式,例如:

const std::size_t cols = 2000;
const std::size_t rows = 1000;

Matrix<double, cols, rows> a, b, c;

// initialize a, b & c
for (std::size_t y = 0; y != rows; ++y) {
    for (std::size_t x = 0; x != cols; ++x) {
        a(x, y) = 1.0;
        b(x, y) = 2.0;
        c(x, y) = 3.0;
    }
}  

Matrix<double, cols, rows> d = a + b + c;  // d(x, y) = 6 

如上圖所示,能夠超載 operator+() 為你提供了一種模仿矩陣的自然數學符號的符號。

不幸的是,與等效的手工製作版本相比,先前的實現效率也非常低。

要理解為什麼,你必須考慮當你寫一個像 Matrix d = a + b + c 這樣的表達時會發生什麼。這實際上擴充套件到 ((a + b) + c)operator+(operator+(a, b), c)。換句話說,operator+() 內的迴圈執行兩次,而它可以在一次通過中輕鬆執行。這也導致產生 2 個臨時值,這進一步降低了效能。實質上,通過新增靈活性來使用接近其數學對應的符號,你還使得 Matrix 類效率非常低。

例如,如果沒有運算子過載,你可以使用單個傳遞實現更高效的矩陣求和:

template<typename T, std::size_t COL, std::size_t ROW>
Matrix<T, COL, ROW> add3(const Matrix<T, COL, ROW>& a,
                         const Matrix<T, COL, ROW>& b,
                         const Matrix<T, COL, ROW>& c)
{
    Matrix<T, COL, ROW> result;
    for (size_t y = 0; y != ROW; ++y) {
        for (size_t x = 0; x != COL; ++x) {
            result(x, y) = a(x, y) + b(x, y) + c(x, y);
        }
    }
    return result;
}

然而,前面的示例有其自身的缺點,因為它為 Matrix 類建立了更復雜的介面(你必須考慮諸如 Matrix::add2()Matrix::AddMultiply() 等方法)。

相反,讓我們退後一步,看看我們如何調整運算子過載以更有效的方式執行

問題源於這樣一個事實,即在你有機會構建整個表示式樹之前,表示式 Matrix d = a + b + c 的評估過於急切。換句話說,你真正想要實現的是在一次通過中評估 a + b + c,並且只有在你真正需要將結果表示式分配給 d 時。

這是表示式模板背後的核心思想:不是讓 operator+() 立即評估新增兩個 Matrix 例項的結果,而是在構建整個表示式樹之後,它將返回一個 表示式模板 以供將來評估。

例如,以下是對應於 2 種型別總和的表示式模板的可能實現:

template <typename LHS, typename RHS>
class MatrixSum
{
public:
    using value_type = typename LHS::value_type;

    MatrixSum(const LHS& lhs, const RHS& rhs) : rhs(rhs), lhs(lhs) {}
    
    value_type operator() (int x, int y) const  {
        return lhs(x, y) + rhs(x, y);
    }
private:
    const LHS& lhs;
    const RHS& rhs;
};

這是 operator+() 的更新版本

template <typename LHS, typename RHS>
MatrixSum<LHS, RHS> operator+(const LHS& lhs, const LHS& rhs) {
    return MatrixSum<LHS, RHS>(lhs, rhs);
}

如你所見,operator+() 不再返回新增 2 個 Matrix 例項(這將是另一個 Matrix 例項)的結果的急切評估,而是返回表示新增操作的表示式模板。要記住的最重要的一點是表示式尚未被評估。它僅包含對其運算元的引用。

實際上,沒有什麼可以阻止你例項化 MatrixSum<> 表示式模板,如下所示:

MatrixSum<Matrix<double>, Matrix<double> > SumAB(a, b);

但是,在稍後階段,當你確實需要求和的結果時,請按如下方式評估表示式 d = a + b

for (std::size_t y = 0; y != a.rows(); ++y) {
    for (std::size_t x = 0; x != a.cols(); ++x) {
        d(x, y) = SumAB(x, y);
    }
}

正如你所看到的,使用表示式模板的另一個好處是,你基本上已經設法評估 ab 的總和,並在一次傳遞中將其分配給 d

此外,沒有什麼能阻止你組合多個表示式模板。例如,a + b + c 將生成以下表示式模板:

MatrixSum<MatrixSum<Matrix<double>, Matrix<double> >, Matrix<double> > SumABC(SumAB, c);

再次,你可以使用一次通過評估最終結果:

for (std::size_t y = 0; y != a.rows(); ++y) {
    for (std::size_t x = 0; x != a.cols(); ++x) {
        d(x, y) = SumABC(x, y);
    }
}

最後,拼圖的最後一部分是將表示式模板實際插入 Matrix 類。這主要是通過為 Matrix::operator=() 提供一個實現來實現的,Matrix::operator=() 將表示式模板作為引數並在一次傳遞中對其進行評估,就像之前手動一樣:

template <typename T, std::size_t COL, std::size_t ROW>
class Matrix {
public:
    using value_type = T;

    Matrix() : values(COL * ROW) {}

    static size_t cols() { return COL; }
    static size_t rows() { return ROW; }

    const T& operator()(size_t x, size_t y) const { return values[y * COL + x]; }
    T& operator()(size_t x, size_t y) { return values[y * COL + x]; }

    template <typename E>
    Matrix<T, COL, ROW>& operator=(const E& expression) {
        for (std::size_t y = 0; y != rows(); ++y) {
            for (std::size_t x = 0; x != cols(); ++x) {
                values[y * COL + x] = expression(x, y);
            }
        }
        return *this;
    }

private:
    std::vector<T> values;
};