1 #ifndef HALIDE_IR_MATCH_H 2 #define HALIDE_IR_MATCH_H 37 bool expr_match(
const Expr &pattern,
const Expr &expr, std::vector<Expr> &result);
51 bool expr_match(
const Expr &pattern,
const Expr &expr, std::map<std::string, Expr> &result);
139 typename =
typename std::remove_reference<T>::type::pattern_tag>
146 constexpr
static uint32_t mask = std::remove_reference<T>::type::binds;
166 const int lanes = scalar_type.
lanes;
167 scalar_type.
lanes = 1;
170 switch (scalar_type.
code) {
198 ((a.type == b.type) &&
199 (a.node_type == b.node_type) &&
216 template<u
int32_t bound>
244 template<u
int32_t bound>
246 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
249 op = ((
const Broadcast *)op)->value.get();
258 state.get_bound_const(i, val, type);
261 state.set_bound_const(i, value, e.type);
265 template<u
int32_t bound>
267 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
271 state.get_bound_const(i, val, type);
272 return type == i64_type && value == val.
u.
i64;
274 state.set_bound_const(i, value, i64_type);
295 std::ostream &operator<<(std::ostream &s, const WildConstInt<i> &c) {
310 template<u
int32_t bound>
312 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
315 op = ((
const Broadcast *)op)->value.get();
324 state.get_bound_const(i, val, type);
327 state.set_bound_const(i, value, e.type);
343 state.get_bound_const(i, val, ty);
348 std::ostream &operator<<(std::ostream &s, const WildConstUInt<i> &c) {
363 template<u
int32_t bound>
365 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
368 op = ((
const Broadcast *)op)->value.get();
373 double value = ((
const FloatImm *)op)->value;
377 state.get_bound_const(i, val, type);
380 state.set_bound_const(i, value, e.type);
396 state.get_bound_const(i, val, ty);
401 std::ostream &operator<<(std::ostream &s, const WildConstFloat<i> &c) {
417 template<u
int32_t bound>
419 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
422 op = ((
const Broadcast *)op)->value.get();
436 template<u
int32_t bound>
438 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
454 state.get_bound_const(i, val, ty);
459 std::ostream &operator<<(std::ostream &s, const WildConst<i> &c) {
475 template<u
int32_t bound>
478 return equal(*state.get_binding(i), e);
480 state.set_binding(i, e);
493 std::ostream &operator<<(std::ostream &s, const Wild<i> &op) {
515 template<u
int32_t bound>
519 op = ((
const Broadcast *)op)->value.get();
527 return ((
const FloatImm *)op)->value == (
double)
v;
533 template<u
int32_t bound>
538 template<u
int32_t bound>
562 val.u.f64 = (double)
v;
578 typename =
typename std::decay<T>::type::pattern_tag>
589 static_assert(!std::is_same<
typename std::decay<T>::type,
Expr>::value || std::is_lvalue_reference<T>::value,
590 "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
601 typename =
typename std::decay<T>::type::pattern_tag,
603 typename =
typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
618 template<
typename Op>
621 template<
typename Op>
624 template<
typename Op>
639 template<
typename Op,
typename A,
typename B>
654 A::canonical && B::canonical && (!
commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
656 template<u
int32_t bound>
658 if (e.node_type != Op::_node_type) {
661 const Op &op = (
const Op &)e;
662 return (
a.template match<bound>(*op.a.get(), state) &&
666 template<u
int32_t bound,
typename Op2,
typename A2,
typename B2>
668 return (std::is_same<Op, Op2>::value &&
669 a.template match<bound>(
unwrap(op.a), state) &&
673 constexpr
static bool foldable = A::foldable && B::foldable;
678 if (std::is_same<A, IntLiteral>::value) {
679 b.make_folded_const(val_b, ty, state);
680 if ((std::is_same<Op, And>::value && val_b.
u.
u64 == 0) ||
681 (std::is_same<Op, Or>::value && val_b.
u.
u64 == 1)) {
687 a.make_folded_const(val_a, ty, state);
690 a.make_folded_const(val_a, ty, state);
691 if ((std::is_same<Op, And>::value && val_a.
u.
u64 == 0) ||
692 (std::is_same<Op, Or>::value && val_a.
u.
u64 == 1)) {
698 b.make_folded_const(val_b, ty, state);
703 val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.
u.
i64, val_b.
u.
i64);
706 val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.
u.
u64, val_b.
u.
u64);
710 val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.
u.
f64, val_b.
u.
f64);
721 if (std::is_same<A, IntLiteral>::value) {
722 eb =
b.make(state, type_hint);
723 ea =
a.make(state, eb.
type());
725 ea =
a.make(state, type_hint);
726 eb =
b.make(state, ea.
type());
736 return Op::make(std::move(ea), std::move(eb));
740 template<
typename Op>
743 template<
typename Op>
746 template<
typename Op>
750 template<
typename Op,
typename A,
typename B>
762 (!
commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
766 template<u
int32_t bound>
768 if (e.node_type != Op::_node_type) {
771 const Op &op = (
const Op &)e;
772 return (
a.template match<bound>(*op.a.get(), state) &&
776 template<u
int32_t bound,
typename Op2,
typename A2,
typename B2>
778 return (std::is_same<Op, Op2>::value &&
779 a.template match<bound>(
unwrap(op.a), state) &&
783 constexpr
static bool foldable = A::foldable && B::foldable;
789 if (std::is_same<A, IntLiteral>::value) {
790 b.make_folded_const(val_b, ty, state);
792 a.make_folded_const(val_a, ty, state);
795 a.make_folded_const(val_a, ty, state);
797 b.make_folded_const(val_b, ty, state);
802 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
i64, val_b.
u.
i64);
805 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
u64, val_b.
u.
u64);
809 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
f64, val_b.
u.
f64);
823 if (std::is_same<A, IntLiteral>::value) {
824 eb =
b.make(state, {});
825 ea =
a.make(state, eb.
type());
827 ea =
a.make(state, {});
828 eb =
b.make(state, ea.
type());
838 return Op::make(std::move(ea), std::move(eb));
842 template<
typename A,
typename B>
843 std::ostream &operator<<(std::ostream &s, const BinOp<Add, A, B> &op) {
844 s <<
"(" << op.a <<
" + " << op.b <<
")";
848 template<
typename A,
typename B>
849 std::ostream &operator<<(std::ostream &s, const BinOp<Sub, A, B> &op) {
850 s <<
"(" << op.a <<
" - " << op.b <<
")";
854 template<
typename A,
typename B>
855 std::ostream &operator<<(std::ostream &s, const BinOp<Mul, A, B> &op) {
856 s <<
"(" << op.a <<
" * " << op.b <<
")";
860 template<
typename A,
typename B>
861 std::ostream &operator<<(std::ostream &s, const BinOp<Div, A, B> &op) {
862 s <<
"(" << op.a <<
" / " << op.b <<
")";
866 template<
typename A,
typename B>
867 std::ostream &operator<<(std::ostream &s, const BinOp<And, A, B> &op) {
868 s <<
"(" << op.a <<
" && " << op.b <<
")";
872 template<
typename A,
typename B>
873 std::ostream &operator<<(std::ostream &s, const BinOp<Or, A, B> &op) {
874 s <<
"(" << op.a <<
" || " << op.b <<
")";
878 template<
typename A,
typename B>
879 std::ostream &operator<<(std::ostream &s, const BinOp<Min, A, B> &op) {
880 s <<
"min(" << op.a <<
", " << op.b <<
")";
884 template<
typename A,
typename B>
885 std::ostream &operator<<(std::ostream &s, const BinOp<Max, A, B> &op) {
886 s <<
"max(" << op.a <<
", " << op.b <<
")";
890 template<
typename A,
typename B>
891 std::ostream &operator<<(std::ostream &s, const CmpOp<LE, A, B> &op) {
892 s <<
"(" << op.a <<
" <= " << op.b <<
")";
896 template<
typename A,
typename B>
897 std::ostream &operator<<(std::ostream &s, const CmpOp<LT, A, B> &op) {
898 s <<
"(" << op.a <<
" < " << op.b <<
")";
902 template<
typename A,
typename B>
903 std::ostream &operator<<(std::ostream &s, const CmpOp<GE, A, B> &op) {
904 s <<
"(" << op.a <<
" >= " << op.b <<
")";
908 template<
typename A,
typename B>
909 std::ostream &operator<<(std::ostream &s, const CmpOp<GT, A, B> &op) {
910 s <<
"(" << op.a <<
" > " << op.b <<
")";
914 template<
typename A,
typename B>
915 std::ostream &operator<<(std::ostream &s, const CmpOp<EQ, A, B> &op) {
916 s <<
"(" << op.a <<
" == " << op.b <<
")";
920 template<
typename A,
typename B>
921 std::ostream &operator<<(std::ostream &s, const CmpOp<NE, A, B> &op) {
922 s <<
"(" << op.a <<
" != " << op.b <<
")";
926 template<
typename A,
typename B>
927 std::ostream &operator<<(std::ostream &s, const BinOp<Mod, A, B> &op) {
928 s <<
"(" << op.a <<
" % " << op.b <<
")";
932 template<
typename A,
typename B>
934 assert_is_lvalue_if_expr<A>();
935 assert_is_lvalue_if_expr<B>();
939 template<
typename A,
typename B>
941 assert_is_lvalue_if_expr<A>();
942 assert_is_lvalue_if_expr<B>();
949 int dead_bits = 64 - t.bits;
957 return (a + b) & (ones >> (64 - t.bits));
965 template<
typename A,
typename B>
967 assert_is_lvalue_if_expr<A>();
968 assert_is_lvalue_if_expr<B>();
972 template<
typename A,
typename B>
974 assert_is_lvalue_if_expr<A>();
975 assert_is_lvalue_if_expr<B>();
983 int dead_bits = 64 - t.bits;
990 return (a - b) & (ones >> (64 - t.bits));
998 template<
typename A,
typename B>
1000 assert_is_lvalue_if_expr<A>();
1001 assert_is_lvalue_if_expr<B>();
1005 template<
typename A,
typename B>
1007 assert_is_lvalue_if_expr<A>();
1008 assert_is_lvalue_if_expr<B>();
1015 int dead_bits = 64 - t.bits;
1023 return (a * b) & (ones >> (64 - t.bits));
1031 template<
typename A,
typename B>
1033 assert_is_lvalue_if_expr<A>();
1034 assert_is_lvalue_if_expr<B>();
1038 template<
typename A,
typename B>
1058 template<
typename A,
typename B>
1060 assert_is_lvalue_if_expr<A>();
1061 assert_is_lvalue_if_expr<B>();
1065 template<
typename A,
typename B>
1067 assert_is_lvalue_if_expr<A>();
1068 assert_is_lvalue_if_expr<B>();
1087 template<
typename A,
typename B>
1089 assert_is_lvalue_if_expr<A>();
1090 assert_is_lvalue_if_expr<B>();
1109 template<
typename A,
typename B>
1111 assert_is_lvalue_if_expr<A>();
1112 assert_is_lvalue_if_expr<B>();
1131 template<
typename A,
typename B>
1136 template<
typename A,
typename B>
1156 template<
typename A,
typename B>
1161 template<
typename A,
typename B>
1181 template<
typename A,
typename B>
1186 template<
typename A,
typename B>
1206 template<
typename A,
typename B>
1211 template<
typename A,
typename B>
1231 template<
typename A,
typename B>
1236 template<
typename A,
typename B>
1256 template<
typename A,
typename B>
1261 template<
typename A,
typename B>
1281 template<
typename A,
typename B>
1286 template<
typename A,
typename B>
1307 template<
typename A,
typename B>
1312 template<
typename A,
typename B>
1337 template<
typename... Args>
1346 template<
typename... Args>
1353 return a < b ? a : b;
1356 template<
typename... Args>
1374 typename =
typename std::enable_if<(i <
sizeof...(Args))>::type>
1376 using T = decltype(std::get<i>(
args));
1377 return (std::get<i>(
args).
template match<bound>(*c.args[i].get(), state) &&
1381 template<
int i, u
int32_t binds>
1386 template<u
int32_t bound>
1394 match_args<0, bound>(0, c, state));
1398 typename =
typename std::enable_if<(i <
sizeof...(Args))>::type>
1400 s << std::get<i>(
args);
1401 if (i + 1 <
sizeof...(Args)) {
1404 print_args<i + 1>(0, s);
1413 print_args<0>(0, s);
1431 return absd(arg0, arg1);
1455 return arg0 << arg1;
1457 return arg0 >> arg1;
1491 if (arg1.
u.
i64 < 0) {
1494 val.u.i64 >>= -arg1.
u.
i64;
1497 val.u.u64 >>= -arg1.
u.
i64;
1500 val.u.u64 <<= arg1.
u.
i64;
1503 if (arg1.
u.
i64 > 0) {
1506 val.u.i64 >>= arg1.
u.
i64;
1509 val.u.u64 >>= arg1.
u.
i64;
1512 val.u.u64 <<= -arg1.
u.
i64;
1525 template<
typename... Args>
1533 template<
typename... Args>
1538 template<
typename A,
typename B>
1542 template<
typename A,
typename B>
1546 template<
typename A,
typename B>
1551 template<
typename A,
typename B>
1555 template<
typename A,
typename B>
1559 template<
typename A,
typename B>
1563 template<
typename A,
typename B>
1567 template<
typename A,
typename B>
1571 template<
typename A>
1577 template<
typename A,
typename B>
1581 template<
typename A,
typename B>
1585 template<
typename A,
typename B>
1589 template<
typename A,
typename B>
1593 template<
typename A,
typename B>
1597 template<
typename A,
typename B>
1601 template<
typename A,
typename B>
1605 template<
typename A,
typename B,
typename C>
1609 template<
typename A,
typename B,
typename C>
1614 template<
typename A>
1625 template<u
int32_t bound>
1630 const Not &op = (
const Not &)e;
1631 return (
a.template match<bound>(*op.
a.
get(), state));
1634 template<u
int32_t bound,
typename A2>
1636 return a.template match<bound>(
unwrap(op.a), state);
1646 template<
typename A1 = A>
1648 a.make_folded_const(val, ty, state);
1649 val.u.u64 = ~val.u.u64;
1654 template<
typename A>
1656 assert_is_lvalue_if_expr<A>();
1660 template<
typename A>
1662 assert_is_lvalue_if_expr<A>();
1666 template<
typename A>
1667 inline std::ostream &operator<<(std::ostream &s, const NotOp<A> &op) {
1668 s <<
"!(" << op.a <<
")";
1672 template<
typename C,
typename T,
typename F>
1684 constexpr
static bool canonical = C::canonical && T::canonical && F::canonical;
1686 template<u
int32_t bound>
1692 return (
c.template match<bound>(*op.
condition.
get(), state) &&
1696 template<u
int32_t bound,
typename C2,
typename T2,
typename F2>
1698 return (
c.template match<bound>(
unwrap(instance.c), state) &&
1705 return Select::make(
c.make(state, {}),
t.make(state, type_hint),
f.make(state, type_hint));
1708 constexpr
static bool foldable = C::foldable && T::foldable && F::foldable;
1710 template<
typename C1 = C>
1714 c.make_folded_const(c_val, c_ty, state);
1715 if ((c_val.
u.
u64 & 1) == 1) {
1716 t.make_folded_const(val, ty, state);
1718 f.make_folded_const(val, ty, state);
1724 template<
typename C,
typename T,
typename F>
1725 std::ostream &operator<<(std::ostream &s, const SelectOp<C, T, F> &op) {
1726 s <<
"select(" << op.c <<
", " << op.t <<
", " << op.f <<
")";
1730 template<
typename C,
typename T,
typename F>
1732 assert_is_lvalue_if_expr<C>();
1733 assert_is_lvalue_if_expr<T>();
1734 assert_is_lvalue_if_expr<F>();
1738 template<
typename A,
typename B>
1749 constexpr
static bool canonical = A::canonical && B::canonical;
1751 template<u
int32_t bound>
1755 if (
a.template match<bound>(*op.
value.
get(), state) &&
1756 lanes.template match<bound>(op.
lanes, state)) {
1763 template<u
int32_t bound,
typename A2,
typename B2>
1765 return (
a.template match<bound>(
unwrap(op.a), state) &&
1773 lanes.make_folded_const(lanes_val, ty, state);
1775 type_hint.
lanes /= l;
1776 Expr val =
a.make(state, type_hint);
1786 template<
typename A1 = A>
1790 lanes.make_folded_const(lanes_val, lanes_ty, state);
1792 a.make_folded_const(val, ty, state);
1797 template<
typename A,
typename B>
1798 inline std::ostream &operator<<(std::ostream &s, const BroadcastOp<A, B> &op) {
1799 s <<
"broadcast(" << op.a <<
", " << op.lanes <<
")";
1803 template<
typename A,
typename B>
1805 assert_is_lvalue_if_expr<A>();
1809 template<
typename A,
typename B,
typename C>
1821 constexpr
static bool canonical = A::canonical && B::canonical && C::canonical;
1823 template<u
int32_t bound>
1829 if (
a.template match<bound>(*op.
base.
get(), state) &&
1838 template<u
int32_t bound,
typename A2,
typename B2,
typename C2>
1840 return (
a.template match<bound>(
unwrap(op.a), state) &&
1849 lanes.make_folded_const(lanes_val, ty, state);
1851 type_hint.
lanes /= l;
1853 eb =
b.make(state, type_hint);
1854 ea =
a.make(state, eb.type());
1861 template<
typename A,
typename B,
typename C>
1862 std::ostream &operator<<(std::ostream &s, const RampOp<A, B, C> &op) {
1863 s <<
"ramp(" << op.a <<
", " << op.b <<
", " << op.lanes <<
")";
1867 template<
typename A,
typename B,
typename C>
1869 assert_is_lvalue_if_expr<A>();
1870 assert_is_lvalue_if_expr<B>();
1871 assert_is_lvalue_if_expr<C>();
1875 template<
typename A,
typename B, VectorReduce::Operator reduce_op>
1887 template<u
int32_t bound>
1891 if (op.
op == reduce_op &&
1892 a.template match<bound>(*op.
value.
get(), state) &&
1900 template<u
int32_t bound,
typename A2,
typename B2, VectorReduce::Operator reduce_op_2>
1902 return (reduce_op == reduce_op_2 &&
1903 a.template match<bound>(
unwrap(op.a), state) &&
1911 lanes.make_folded_const(lanes_val, ty, state);
1912 int l = (int)lanes_val.
u.
i64;
1919 template<
typename A,
typename B, VectorReduce::Operator reduce_op>
1920 inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp<A, B, reduce_op> &op) {
1921 s <<
"vector_reduce(" << reduce_op <<
", " << op.a <<
", " << op.lanes <<
")";
1925 template<
typename A,
typename B>
1927 assert_is_lvalue_if_expr<A>();
1931 template<
typename A,
typename B>
1933 assert_is_lvalue_if_expr<A>();
1937 template<
typename A,
typename B>
1939 assert_is_lvalue_if_expr<A>();
1943 template<
typename A,
typename B>
1945 assert_is_lvalue_if_expr<A>();
1949 template<
typename A,
typename B>
1951 assert_is_lvalue_if_expr<A>();
1955 template<
typename A>
1967 template<u
int32_t bound>
1972 const Sub &op = (
const Sub &)e;
1973 return (
a.template match<bound>(*op.
b.
get(), state) &&
1977 template<u
int32_t bound,
typename A2>
1979 return a.template match<bound>(
unwrap(p.a), state);
1984 Expr ea =
a.make(state, type_hint);
1986 return Sub::make(std::move(z), std::move(ea));
1991 template<
typename A1 = A>
1993 a.make_folded_const(val, ty, state);
1994 int dead_bits = 64 - ty.bits;
1997 if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
2006 val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
2010 val.u.f64 = -val.u.f64;
2019 template<
typename A>
2020 std::ostream &operator<<(std::ostream &s, const NegateOp<A> &op) {
2025 template<
typename A>
2027 assert_is_lvalue_if_expr<A>();
2031 template<
typename A>
2033 assert_is_lvalue_if_expr<A>();
2037 template<
typename A>
2049 template<u
int32_t bound>
2055 return (e.type ==
t &&
2056 a.template match<bound>(*op.
value.
get(), state));
2058 template<u
int32_t bound,
typename A2>
2060 return t == op.t &&
a.template match<bound>(
unwrap(op.a), state);
2065 return cast(
t,
a.make(state, {}));
2071 template<
typename A>
2072 std::ostream &operator<<(std::ostream &s, const CastOp<A> &op) {
2073 s <<
"cast(" << op.t <<
", " << op.a <<
")";
2077 template<
typename A>
2079 assert_is_lvalue_if_expr<A>();
2083 template<
typename Vec,
typename Base,
typename Str
ide,
typename Lanes>
2091 static constexpr
uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds;
2095 constexpr
static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical;
2097 template<u
int32_t bound>
2103 return v.
vectors.size() == 1 &&
2105 vec.template match<bound>(*v.
vectors[0].get(), state) &&
2115 base.make_folded_const(base_val, ty, state);
2116 int b = (int)base_val.
u.
i64;
2117 stride.make_folded_const(stride_val, ty, state);
2118 int s = (int)stride_val.
u.
i64;
2119 lanes.make_folded_const(lanes_val, ty, state);
2120 int l = (int)lanes_val.
u.
i64;
2129 static_assert(Base::foldable,
"Base of slice should consist only of operations that constant-fold");
2130 static_assert(Stride::foldable,
"Stride of slice should consist only of operations that constant-fold");
2131 static_assert(Lanes::foldable,
"Lanes of slice should consist only of operations that constant-fold");
2135 template<
typename Vec,
typename Base,
typename Str
ide,
typename Lanes>
2136 std::ostream &operator<<(std::ostream &s, const SliceOp<Vec, Base, Stride, Lanes> &op) {
2137 s <<
"slice(" << op.vec <<
", " << op.base <<
", " << op.stride <<
", " << op.lanes <<
")";
2141 template<
typename Vec,
typename Base,
typename Str
ide,
typename Lanes>
2147 template<
typename A>
2162 a.make_folded_const(c, ty, state);
2168 if (type_hint.bits) {
2172 c.
u.
f64 = (double)x;
2174 ty.
code = type_hint.code;
2175 ty.
bits = type_hint.bits;
2184 template<
typename A1 = A>
2186 a.make_folded_const(val, ty, state);
2190 template<
typename A>
2192 assert_is_lvalue_if_expr<A>();
2196 template<
typename A>
2197 std::ostream &operator<<(std::ostream &s, const Fold<A> &op) {
2198 s <<
"fold(" << op.a <<
")";
2202 template<
typename A>
2217 template<
typename A1 = A>
2219 a.make_folded_const(val, ty, state);
2227 template<
typename A>
2229 assert_is_lvalue_if_expr<A>();
2233 template<
typename A>
2234 std::ostream &operator<<(std::ostream &s, const Overflows<A> &op) {
2235 s <<
"overflows(" << op.a <<
")";
2249 template<u
int32_t bound>
2278 template<
typename A>
2295 template<
typename A1 = A>
2297 Expr e =
a.make(state, {});
2309 template<
typename A>
2311 assert_is_lvalue_if_expr<A>();
2315 template<
typename A>
2317 assert_is_lvalue_if_expr<A>();
2321 template<
typename A>
2322 std::ostream &operator<<(std::ostream &s, const IsConst<A> &op) {
2324 s <<
"is_const(" << op.a <<
")";
2326 s <<
"is_const(" << op.a <<
", " << op.v <<
")";
2331 template<
typename A,
typename Prover>
2348 Expr condition =
a.make(state, {});
2349 condition =
prover->mutate(condition,
nullptr);
2357 template<
typename A,
typename Prover>
2359 assert_is_lvalue_if_expr<A>();
2363 template<
typename A,
typename Prover>
2364 std::ostream &operator<<(std::ostream &s, const CanProve<A, Prover> &op) {
2365 s <<
"can_prove(" << op.a <<
")";
2369 template<
typename A>
2386 Type t =
a.make(state, {}).type();
2394 template<
typename A>
2396 assert_is_lvalue_if_expr<A>();
2400 template<
typename A>
2401 std::ostream &operator<<(std::ostream &s, const IsFloat<A> &op) {
2402 s <<
"is_float(" << op.a <<
")";
2406 template<
typename A>
2424 Type t =
a.make(state, {}).type();
2432 template<
typename A>
2434 assert_is_lvalue_if_expr<A>();
2438 template<
typename A>
2439 std::ostream &operator<<(std::ostream &s, const IsInt<A> &op) {
2440 s <<
"is_int(" << op.a;
2442 s <<
", " << op.bits;
2445 s <<
", " << op.lanes;
2451 template<
typename A>
2469 Type t =
a.make(state, {}).type();
2477 template<
typename A>
2479 assert_is_lvalue_if_expr<A>();
2483 template<
typename A>
2484 std::ostream &operator<<(std::ostream &s, const IsUInt<A> &op) {
2485 s <<
"is_uint(" << op.a;
2487 s <<
", " << op.bits;
2490 s <<
", " << op.lanes;
2496 template<
typename A>
2513 Type t =
a.make(state, {}).type();
2521 template<
typename A>
2523 assert_is_lvalue_if_expr<A>();
2527 template<
typename A>
2528 std::ostream &operator<<(std::ostream &s, const IsScalar<A> &op) {
2529 s <<
"is_scalar(" << op.a <<
")";
2533 template<
typename A>
2550 a.make_folded_const(val, ty, state);
2553 val.
u.
u64 = (val.
u.
u64 == max_bits);
2562 template<
typename A>
2564 assert_is_lvalue_if_expr<A>();
2568 template<
typename A>
2569 std::ostream &operator<<(std::ostream &s, const IsMaxValue<A> &op) {
2570 s <<
"is_max_value(" << op.a <<
")";
2574 template<
typename A>
2591 a.make_folded_const(val, ty, state);
2594 val.
u.
u64 = (val.
u.
u64 == min_bits);
2605 template<
typename A>
2607 assert_is_lvalue_if_expr<A>();
2611 template<
typename A>
2612 std::ostream &operator<<(std::ostream &s, const IsMinValue<A> &op) {
2613 s <<
"is_min_value(" << op.a <<
")";
2617 template<
typename A>
2634 Type t =
a.make(state, {}).type();
2642 template<
typename A>
2644 assert_is_lvalue_if_expr<A>();
2648 template<
typename A>
2649 std::ostream &operator<<(std::ostream &s, const LanesOf<A> &op) {
2650 s <<
"lanes_of(" << op.a <<
")";
2655 template<
typename Before,
2658 typename =
typename std::enable_if<std::decay<Before>::type::foldable &&
2659 std::decay<After>::type::foldable>::type>
2664 wildcard_type.lanes = output_type.lanes = 1;
2667 static std::set<uint32_t> tested;
2669 if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2674 debug(0) <<
"validate('" << before <<
"', '" << after <<
"', '" << pred <<
"', " <<
Type(wildcard_type) <<
", " <<
Type(output_type) <<
")\n";
2679 static std::mt19937_64 rng(0);
2684 for (
int trials = 0; trials < 100; trials++) {
2688 int shift = (int)(rng() & (wildcard_type.bits - 1));
2690 for (
int i = 0; i <
max_wild; i++) {
2692 switch (wildcard_type.code) {
2712 double val = ((
int64_t)(rng() & 15) - 8) / 2.0;
2714 val = ((
int64_t)(rng() & 15) - 8) / 2.0;
2728 before.make_folded_const(val_before, type, state);
2730 after.make_folded_const(val_after, type, state);
2731 lanes |= type.
lanes;
2738 switch (output_type.code) {
2753 ok &= (error < 0.01 ||
2754 val_before.
u.
u64 == val_after.
u.
u64 ||
2755 std::isnan(val_before.
u.
f64));
2763 debug(0) <<
"Fails with values:\n";
2764 for (
int i = 0; i <
max_wild; i++) {
2769 for (
int i = 0; i <
max_wild; i++) {
2774 debug(0) << val_before.
u.
u64 <<
" " << val_after.
u.
u64 <<
"\n";
2780 template<
typename Before,
2783 typename =
typename std::enable_if<!(std::decay<Before>::type::foldable &&
2784 std::decay<After>::type::foldable)>::type>
2795 template<
typename Pattern,
2796 typename =
typename enable_if_pattern<Pattern>::type>
2800 p.make_folded_const(c, ty, state);
2808 #define HALIDE_DEBUG_MATCHED_RULES 0 2809 #define HALIDE_DEBUG_UNMATCHED_RULES 0 2815 #define HALIDE_FUZZ_TEST_RULES 0 2817 template<
typename Instance>
2830 template<
typename After>
2835 template<
typename Before,
2840 static_assert((Before::binds & After::binds) == After::binds,
"Rule result uses unbound values");
2841 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2842 static_assert(After::canonical,
"RHS of rewrite rule should be in canonical form");
2843 #if HALIDE_FUZZ_TEST_RULES 2848 #if HALIDE_DEBUG_MATCHED_RULES 2853 #if HALIDE_DEBUG_UNMATCHED_RULES 2854 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2860 template<
typename Before,
2863 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2866 #if HALIDE_DEBUG_MATCHED_RULES 2871 #if HALIDE_DEBUG_UNMATCHED_RULES 2872 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2878 template<
typename Before,
2881 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2882 #if HALIDE_FUZZ_TEST_RULES 2887 #if HALIDE_DEBUG_MATCHED_RULES 2892 #if HALIDE_DEBUG_UNMATCHED_RULES 2893 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2899 template<
typename Before,
2906 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
2907 static_assert((Before::binds & After::binds) == After::binds,
"Rule result uses unbound values");
2908 static_assert((Before::binds & Predicate::binds) == Predicate::binds,
"Rule predicate uses unbound values");
2909 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2910 static_assert(After::canonical,
"RHS of rewrite rule should be in canonical form");
2912 #if HALIDE_FUZZ_TEST_RULES 2918 #if HALIDE_DEBUG_MATCHED_RULES 2919 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
2923 #if HALIDE_DEBUG_UNMATCHED_RULES 2924 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2930 template<
typename Before,
2935 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
2936 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2941 #if HALIDE_DEBUG_MATCHED_RULES 2942 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
2946 #if HALIDE_DEBUG_UNMATCHED_RULES 2947 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2953 template<
typename Before,
2958 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
2959 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2960 #if HALIDE_FUZZ_TEST_RULES 2966 #if HALIDE_DEBUG_MATCHED_RULES 2967 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
2971 #if HALIDE_DEBUG_UNMATCHED_RULES 2972 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2996 template<
typename Instance,
2997 typename =
typename enable_if_pattern<Instance>::type>
2999 return {
pattern_arg(instance), output_type, wildcard_type};
3002 template<
typename Instance,
3003 typename =
typename enable_if_pattern<Instance>::type>
3005 return {
pattern_arg(instance), output_type, output_type};
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
static constexpr IRNodeType max_node_type
static constexpr IRNodeType min_node_type
Unsigned integer constants.
static constexpr uint32_t binds
static const IRNodeType _node_type
The actual IR nodes begin here.
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
static constexpr bool canonical
static constexpr bool canonical
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
static constexpr IRNodeType min_node_type
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
constexpr bool and_reduce()
A fragment of Halide syntax.
bool is_slice() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so...
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
static constexpr IRNodeType min_node_type
static constexpr bool canonical
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr IRNodeType max_node_type
static constexpr bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static Expr make(Expr condition, Expr true_value, Expr false_value)
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr bool canonical
static constexpr bool canonical
static constexpr IRNodeType max_node_type
static constexpr bool foldable
std::tuple< Args... > args
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
static constexpr IRNodeType max_node_type
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
The difference of two expressions.
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static const IRNodeType _node_type
static constexpr IRNodeType max_node_type
A vector with 'lanes' elements, in which every element is 'value'.
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
static constexpr bool foldable
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr bool canonical
static constexpr IRNodeType max_node_type
auto widening_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static constexpr bool foldable
uint16_t lanes
How many elements in a vector.
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr bool foldable
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
Logical not - true if the expression false.
Methods to test Exprs and Stmts for equality of value.
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr bool canonical
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
static constexpr bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp< Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static constexpr bool canonical
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
static constexpr uint32_t binds
static const UIntImm * make(Type t, uint64_t value)
IEEE floating point numbers.
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static constexpr IRNodeType max_node_type
static constexpr uint32_t binds
Floating point constants.
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
static constexpr bool foldable
static constexpr bool foldable
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr IRNodeType max_node_type
static constexpr IRNodeType min_node_type
static constexpr bool canonical
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
static constexpr bool foldable
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static constexpr uint32_t binds
Expr make_zero(Type t)
Construct the representation of zero in the given type.
A linear ramp vector node.
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
static Expr make(Operator op, Expr vec, int lanes)
static constexpr bool foldable
static constexpr IRNodeType max_node_type
static const IRNodeType _node_type
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
static constexpr bool canonical
This file defines the class FunctionDAG, which is our representation of a Halide pipeline, and contains methods to using Halide's bounds tools to query properties of it.
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr IRNodeType max_node_type
const BaseExprNode & expr
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr bool foldable
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
static constexpr bool foldable
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
static constexpr IRNodeType min_node_type
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
A base class for expression nodes.
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr bool canonical
static constexpr bool foldable
static constexpr bool foldable
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
int slice_stride() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so...
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
static Expr make(Expr value, int lanes)
static constexpr bool foldable
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
static constexpr uint32_t binds
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static const IntImm * make(Type t, int64_t value)
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
halide_scalar_value_t bound_const[max_wild]
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool is_vector() const
Is this type a vector type? (lanes() != 1).
static constexpr bool foldable
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
static constexpr uint16_t special_values_mask
HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr bool canonical
static constexpr IRNodeType max_node_type
static constexpr IRNodeType min_node_type
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
Expr abs(Expr a)
Returns the absolute value of a signed integer or floating-point expression.
std::vector< Expr > vectors
static constexpr IRNodeType min_node_type
IRNodeType node_type
Each IR node subclass has a unique identifier.
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
HALIDE_ALWAYS_INLINE Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
static constexpr uint32_t binds
#define HALIDE_NEVER_INLINE
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin< decltype(pattern_arg(args))... >
uint8_t bits
The number of bits of precision of a single scalar value of this type.
static constexpr bool canonical
For optional debugging during codegen, use the debug class as follows:
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
constexpr int const_min(int a, int b)
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
halide_type_t wildcard_type
static const IRNodeType _node_type
constexpr uint32_t bitwise_or_reduce()
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
static constexpr IRNodeType max_node_type
static constexpr bool canonical
static constexpr IRNodeType min_node_type
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
static constexpr uint32_t binds
static constexpr bool foldable
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
static constexpr IRNodeType min_node_type
HALIDE_NEVER_INLINE void build_replacement(After after)
unsigned __INT32_TYPE__ uint32_t
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static constexpr IRNodeType min_node_type
union halide_scalar_value_t::@4 u
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
static constexpr IRNodeType min_node_type
To save stack space, the matcher objects are largely stateless and immutable.
static constexpr bool canonical
Not visible externally, similar to 'static' linkage in C.
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
static constexpr bool canonical
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static constexpr bool canonical
signed __INT64_TYPE__ int64_t
static constexpr bool canonical
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits...
static constexpr IRNodeType min_node_type
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr uint32_t binds
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
#define HALIDE_ALWAYS_INLINE
static constexpr uint32_t mask
static const FloatImm * make(Type t, double value)
static constexpr IRNodeType min_node_type
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
static constexpr bool foldable
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
static constexpr uint32_t binds
static const IRNodeType _node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE auto operator &&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
static constexpr bool foldable
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
A runtime tag for a type in the halide type system.
static constexpr IRNodeType max_node_type
constexpr IRNodeType StrongestExprNodeType
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
static constexpr uint32_t binds
halide_type_t output_type
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt) ...
auto shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static constexpr bool foldable
static constexpr bool foldable
static constexpr bool canonical
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
static constexpr IRNodeType max_node_type
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes...
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
static constexpr bool canonical
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp< Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
unsigned __INT16_TYPE__ uint16_t
static const IRNodeType _node_type
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin< decltype(pattern_arg(a))>
floating point numbers in the bfloat format
static constexpr uint32_t binds
Types in the halide type system.
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits=0, int lanes=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
auto halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr bool commutative(IRNodeType t)
bool is_intrinsic() const
static constexpr uint32_t binds
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE SliceOp(Vec v, Base b, Stride s, Lanes l)
static constexpr IRNodeType max_node_type
halide_type_t bound_const_type[max_wild]
static constexpr IRNodeType min_node_type
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr uint32_t binds
static constexpr bool foldable
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes...
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
static constexpr bool canonical
static Expr make(Expr a, Expr b)
static constexpr bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
static constexpr uint16_t signed_integer_overflow
Expr likely_if_innermost(Expr e)
Equivalent to likely, but only triggers a loop partitioning if found in an innermost loop...
static constexpr bool foldable
static constexpr IRNodeType max_node_type
static constexpr IRNodeType min_node_type
bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
bool sub_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits...
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr bool canonical
static constexpr bool foldable
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits=0, int lanes=0) noexcept -> IsInt< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static const IRNodeType _node_type
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static constexpr bool canonical
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static constexpr uint32_t binds
static constexpr bool foldable
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static constexpr uint32_t binds
auto shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static constexpr bool canonical
static Expr make(Expr base, Expr stride, int lanes)
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
static constexpr uint32_t binds
static constexpr bool canonical
static Expr make_slice(Expr vector, int begin, int stride, int size)
Convenience constructor for making a shuffle representing a contiguous subset of a vector...
static constexpr uint32_t binds
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept -> SliceOp< decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))>
unsigned __INT64_TYPE__ uint64_t
static constexpr bool foldable
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
HALIDE_ALWAYS_INLINE MatcherState() noexcept
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits...
T div_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
static constexpr uint32_t binds
static constexpr IRNodeType min_node_type
static constexpr bool foldable
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
static constexpr IRNodeType max_node_type
static constexpr bool foldable
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
Expr likely(Expr e)
Expressions tagged with this intrinsic are considered to be part of the steady state of some loop wit...
signed __INT32_TYPE__ int32_t
static constexpr bool canonical
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
int slice_begin() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so...
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr bool canonical
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
static constexpr bool foldable
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
Construct a new vector by taking elements from another sequence of vectors.
static constexpr bool foldable
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const