1 #ifndef HALIDE_IR_MATCH_H
2 #define HALIDE_IR_MATCH_H
139 typename =
typename std::remove_reference<T>::type::pattern_tag>
146 constexpr
static uint32_t mask = std::remove_reference<T>::type::binds;
166 const int lanes = scalar_type.
lanes;
167 scalar_type.
lanes = 1;
170 switch (scalar_type.
code) {
204 template<u
int32_t bound>
232 template<u
int32_t bound>
234 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
237 op = ((
const Broadcast *)op)->value.get();
246 state.get_bound_const(i, val, type);
249 state.set_bound_const(i, value, e.type);
253 template<u
int32_t bound>
255 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
259 state.get_bound_const(i, val, type);
260 return type == i64_type && value == val.
u.
i64;
262 state.set_bound_const(i, value, i64_type);
298 template<u
int32_t bound>
300 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
303 op = ((
const Broadcast *)op)->value.get();
312 state.get_bound_const(i, val, type);
315 state.set_bound_const(i, value, e.type);
331 state.get_bound_const(i, val, ty);
351 template<u
int32_t bound>
353 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
356 op = ((
const Broadcast *)op)->value.get();
361 double value = ((
const FloatImm *)op)->value;
365 state.get_bound_const(i, val, type);
368 state.set_bound_const(i, value, e.type);
384 state.get_bound_const(i, val, ty);
405 template<u
int32_t bound>
407 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
410 op = ((
const Broadcast *)op)->value.get();
424 template<u
int32_t bound>
426 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
442 state.get_bound_const(i, val, ty);
463 template<u
int32_t bound>
466 return equal(*state.get_binding(i), e);
468 state.set_binding(i, e);
503 template<u
int32_t bound>
507 op = ((
const Broadcast *)op)->value.get();
515 return ((
const FloatImm *)op)->value == (
double)
v;
521 template<u
int32_t bound>
526 template<u
int32_t bound>
550 val.u.f64 = (double)
v;
566 typename =
typename std::decay<T>::type::pattern_tag>
577 static_assert(!std::is_same<
typename std::decay<T>::type,
Expr>::value || std::is_lvalue_reference<T>::value,
578 "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
589 typename =
typename std::decay<T>::type::pattern_tag,
591 typename =
typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
606 template<
typename Op>
609 template<
typename Op>
612 template<
typename Op>
627 template<
typename Op,
typename A,
typename B>
642 A::canonical && B::canonical && (!
commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
644 template<u
int32_t bound>
646 if (e.node_type != Op::_node_type) {
649 const Op &op = (
const Op &)e;
650 return (
a.template match<bound>(*op.a.get(), state) &&
651 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
654 template<u
int32_t bound,
typename Op2,
typename A2,
typename B2>
656 return (std::is_same<Op, Op2>::value &&
657 a.template match<bound>(
unwrap(op.a), state) &&
661 constexpr
static bool foldable = A::foldable && B::foldable;
666 if (std::is_same<A, IntLiteral>::value) {
667 b.make_folded_const(val_b, ty, state);
668 if ((std::is_same<Op, And>::value && val_b.
u.
u64 == 0) ||
669 (std::is_same<Op, Or>::value && val_b.
u.
u64 == 1)) {
675 a.make_folded_const(val_a, ty, state);
678 a.make_folded_const(val_a, ty, state);
679 if ((std::is_same<Op, And>::value && val_a.
u.
u64 == 0) ||
680 (std::is_same<Op, Or>::value && val_a.
u.
u64 == 1)) {
686 b.make_folded_const(val_b, ty, state);
691 val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.
u.
i64, val_b.
u.
i64);
694 val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.
u.
u64, val_b.
u.
u64);
698 val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.
u.
f64, val_b.
u.
f64);
709 if (std::is_same<A, IntLiteral>::value) {
710 eb =
b.make(state, type_hint);
711 ea =
a.make(state, eb.
type());
713 ea =
a.make(state, type_hint);
714 eb =
b.make(state, ea.
type());
724 return Op::make(std::move(ea), std::move(eb));
728 template<
typename Op>
731 template<
typename Op>
734 template<
typename Op>
738 template<
typename Op,
typename A,
typename B>
750 (!
commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
754 template<u
int32_t bound>
756 if (e.node_type != Op::_node_type) {
759 const Op &op = (
const Op &)e;
760 return (
a.template match<bound>(*op.a.get(), state) &&
761 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
764 template<u
int32_t bound,
typename Op2,
typename A2,
typename B2>
766 return (std::is_same<Op, Op2>::value &&
767 a.template match<bound>(
unwrap(op.a), state) &&
771 constexpr
static bool foldable = A::foldable && B::foldable;
777 if (std::is_same<A, IntLiteral>::value) {
778 b.make_folded_const(val_b, ty, state);
780 a.make_folded_const(val_a, ty, state);
783 a.make_folded_const(val_a, ty, state);
785 b.make_folded_const(val_b, ty, state);
790 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
i64, val_b.
u.
i64);
793 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
u64, val_b.
u.
u64);
797 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
f64, val_b.
u.
f64);
811 if (std::is_same<A, IntLiteral>::value) {
812 eb =
b.make(state, {});
813 ea =
a.make(state, eb.
type());
815 ea =
a.make(state, {});
816 eb =
b.make(state, ea.
type());
826 return Op::make(std::move(ea), std::move(eb));
830 template<
typename A,
typename B>
832 s <<
"(" << op.
a <<
" + " << op.
b <<
")";
836 template<
typename A,
typename B>
838 s <<
"(" << op.
a <<
" - " << op.
b <<
")";
842 template<
typename A,
typename B>
844 s <<
"(" << op.
a <<
" * " << op.
b <<
")";
848 template<
typename A,
typename B>
850 s <<
"(" << op.
a <<
" / " << op.
b <<
")";
854 template<
typename A,
typename B>
856 s <<
"(" << op.
a <<
" && " << op.
b <<
")";
860 template<
typename A,
typename B>
862 s <<
"(" << op.
a <<
" || " << op.
b <<
")";
866 template<
typename A,
typename B>
868 s <<
"min(" << op.
a <<
", " << op.
b <<
")";
872 template<
typename A,
typename B>
874 s <<
"max(" << op.
a <<
", " << op.
b <<
")";
878 template<
typename A,
typename B>
880 s <<
"(" << op.
a <<
" <= " << op.
b <<
")";
884 template<
typename A,
typename B>
886 s <<
"(" << op.
a <<
" < " << op.
b <<
")";
890 template<
typename A,
typename B>
892 s <<
"(" << op.
a <<
" >= " << op.
b <<
")";
896 template<
typename A,
typename B>
898 s <<
"(" << op.
a <<
" > " << op.
b <<
")";
902 template<
typename A,
typename B>
904 s <<
"(" << op.
a <<
" == " << op.
b <<
")";
908 template<
typename A,
typename B>
910 s <<
"(" << op.
a <<
" != " << op.
b <<
")";
914 template<
typename A,
typename B>
916 s <<
"(" << op.
a <<
" % " << op.
b <<
")";
920 template<
typename A,
typename B>
922 assert_is_lvalue_if_expr<A>();
923 assert_is_lvalue_if_expr<B>();
927 template<
typename A,
typename B>
929 assert_is_lvalue_if_expr<A>();
930 assert_is_lvalue_if_expr<B>();
937 int dead_bits = 64 - t.bits;
945 return (a + b) & (ones >> (64 - t.bits));
953 template<
typename A,
typename B>
955 assert_is_lvalue_if_expr<A>();
956 assert_is_lvalue_if_expr<B>();
960 template<
typename A,
typename B>
962 assert_is_lvalue_if_expr<A>();
963 assert_is_lvalue_if_expr<B>();
971 int dead_bits = 64 - t.bits;
978 return (a - b) & (ones >> (64 - t.bits));
986 template<
typename A,
typename B>
988 assert_is_lvalue_if_expr<A>();
989 assert_is_lvalue_if_expr<B>();
993 template<
typename A,
typename B>
995 assert_is_lvalue_if_expr<A>();
996 assert_is_lvalue_if_expr<B>();
1003 int dead_bits = 64 - t.bits;
1011 return (a * b) & (ones >> (64 - t.bits));
1019 template<
typename A,
typename B>
1021 assert_is_lvalue_if_expr<A>();
1022 assert_is_lvalue_if_expr<B>();
1026 template<
typename A,
typename B>
1046 template<
typename A,
typename B>
1048 assert_is_lvalue_if_expr<A>();
1049 assert_is_lvalue_if_expr<B>();
1053 template<
typename A,
typename B>
1055 assert_is_lvalue_if_expr<A>();
1056 assert_is_lvalue_if_expr<B>();
1075 template<
typename A,
typename B>
1077 assert_is_lvalue_if_expr<A>();
1078 assert_is_lvalue_if_expr<B>();
1097 template<
typename A,
typename B>
1099 assert_is_lvalue_if_expr<A>();
1100 assert_is_lvalue_if_expr<B>();
1119 template<
typename A,
typename B>
1124 template<
typename A,
typename B>
1144 template<
typename A,
typename B>
1149 template<
typename A,
typename B>
1169 template<
typename A,
typename B>
1174 template<
typename A,
typename B>
1194 template<
typename A,
typename B>
1199 template<
typename A,
typename B>
1219 template<
typename A,
typename B>
1224 template<
typename A,
typename B>
1244 template<
typename A,
typename B>
1249 template<
typename A,
typename B>
1269 template<
typename A,
typename B>
1274 template<
typename A,
typename B>
1295 template<
typename A,
typename B>
1300 template<
typename A,
typename B>
1325 template<
typename... Args>
1334 template<
typename... Args>
1341 return a < b ? a : b;
1344 template<
typename... Args>
1362 typename =
typename std::enable_if<(i <
sizeof...(Args))>::type>
1364 using T = decltype(std::get<i>(
args));
1365 return (std::get<i>(
args).
template match<bound>(*c.args[i].get(), state) &&
1369 template<
int i, u
int32_t binds>
1374 template<u
int32_t bound>
1382 match_args<0, bound>(0, c, state));
1386 typename =
typename std::enable_if<(i <
sizeof...(Args))>::type>
1388 s << std::get<i>(
args);
1389 if (i + 1 <
sizeof...(Args)) {
1392 print_args<i + 1>(0, s);
1401 print_args<0>(0, s);
1406 Expr arg0 = std::get<0>(
args).make(state, type_hint);
1419 return absd(arg0, arg1);
1443 return arg0 << arg1;
1445 return arg0 >> arg1;
1471 std::get<0>(
args).make_folded_const(val, ty, state);
1476 std::get<1>(
args).make_folded_const(arg1, signed_ty, state);
1479 if (arg1.
u.
i64 < 0) {
1482 val.u.i64 >>= -arg1.
u.
i64;
1485 val.u.u64 >>= -arg1.
u.
i64;
1488 val.u.u64 <<= arg1.
u.
i64;
1491 if (arg1.
u.
i64 > 0) {
1494 val.u.i64 >>= arg1.
u.
i64;
1497 val.u.u64 >>= arg1.
u.
i64;
1500 val.u.u64 <<= -arg1.
u.
i64;
1513 template<
typename... Args>
1521 template<
typename... Args>
1526 template<
typename A,
typename B>
1530 template<
typename A,
typename B>
1534 template<
typename A,
typename B>
1539 template<
typename A,
typename B>
1543 template<
typename A,
typename B>
1547 template<
typename A,
typename B>
1551 template<
typename A,
typename B>
1555 template<
typename A,
typename B>
1559 template<
typename A>
1565 template<
typename A,
typename B>
1569 template<
typename A,
typename B>
1573 template<
typename A,
typename B>
1577 template<
typename A,
typename B>
1581 template<
typename A,
typename B>
1585 template<
typename A,
typename B>
1589 template<
typename A,
typename B>
1593 template<
typename A,
typename B,
typename C>
1597 template<
typename A,
typename B,
typename C>
1602 template<
typename A>
1613 template<u
int32_t bound>
1618 const Not &op = (
const Not &)e;
1619 return (
a.template match<bound>(*op.
a.
get(), state));
1622 template<u
int32_t bound,
typename A2>
1624 return a.template match<bound>(
unwrap(op.a), state);
1634 template<
typename A1 = A>
1636 a.make_folded_const(val, ty, state);
1637 val.u.u64 = ~val.u.u64;
1642 template<
typename A>
1644 assert_is_lvalue_if_expr<A>();
1648 template<
typename A>
1650 assert_is_lvalue_if_expr<A>();
1654 template<
typename A>
1656 s <<
"!(" << op.
a <<
")";
1660 template<
typename C,
typename T,
typename F>
1672 constexpr
static bool canonical = C::canonical && T::canonical && F::canonical;
1674 template<u
int32_t bound>
1680 return (
c.template match<bound>(*op.
condition.
get(), state) &&
1681 t.template match<bound | bindings<C>::mask>(*op.
true_value.
get(), state) &&
1684 template<u
int32_t bound,
typename C2,
typename T2,
typename F2>
1686 return (
c.template match<bound>(
unwrap(instance.c), state) &&
1693 return Select::make(
c.make(state, {}),
t.make(state, type_hint),
f.make(state, type_hint));
1696 constexpr
static bool foldable = C::foldable && T::foldable && F::foldable;
1698 template<
typename C1 = C>
1702 c.make_folded_const(c_val, c_ty, state);
1703 if ((c_val.
u.
u64 & 1) == 1) {
1704 t.make_folded_const(val, ty, state);
1706 f.make_folded_const(val, ty, state);
1712 template<
typename C,
typename T,
typename F>
1714 s <<
"select(" << op.
c <<
", " << op.
t <<
", " << op.
f <<
")";
1718 template<
typename C,
typename T,
typename F>
1720 assert_is_lvalue_if_expr<C>();
1721 assert_is_lvalue_if_expr<T>();
1722 assert_is_lvalue_if_expr<F>();
1726 template<
typename A,
typename B>
1737 constexpr
static bool canonical = A::canonical && B::canonical;
1739 template<u
int32_t bound>
1743 if (
a.template match<bound>(*op.
value.
get(), state) &&
1744 lanes.template match<bound>(op.
lanes, state)) {
1751 template<u
int32_t bound,
typename A2,
typename B2>
1753 return (
a.template match<bound>(
unwrap(op.a), state) &&
1761 lanes.make_folded_const(lanes_val, ty, state);
1763 type_hint.
lanes /= l;
1764 Expr val =
a.make(state, type_hint);
1774 template<
typename A1 = A>
1778 lanes.make_folded_const(lanes_val, lanes_ty, state);
1780 a.make_folded_const(val, ty, state);
1785 template<
typename A,
typename B>
1787 s <<
"broadcast(" << op.
a <<
", " << op.
lanes <<
")";
1791 template<
typename A,
typename B>
1793 assert_is_lvalue_if_expr<A>();
1797 template<
typename A,
typename B,
typename C>
1809 constexpr
static bool canonical = A::canonical && B::canonical && C::canonical;
1811 template<u
int32_t bound>
1817 if (
a.template match<bound>(*op.
base.
get(), state) &&
1818 b.template match<bound | bindings<A>::mask>(*op.
stride.
get(), state) &&
1826 template<u
int32_t bound,
typename A2,
typename B2,
typename C2>
1828 return (
a.template match<bound>(
unwrap(op.a), state) &&
1837 lanes.make_folded_const(lanes_val, ty, state);
1839 type_hint.
lanes /= l;
1841 eb =
b.make(state, type_hint);
1842 ea =
a.make(state, eb.type());
1849 template<
typename A,
typename B,
typename C>
1851 s <<
"ramp(" << op.
a <<
", " << op.
b <<
", " << op.
lanes <<
")";
1855 template<
typename A,
typename B,
typename C>
1857 assert_is_lvalue_if_expr<A>();
1858 assert_is_lvalue_if_expr<B>();
1859 assert_is_lvalue_if_expr<C>();
1863 template<
typename A,
typename B, VectorReduce::Operator reduce_op>
1875 template<u
int32_t bound>
1879 if (op.
op == reduce_op &&
1880 a.template match<bound>(*op.
value.
get(), state) &&
1881 lanes.template match<bound | bindings<A>::mask>(op.
type.
lanes(), state)) {
1888 template<u
int32_t bound,
typename A2,
typename B2, VectorReduce::Operator reduce_op_2>
1890 return (reduce_op == reduce_op_2 &&
1891 a.template match<bound>(
unwrap(op.a), state) &&
1899 lanes.make_folded_const(lanes_val, ty, state);
1900 int l = (int)lanes_val.
u.
i64;
1907 template<
typename A,
typename B, VectorReduce::Operator reduce_op>
1909 s <<
"vector_reduce(" << reduce_op <<
", " << op.
a <<
", " << op.
lanes <<
")";
1913 template<
typename A,
typename B>
1915 assert_is_lvalue_if_expr<A>();
1919 template<
typename A,
typename B>
1921 assert_is_lvalue_if_expr<A>();
1925 template<
typename A,
typename B>
1927 assert_is_lvalue_if_expr<A>();
1931 template<
typename A,
typename B>
1933 assert_is_lvalue_if_expr<A>();
1937 template<
typename A,
typename B>
1939 assert_is_lvalue_if_expr<A>();
1943 template<
typename A>
1955 template<u
int32_t bound>
1960 const Sub &op = (
const Sub &)e;
1961 return (
a.template match<bound>(*op.
b.
get(), state) &&
1965 template<u
int32_t bound,
typename A2>
1967 return a.template match<bound>(
unwrap(p.a), state);
1972 Expr ea =
a.make(state, type_hint);
1974 return Sub::make(std::move(z), std::move(ea));
1979 template<
typename A1 = A>
1981 a.make_folded_const(val, ty, state);
1982 int dead_bits = 64 - ty.bits;
1985 if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
1994 val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
1998 val.u.f64 = -val.u.f64;
2007 template<
typename A>
2013 template<
typename A>
2015 assert_is_lvalue_if_expr<A>();
2019 template<
typename A>
2021 assert_is_lvalue_if_expr<A>();
2025 template<
typename A>
2037 template<u
int32_t bound>
2043 return (e.type ==
t &&
2044 a.template match<bound>(*op.
value.
get(), state));
2046 template<u
int32_t bound,
typename A2>
2048 return t == op.t &&
a.template match<bound>(
unwrap(op.a), state);
2053 return cast(
t,
a.make(state, {}));
2059 template<
typename A>
2061 s <<
"cast(" << op.
t <<
", " << op.
a <<
")";
2065 template<
typename A>
2067 assert_is_lvalue_if_expr<A>();
2071 template<
typename A>
2082 template<u
int32_t bound>
2089 a.template match<bound>(*op.
value.
get(), state));
2091 template<u
int32_t bound,
typename A2>
2093 return a.template match<bound>(
unwrap(op.a), state);
2098 Expr e =
a.make(state, {});
2100 return cast(w, std::move(e));
2106 template<
typename A>
2108 s <<
"widen(" << op.
a <<
")";
2112 template<
typename A>
2114 assert_is_lvalue_if_expr<A>();
2118 template<
typename Vec,
typename Base,
typename Str
ide,
typename Lanes>
2126 static constexpr
uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds;
2130 constexpr
static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical;
2132 template<u
int32_t bound>
2138 return v.
vectors.size() == 1 &&
2140 vec.template match<bound>(*v.
vectors[0].get(), state) &&
2141 base.template match<bound | bindings<Vec>::mask>(v.
slice_begin(), state) &&
2150 base.make_folded_const(base_val, ty, state);
2151 int b = (int)base_val.
u.
i64;
2152 stride.make_folded_const(stride_val, ty, state);
2153 int s = (int)stride_val.
u.
i64;
2154 lanes.make_folded_const(lanes_val, ty, state);
2155 int l = (int)lanes_val.
u.
i64;
2164 static_assert(Base::foldable,
"Base of slice should consist only of operations that constant-fold");
2165 static_assert(Stride::foldable,
"Stride of slice should consist only of operations that constant-fold");
2166 static_assert(Lanes::foldable,
"Lanes of slice should consist only of operations that constant-fold");
2170 template<
typename Vec,
typename Base,
typename Str
ide,
typename Lanes>
2172 s <<
"slice(" << op.
vec <<
", " << op.
base <<
", " << op.
stride <<
", " << op.
lanes <<
")";
2176 template<
typename Vec,
typename Base,
typename Str
ide,
typename Lanes>
2182 template<
typename A>
2197 a.make_folded_const(c, ty, state);
2203 if (type_hint.bits) {
2207 c.
u.
f64 = (double)x;
2209 ty.
code = type_hint.code;
2210 ty.
bits = type_hint.bits;
2219 template<
typename A1 = A>
2221 a.make_folded_const(val, ty, state);
2225 template<
typename A>
2227 assert_is_lvalue_if_expr<A>();
2231 template<
typename A>
2233 s <<
"fold(" << op.
a <<
")";
2237 template<
typename A>
2252 template<
typename A1 = A>
2254 a.make_folded_const(val, ty, state);
2262 template<
typename A>
2264 assert_is_lvalue_if_expr<A>();
2268 template<
typename A>
2270 s <<
"overflows(" << op.
a <<
")";
2284 template<u
int32_t bound>
2313 template<
typename A>
2330 template<
typename A1 = A>
2332 Expr e =
a.make(state, {});
2344 template<
typename A>
2346 assert_is_lvalue_if_expr<A>();
2350 template<
typename A>
2352 assert_is_lvalue_if_expr<A>();
2356 template<
typename A>
2359 s <<
"is_const(" << op.
a <<
")";
2361 s <<
"is_const(" << op.
a <<
", " << op.
v <<
")";
2366 template<
typename A,
typename Prover>
2383 Expr condition =
a.make(state, {});
2384 condition =
prover->mutate(condition,
nullptr);
2392 template<
typename A,
typename Prover>
2394 assert_is_lvalue_if_expr<A>();
2398 template<
typename A,
typename Prover>
2400 s <<
"can_prove(" << op.
a <<
")";
2404 template<
typename A>
2421 Type t =
a.make(state, {}).type();
2429 template<
typename A>
2431 assert_is_lvalue_if_expr<A>();
2435 template<
typename A>
2437 s <<
"is_float(" << op.
a <<
")";
2441 template<
typename A>
2459 Type t =
a.make(state, {}).type();
2467 template<
typename A>
2469 assert_is_lvalue_if_expr<A>();
2473 template<
typename A>
2475 s <<
"is_int(" << op.
a;
2477 s <<
", " << op.
bits;
2480 s <<
", " << op.
lanes;
2486 template<
typename A>
2504 Type t =
a.make(state, {}).type();
2512 template<
typename A>
2514 assert_is_lvalue_if_expr<A>();
2518 template<
typename A>
2520 s <<
"is_uint(" << op.
a;
2522 s <<
", " << op.
bits;
2525 s <<
", " << op.
lanes;
2531 template<
typename A>
2548 Type t =
a.make(state, {}).type();
2556 template<
typename A>
2558 assert_is_lvalue_if_expr<A>();
2562 template<
typename A>
2564 s <<
"is_scalar(" << op.
a <<
")";
2568 template<
typename A>
2585 a.make_folded_const(val, ty, state);
2588 val.
u.
u64 = (val.
u.
u64 == max_bits);
2597 template<
typename A>
2599 assert_is_lvalue_if_expr<A>();
2603 template<
typename A>
2605 s <<
"is_max_value(" << op.
a <<
")";
2609 template<
typename A>
2626 a.make_folded_const(val, ty, state);
2629 val.
u.
u64 = (val.
u.
u64 == min_bits);
2640 template<
typename A>
2642 assert_is_lvalue_if_expr<A>();
2646 template<
typename A>
2648 s <<
"is_min_value(" << op.
a <<
")";
2652 template<
typename A>
2669 Type t =
a.make(state, {}).type();
2677 template<
typename A>
2679 assert_is_lvalue_if_expr<A>();
2683 template<
typename A>
2685 s <<
"lanes_of(" << op.
a <<
")";
2690 template<
typename Before,
2693 typename =
typename std::enable_if<std::decay<Before>::type::foldable &&
2694 std::decay<After>::type::foldable>::type>
2699 wildcard_type.lanes = output_type.lanes = 1;
2702 static std::set<uint32_t> tested;
2704 if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2709 debug(0) <<
"validate('" << before <<
"', '" << after <<
"', '" << pred <<
"', " <<
Type(wildcard_type) <<
", " <<
Type(output_type) <<
")\n";
2714 static std::mt19937_64 rng(0);
2719 for (
int trials = 0; trials < 100; trials++) {
2723 int shift = (int)(rng() & (wildcard_type.bits - 1));
2725 for (
int i = 0; i <
max_wild; i++) {
2727 switch (wildcard_type.code) {
2747 double val = ((
int64_t)(rng() & 15) - 8) / 2.0;
2749 val = ((
int64_t)(rng() & 15) - 8) / 2.0;
2763 before.make_folded_const(val_before, type, state);
2765 after.make_folded_const(val_after, type, state);
2766 lanes |= type.
lanes;
2773 switch (output_type.code) {
2788 ok &= (error < 0.01 ||
2789 val_before.
u.
u64 == val_after.
u.
u64 ||
2790 std::isnan(val_before.
u.
f64));
2798 debug(0) <<
"Fails with values:\n";
2799 for (
int i = 0; i <
max_wild; i++) {
2804 for (
int i = 0; i <
max_wild; i++) {
2809 debug(0) << val_before.
u.
u64 <<
" " << val_after.
u.
u64 <<
"\n";
2815 template<
typename Before,
2818 typename =
typename std::enable_if<!(std::decay<Before>::type::foldable &&
2819 std::decay<After>::type::foldable)>::type>
2830 template<
typename Pattern,
2831 typename =
typename enable_if_pattern<Pattern>::type>
2835 p.make_folded_const(c, ty, state);
2843 #define HALIDE_DEBUG_MATCHED_RULES 0
2844 #define HALIDE_DEBUG_UNMATCHED_RULES 0
2850 #define HALIDE_FUZZ_TEST_RULES 0
2852 template<
typename Instance>
2865 template<
typename After>
2870 template<
typename Before,
2875 static_assert((Before::binds & After::binds) == After::binds,
"Rule result uses unbound values");
2876 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2877 static_assert(After::canonical,
"RHS of rewrite rule should be in canonical form");
2878 #if HALIDE_FUZZ_TEST_RULES
2883 #if HALIDE_DEBUG_MATCHED_RULES
2888 #if HALIDE_DEBUG_UNMATCHED_RULES
2889 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2895 template<
typename Before,
2898 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2901 #if HALIDE_DEBUG_MATCHED_RULES
2906 #if HALIDE_DEBUG_UNMATCHED_RULES
2907 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2913 template<
typename Before,
2916 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2917 #if HALIDE_FUZZ_TEST_RULES
2922 #if HALIDE_DEBUG_MATCHED_RULES
2927 #if HALIDE_DEBUG_UNMATCHED_RULES
2928 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2934 template<
typename Before,
2941 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
2942 static_assert((Before::binds & After::binds) == After::binds,
"Rule result uses unbound values");
2943 static_assert((Before::binds & Predicate::binds) == Predicate::binds,
"Rule predicate uses unbound values");
2944 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2945 static_assert(After::canonical,
"RHS of rewrite rule should be in canonical form");
2947 #if HALIDE_FUZZ_TEST_RULES
2953 #if HALIDE_DEBUG_MATCHED_RULES
2954 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
2958 #if HALIDE_DEBUG_UNMATCHED_RULES
2959 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2965 template<
typename Before,
2970 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
2971 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2976 #if HALIDE_DEBUG_MATCHED_RULES
2977 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
2981 #if HALIDE_DEBUG_UNMATCHED_RULES
2982 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2988 template<
typename Before,
2993 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
2994 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2995 #if HALIDE_FUZZ_TEST_RULES
3001 #if HALIDE_DEBUG_MATCHED_RULES
3002 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
3006 #if HALIDE_DEBUG_UNMATCHED_RULES
3007 debug(0) <<
instance <<
" does not match " << before <<
"\n";
3031 template<
typename Instance,
3032 typename =
typename enable_if_pattern<Instance>::type>
3034 return {
pattern_arg(instance), output_type, wildcard_type};
3037 template<
typename Instance,
3038 typename =
typename enable_if_pattern<Instance>::type>
3040 return {
pattern_arg(instance), output_type, output_type};
@ halide_type_float
IEEE floating point numbers.
@ halide_type_bfloat
floating point numbers in the bfloat format
@ halide_type_int
signed integers
@ halide_type_uint
unsigned integers
#define HALIDE_NEVER_INLINE
#define HALIDE_ALWAYS_INLINE
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Methods to test Exprs and Stmts for equality of value.
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
For optional debugging during codegen, use the debug class as follows:
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp< Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin< decltype(pattern_arg(args))... >
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp< Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept -> SliceOp< decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))>
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp< decltype(pattern_arg(a))>
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
auto widening_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
auto halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
constexpr bool and_reduce()
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits=0, int lanes=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
constexpr bool commutative(IRNodeType t)
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits=0, int lanes=0) noexcept -> IsInt< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
auto shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
constexpr uint32_t bitwise_or_reduce()
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin< decltype(pattern_arg(a))>
constexpr int const_min(int a, int b)
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
Expr make_zero(Type t)
Construct the representation of zero in the given type.
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
bool equal(const RDom &bounds0, const RDom &bounds1)
Return true if bounds0 and bounds1 represent the same bounds.
constexpr IRNodeType StrongestExprNodeType
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
bool sub_would_overflow(int bits, int64_t a, int64_t b)
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
ConstantInterval abs(const ConstantInterval &a)
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
@ Predicate
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
Expr likely_if_innermost(Expr e)
Equivalent to likely, but only triggers a loop partitioning if found in an innermost loop.
Expr likely(Expr e)
Expressions tagged with this intrinsic are considered to be part of the steady state of some loop wit...
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
The sum of two expressions.
Logical and - are both expressions true.
A base class for expression nodes.
A vector with 'lanes' elements, in which every element is 'value'.
static Expr make(Expr value, int lanes)
static const IRNodeType _node_type
@ signed_integer_overflow
@ rounding_mul_shift_right
bool is_intrinsic() const
static const IRNodeType _node_type
The actual IR nodes begin here.
static const IRNodeType _node_type
The ratio of two expressions.
Is the first expression equal to the second.
Floating point constants.
static const FloatImm * make(Type t, double value)
Is the first expression greater than or equal to the second.
Is the first expression greater than the second.
constexpr static uint32_t binds
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
constexpr static bool canonical
constexpr static bool foldable
constexpr static bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static IRNodeType max_node_type
constexpr static bool canonical
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
constexpr static bool foldable
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
constexpr static uint32_t binds
constexpr static bool canonical
constexpr static bool canonical
constexpr static IRNodeType max_node_type
constexpr static bool foldable
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static uint32_t binds
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
constexpr static bool foldable
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
constexpr static bool foldable
constexpr static uint32_t binds
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
constexpr static bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static bool canonical
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
constexpr static bool foldable
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static bool canonical
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
constexpr static bool foldable
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
std::tuple< Args... > args
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
HALIDE_ALWAYS_INLINE Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
constexpr static IRNodeType max_node_type
constexpr static IRNodeType min_node_type
constexpr static bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static bool foldable
constexpr static IRNodeType max_node_type
constexpr static IRNodeType min_node_type
constexpr static uint32_t binds
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
constexpr static bool canonical
constexpr static uint32_t binds
constexpr static bool foldable
constexpr static IRNodeType max_node_type
constexpr static uint32_t binds
constexpr static bool foldable
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
constexpr static IRNodeType max_node_type
constexpr static bool canonical
constexpr static bool canonical
constexpr static bool foldable
constexpr static uint32_t binds
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
constexpr static bool canonical
constexpr static uint32_t binds
constexpr static bool foldable
constexpr static bool foldable
constexpr static bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
constexpr static IRNodeType max_node_type
constexpr static IRNodeType min_node_type
constexpr static uint32_t binds
constexpr static bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
constexpr static uint32_t binds
constexpr static IRNodeType min_node_type
constexpr static bool foldable
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
constexpr static bool foldable
constexpr static bool canonical
constexpr static uint32_t binds
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
To save stack space, the matcher objects are largely stateless and immutable.
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
static constexpr uint16_t special_values_mask
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
halide_type_t bound_const_type[max_wild]
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
HALIDE_ALWAYS_INLINE MatcherState() noexcept
halide_scalar_value_t bound_const[max_wild]
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
static constexpr uint16_t signed_integer_overflow
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static uint32_t binds
constexpr static bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static bool foldable
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
constexpr static uint32_t binds
constexpr static IRNodeType max_node_type
constexpr static bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static bool canonical
constexpr static bool foldable
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static IRNodeType max_node_type
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
constexpr static bool foldable
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static IRNodeType max_node_type
constexpr static IRNodeType min_node_type
constexpr static uint32_t binds
constexpr static bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static IRNodeType max_node_type
constexpr static bool canonical
constexpr static IRNodeType min_node_type
constexpr static bool foldable
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_NEVER_INLINE void build_replacement(After after)
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
halide_type_t wildcard_type
halide_type_t output_type
constexpr static IRNodeType max_node_type
constexpr static bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
constexpr static uint32_t binds
constexpr static bool foldable
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static IRNodeType max_node_type
constexpr static bool foldable
HALIDE_ALWAYS_INLINE SliceOp(Vec v, Base b, Stride s, Lanes l)
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr uint32_t binds
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static bool canonical
constexpr static IRNodeType min_node_type
constexpr static bool canonical
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static uint32_t binds
const BaseExprNode & expr
constexpr static IRNodeType max_node_type
constexpr static bool foldable
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
constexpr static uint32_t binds
constexpr static IRNodeType max_node_type
constexpr static bool canonical
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static bool foldable
constexpr static bool foldable
constexpr static bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE bool match(const WidenOp< A2 > &op, MatcherState &state) const noexcept
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
constexpr static bool foldable
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static uint32_t binds
constexpr static bool canonical
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static uint32_t binds
constexpr static bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
constexpr static bool canonical
constexpr static bool foldable
constexpr static IRNodeType min_node_type
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
constexpr static IRNodeType max_node_type
constexpr static uint32_t binds
constexpr static IRNodeType max_node_type
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static bool foldable
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static bool canonical
constexpr static IRNodeType max_node_type
constexpr static bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static bool canonical
constexpr static IRNodeType min_node_type
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static uint32_t mask
IRNodeType node_type
Each IR node subclass has a unique identifier.
static const IntImm * make(Type t, int64_t value)
Is the first expression less than or equal to the second.
Is the first expression less than the second.
The greater of two values.
The lesser of two values.
The product of two expressions.
Is the first expression not equal to the second.
Logical not - true if the expression false.
Logical or - is at least one of the expression true.
A linear ramp vector node.
static const IRNodeType _node_type
static Expr make(Expr base, Expr stride, int lanes)
static Expr make(Expr condition, Expr true_value, Expr false_value)
static const IRNodeType _node_type
Construct a new vector by taking elements from another sequence of vectors.
static Expr make_slice(Expr vector, int begin, int stride, int size)
Convenience constructor for making a shuffle representing a contiguous subset of a vector.
std::vector< Expr > vectors
bool is_slice() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
int slice_stride() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
int slice_begin() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
The difference of two expressions.
static const IRNodeType _node_type
static Expr make(Expr a, Expr b)
Unsigned integer constants.
static const UIntImm * make(Type t, uint64_t value)
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
static const IRNodeType _node_type
static Expr make(Operator op, Expr vec, int lanes)
Types in the halide type system.
Type widen() const
Return Type with the same type code and number of lanes, but with at least twice as many bits.
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
HALIDE_ALWAYS_INLINE bool is_vector() const
Is this type a vector type? (lanes() != 1).
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
union halide_scalar_value_t::@3 u
A runtime tag for a type in the halide type system.
uint8_t bits
The number of bits of precision of a single scalar value of this type.
uint16_t lanes
How many elements in a vector.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.