元素代数表达式的基本表达模板
介绍和动机
表达模板 ( 在下面表示为 ET )是一种功能强大的模板元编程技术,用于加速有时非常昂贵的表达式的计算。它广泛用于不同的领域,例如在线性代数库的实现中。
对于此示例,请考虑线性代数计算的上下文。更具体地,仅涉及逐元素操作的计算。这种计算是 ET 的最基本的应用,它们可以很好地介绍 ET 如何在内部工作。
让我们来看一个激励人心的例子。考虑表达式的计算:
Vector vec_1, vec_2, vec_3;
// Initializing vec_1, vec_2 and vec_3.
Vector result = vec_1 + vec_2*vec_3;
这里为了简单起见,我假设类 Vector
和操作+(向量加:元素加操作)和操作*(这里表示向量内积:也是元素操作)都正确实现,如它们应该如何,数学上。
在不使用 ET (或其他类似技术) 的传统实现中,至少发生了五个 Vector
实例的构造以获得最终的 result
:
- 对应于
vec_1
,vec_2
和vec_3
的三个实例。 - 一个临时的
Vector
实例_tmp
,代表_tmp = vec_2*vec_3;
的结果。 - 最后正确使用返回值优化,在
result = vec_1 + _tmp;
中构建最终的result
。
使用 ET 的实现可以消除 2 中临时 Vector _tmp
的创建,因此只留下 Vector
实例的四个构造。更有趣的是,请考虑以下更复杂的表达式:
Vector result = vec_1 + (vec_2*vec3 + vec_1)*(vec_2 + vec_3*vec_1);
总共还有四个 Vector
个实例:vec_1, vec_2, vec_3
和 result
。换句话说,在此示例中,仅涉及按元素操作,保证不会从中间计算创建临时对象。
ET 如何运作
基本上,任何代数计算的 ET 都包含两个构建块:
- 纯代数表达式 ( PAE ):它们是代数表达式的代理/抽象。纯代数不进行实际计算,它们仅仅是计算工作流的抽象/建模。PAE 可以是任何代数表达式的输入或输出的模型。 PAE 的实例通常被认为是便宜的复制。
- 懒惰评估 :这是实际计算的实现。在下面的示例中,我们将看到对于仅涉及逐元素操作的表达式,延迟评估可以在最终结果的索引访问操作内实现实际计算,从而创建按需评估方案:不执行计算只有在访问/要求最终结果时。
那么,具体如何在这个例子中实现 ET ?我们现在来看看吧。
请始终考虑以下代码段:
Vector vec_1, vec_2, vec_3;
// Initializing vec_1, vec_2 and vec_3.
Vector result = vec_1 + vec_2*vec_3;
计算结果的表达式可以进一步分解为两个子表达式:
- 向量加表达式(表示为 plus_expr )
- 向量内积表达式(表示为 innerprod_expr )。
什么外星人做的是以下几点:
-
ET 不是立即计算每个子表达式,而是首先使用图形结构对整个表达式进行建模。图中的每个节点代表 PAE 。节点的边缘连接表示实际的计算流程。因此,对于上面的表达式,我们获得以下图表:
result = plus_expr( vec_1, innerprod_expr(vec_2, vec_3) ) / \ / \ / \ / innerprod_expr( vec_2, vec_3 ) / / \ / / \ / / \ vec_1 vec_2 vec_3
-
最后的计算是通过查看图层次结构来实现的 :因为这里我们只处理逐元素操作,
result
中每个索引值的计算可以独立完成 :result
的最终评估可以被懒惰地推迟到元素 - 对这个元素的明智评价 19。换句话说,由于result
,elem_res
的元素的计算可以使用vec_1
(elem_1
),vec_2
(elem_2
)和vec_3
(elem_3
)中的相应元素表示为:elem_res = elem_1 + elem_2*elem_3;
因此,不需要创建临时 Vector
来存储中间内积的结果: 一个元素的整个计算可以完全完成,并在索引访问操作中编码。
以下是实际操作中的示例代码
文件 vec.hh:std::vector 的包装器,用于在调用构造时显示日志
#ifndef EXPR_VEC
# define EXPR_VEC
# include <vector>
# include <cassert>
# include <utility>
# include <iostream>
# include <algorithm>
# include <functional>
///
/// This is a wrapper for std::vector. It's only purpose is to print out a log when a
/// vector constructions in called.
/// It wraps the indexed access operator [] and the size() method, which are
/// important for later ETs implementation.
///
// std::vector wrapper.
template<typename ScalarType> class Vector
{
public:
explicit Vector() { std::cout << "ctor called.\n"; };
explicit Vector(int size): _vec(size) { std::cout << "ctor called.\n"; };
explicit Vector(const std::vector<ScalarType> &vec): _vec(vec)
{ std::cout << "ctor called.\n"; };
Vector(const Vector<ScalarType> & vec): _vec{vec()}
{ std::cout << "copy ctor called.\n"; };
Vector(Vector<ScalarType> && vec): _vec(std::move(vec()))
{ std::cout << "move ctor called.\n"; };
Vector<ScalarType> & operator=(const Vector<ScalarType> &) = default;
Vector<ScalarType> & operator=(Vector<ScalarType> &&) = default;
decltype(auto) operator[](int indx) { return _vec[indx]; }
decltype(auto) operator[](int indx) const { return _vec[indx]; }
decltype(auto) operator()() & { return (_vec); };
decltype(auto) operator()() const & { return (_vec); };
Vector<ScalarType> && operator()() && { return std::move(*this); }
int size() const { return _vec.size(); }
private:
std::vector<ScalarType> _vec;
};
///
/// These are conventional overloads of operator + (the vector plus operation)
/// and operator * (the vector inner product operation) without using the expression
/// templates. They are later used for bench-marking purpose.
///
// + (vector plus) operator.
template<typename ScalarType>
auto operator+(const Vector<ScalarType> &lhs, const Vector<ScalarType> &rhs)
{
assert(lhs().size() == rhs().size() &&
"error: ops plus -> lhs and rhs size mismatch.");
std::vector<ScalarType> _vec;
_vec.resize(lhs().size());
std::transform(std::cbegin(lhs()), std::cend(lhs()),
std::cbegin(rhs()), std::begin(_vec),
std::plus<>());
return Vector<ScalarType>(std::move(_vec));
}
// * (vector inner product) operator.
template<typename ScalarType>
auto operator*(const Vector<ScalarType> &lhs, const Vector<ScalarType> &rhs)
{
assert(lhs().size() == rhs().size() &&
"error: ops multiplies -> lhs and rhs size mismatch.");
std::vector<ScalarType> _vec;
_vec.resize(lhs().size());
std::transform(std::cbegin(lhs()), std::cend(lhs()),
std::cbegin(rhs()), std::begin(_vec),
std::multiplies<>());
return Vector<ScalarType>(std::move(_vec));
}
#endif //!EXPR_VEC
File expr.hh:用于逐元素操作的表达式模板的实现(vector plus 和 vector inner product)
让我们把它分解成各个部分。
- 第 1 节为所有表达式实现了一个基类。它采用了奇怪的重复模板模式 ( CRTP )。
- 第 2 节实现了第一个 PAE :一个终端,它只是一个输入数据结构的包装器(const 引用),包含用于计算的实际输入值。
- 第 3 节实现了第二个 PAE : binary_operation ,它是一个稍后用于 vector_plus 和 vector_innerprod 的类模板。它由操作类型,左侧 PAE 和右侧 PAE 参数化。实际计算在索引访问运算符中编码。
- 第 4 节将 vector_plus 和 vector_innerprod 操作定义为元素操作。它还会为 PAE s 重载 operator +和* :这样这两个操作也会返回 PAE 。
#ifndef EXPR_EXPR
# define EXPR_EXPR
/// Fwd declaration.
template<typename> class Vector;
namespace expr
{
/// -----------------------------------------
///
/// Section 1.
///
/// The first section is a base class template for all kinds of expression. It
/// employs the Curiously Recurring Template Pattern, which enables its instantiation
/// to any kind of expression structure inheriting from it.
///
/// -----------------------------------------
/// Base class for all expressions.
template<typename Expr> class expr_base
{
public:
const Expr& self() const { return static_cast<const Expr&>(*this); }
Expr& self() { return static_cast<Expr&>(*this); }
protected:
explicit expr_base() {};
int size() const { return self().size_impl(); }
auto operator[](int indx) const { return self().at_impl(indx); }
auto operator()() const { return self()(); };
};
/// -----------------------------------------
///
/// The following section 2 & 3 are abstractions of pure algebraic expressions (PAE).
/// Any PAE can be converted to a real object instance using operator(): it is in
/// this conversion process, where the real computations are done.
///
/// Section 2. Terminal
///
/// A terminal is an abstraction wrapping a const reference to the Vector data
/// structure. It inherits from expr_base, therefore providing a unified interface
/// wrapping a Vector into a PAE.
///
/// It provides the size() method, indexed access through at_impl() and a conversion
/// to referenced object through () operator.
///
/// It might no be necessary for user defined data structures to have a terminal
/// wrapper, since user defined structure can inherit expr_base, therefore eliminates
/// the need to provide such terminal wrapper.
///
/// -----------------------------------------
/// Generic wrapper for underlying data structure.
template<typename DataType> class terminal: expr_base<terminal<DataType>>
{
public:
using base_type = expr_base<terminal<DataType>>;
using base_type::size;
using base_type::operator[];
friend base_type;
explicit terminal(const DataType &val): _val(val) {}
int size_impl() const { return _val.size(); };
auto at_impl(int indx) const { return _val[indx]; };
decltype(auto) operator()() const { return (_val); }
private:
const DataType &_val;
};
/// -----------------------------------------
///
/// Section 3. Binary operation expression.
///
/// This is a PAE abstraction of any binary expression. Similarly it inherits from
/// expr_base.
///
/// It provides the size() method, indexed access through at_impl() and a conversion
/// to referenced object through () operator. Each call to the at_impl() method is
/// a element wise computation.
///
/// -----------------------------------------
/// Generic wrapper for binary operations (that are element-wise).
template<typename Ops, typename lExpr, typename rExpr>
class binary_ops: public expr_base<binary_ops<Ops,lExpr,rExpr>>
{
public:
using base_type = expr_base<binary_ops<Ops,lExpr,rExpr>>;
using base_type::size;
using base_type::operator[];
friend base_type;
explicit binary_ops(const Ops &ops, const lExpr &lxpr, const rExpr &rxpr)
: _ops(ops), _lxpr(lxpr), _rxpr(rxpr) {};
int size_impl() const { return _lxpr.size(); };
/// This does the element-wise computation for index indx.
auto at_impl(int indx) const { return _ops(_lxpr[indx], _rxpr[indx]); };
/// Conversion from arbitrary expr to concrete data type. It evaluates
/// element-wise computations for all indices.
template<typename DataType> operator DataType()
{
DataType _vec(size());
for(int _ind = 0; _ind < _vec.size(); ++_ind)
_vec[_ind] = (*this)[_ind];
return _vec;
}
private: /// Ops and expr are assumed cheap to copy.
Ops _ops;
lExpr _lxpr;
rExpr _rxpr;
};
/// -----------------------------------------
/// Section 4.
///
/// The following two structs defines algebraic operations on PAEs: here only vector
/// plus and vector inner product are implemented.
///
/// First, some element-wise operations are defined : in other words, vec_plus and
/// vec_prod acts on elements in Vectors, but not whole Vectors.
///
/// Then, operator + & * are overloaded on PAEs, such that: + & * operations on PAEs
/// also return PAEs.
///
/// -----------------------------------------
/// Element-wise plus operation.
struct vec_plus_t
{
constexpr explicit vec_plus_t() = default;
template<typename LType, typename RType>
auto operator()(const LType &lhs, const RType &rhs) const
{ return lhs+rhs; }
};
/// Element-wise inner product operation.
struct vec_prod_t
{
constexpr explicit vec_prod_t() = default;
template<typename LType, typename RType>
auto operator()(const LType &lhs, const RType &rhs) const
{ return lhs*rhs; }
};
/// Constant plus and inner product operator objects.
constexpr vec_plus_t vec_plus{};
constexpr vec_prod_t vec_prod{};
/// Plus operator overload on expressions: return binary expression.
template<typename lExpr, typename rExpr>
auto operator+(const lExpr &lhs, const rExpr &rhs)
{ return binary_ops<vec_plus_t,lExpr,rExpr>(vec_plus,lhs,rhs); }
/// Inner prod operator overload on expressions: return binary expression.
template<typename lExpr, typename rExpr>
auto operator*(const lExpr &lhs, const rExpr &rhs)
{ return binary_ops<vec_prod_t,lExpr,rExpr>(vec_prod,lhs,rhs); }
} //!expr
#endif //!EXPR_EXPR
文件 main.cc:测试 src 文件
# include <chrono>
# include <iomanip>
# include <iostream>
# include "vec.hh"
# include "expr.hh"
# include "boost/core/demangle.hpp"
int main()
{
using dtype = float;
constexpr int size = 5e7;
std::vector<dtype> _vec1(size);
std::vector<dtype> _vec2(size);
std::vector<dtype> _vec3(size);
// ... Initialize vectors' contents.
Vector<dtype> vec1(std::move(_vec1));
Vector<dtype> vec2(std::move(_vec2));
Vector<dtype> vec3(std::move(_vec3));
unsigned long start_ms_no_ets =
std::chrono::duration_cast<std::chrono::milliseconds>
(std::chrono::system_clock::now().time_since_epoch()).count();
std::cout << "\nNo-ETs evaluation starts.\n";
Vector<dtype> result_no_ets = vec1 + (vec2*vec3);
unsigned long stop_ms_no_ets =
std::chrono::duration_cast<std::chrono::milliseconds>
(std::chrono::system_clock::now().time_since_epoch()).count();
std::cout << std::setprecision(6) << std::fixed
<< "No-ETs. Time eclapses: " << (stop_ms_no_ets-start_ms_no_ets)/1000.0
<< " s.\n" << std::endl;
unsigned long start_ms_ets =
std::chrono::duration_cast<std::chrono::milliseconds>
(std::chrono::system_clock::now().time_since_epoch()).count();
std::cout << "Evaluation using ETs starts.\n";
expr::terminal<Vector<dtype>> vec4(vec1);
expr::terminal<Vector<dtype>> vec5(vec2);
expr::terminal<Vector<dtype>> vec6(vec3);
Vector<dtype> result_ets = (vec4 + vec5*vec6);
unsigned long stop_ms_ets =
std::chrono::duration_cast<std::chrono::milliseconds>
(std::chrono::system_clock::now().time_since_epoch()).count();
std::cout << std::setprecision(6) << std::fixed
<< "With ETs. Time eclapses: " << (stop_ms_ets-start_ms_ets)/1000.0
<< " s.\n" << std::endl;
auto ets_ret_type = (vec4 + vec5*vec6);
std::cout << "\nETs result's type:\n";
std::cout << boost::core::demangle( typeid(decltype(ets_ret_type)).name() ) << '\n';
return 0;
}
使用 GCC 5.3 使用 -O3 -std=c++14
编译时,这是一个可能的输出:
ctor called.
ctor called.
ctor called.
No-ETs evaluation starts.
ctor called.
ctor called.
No-ETs. Time eclapses: 0.571000 s.
Evaluation using ETs starts.
ctor called.
With ETs. Time eclapses: 0.164000 s.
ETs result's type:
expr::binary_ops<expr::vec_plus_t, expr::terminal<Vector<float> >, expr::binary_ops<expr::vec_prod_t, expr::terminal<Vector<float> >, expr::terminal<Vector<float> > > >
观察结果如下:
- **在这种情况下,**使用 ET 可以实现相当显着的性能提升 (> 3x)。 ****
- 消除了临时 Vector 对象的创建。与 ETs 一样,ctor 只被调用一次。
- Boost::demangle 用于可视化转换前 ET 返回的类型:它清楚地构建了与上面演示的完全相同的表达图。
缺点和警告
-
ET 的一个明显缺点是学习曲线,实施的复杂性和代码维护难度。在上面仅考虑元素操作的示例中,实现包含了大量的样板,更不用说在现实世界中,每个计算中都会出现更复杂的代数表达式,并且元素方面的独立性不再成立(例如矩阵乘法) ),难度将是指数级的。
-
使用 ET 的另一个警告是它们与
auto
关键字配合得很好。如上所述, PAE 本质上是代理:并且代理基本上不能与auto
一起使用。请考虑以下示例:auto result = ...; // Some expensive expression: // auto returns the expr graph, // NOT the computed value. for(auto i = 0; i < 100; ++i) ScalrType value = result* ... // Some other expensive computations using result.
在 for 循环的每次迭代中,将重新计算结果,因为表达式图形而不是计算值被传递给 for 循环。
实现 ET 的现有库 ****
- boost::proto 是一个功能强大的库,允许你为自己的表达式定义自己的规则和语法,并使用 ET 执行。
- Eigen 是一个线性代数库,可以使用 ET 有效地实现各种代数计算。