Halide  19.0.0
Halide compiler and libraries
IRMatch.h
Go to the documentation of this file.
1 #ifndef HALIDE_IR_MATCH_H
2 #define HALIDE_IR_MATCH_H
3 
4 /** \file
5  * Defines a method to match a fragment of IR against a pattern containing wildcards
6  */
7 
8 #include <map>
9 #include <random>
10 #include <set>
11 #include <vector>
12 
13 #include "IR.h"
14 #include "IREquality.h"
15 #include "IROperator.h"
16 
17 namespace Halide {
18 namespace Internal {
19 
20 /** Does the first expression have the same structure as the second?
21  * Variables in the first expression with the name * are interpreted
22  * as wildcards, and their matching equivalent in the second
23  * expression is placed in the vector give as the third argument.
24  * Wildcards require the types to match. For the type bits and width,
25  * a 0 indicates "match anything". So an Int(8, 0) will match 8-bit
26  * integer vectors of any width (including scalars), and a UInt(0, 0)
27  * will match any unsigned integer type.
28  *
29  * For example:
30  \code
31  Expr x = Variable::make(Int(32), "*");
32  match(x + x, 3 + (2*k), result)
33  \endcode
34  * should return true, and set result[0] to 3 and
35  * result[1] to 2*k.
36  */
37 bool expr_match(const Expr &pattern, const Expr &expr, std::vector<Expr> &result);
38 
39 /** Does the first expression have the same structure as the second?
40  * Variables are matched consistently. The first time a variable is
41  * matched, it assumes the value of the matching part of the second
42  * expression. Subsequent matches must be equal to the first match.
43  *
44  * For example:
45  \code
46  Var x("x"), y("y");
47  match(x*(x + y), a*(a + b), result)
48  \endcode
49  * should return true, and set result["x"] = a, and result["y"] = b.
50  */
51 bool expr_match(const Expr &pattern, const Expr &expr, std::map<std::string, Expr> &result);
52 
53 /** Rewrite the expression x to have `lanes` lanes. This is useful
54  * for substituting the results of expr_match into a pattern expression. */
55 Expr with_lanes(const Expr &x, int lanes);
56 
58 
59 /** An alternative template-metaprogramming approach to expression
60  * matching. Potentially more efficient. We lift the expression
61  * pattern into a type, and then use force-inlined functions to
62  * generate efficient matching and reconstruction code for any
63  * pattern. Pattern elements are either one of the classes in the
64  * namespace IRMatcher, or are non-null Exprs (represented as
65  * BaseExprNode &).
66  *
67  * Pattern elements that are fully specified by their pattern can be
68  * built into an expression using the make method. Some patterns,
69  * such as a broadcast that matches any number of lanes, don't have
70  * enough information to recreate an Expr.
71  */
72 namespace IRMatcher {
73 
74 constexpr int max_wild = 6;
75 
76 static const halide_type_t i64_type = {halide_type_int, 64, 1};
77 
78 /** To save stack space, the matcher objects are largely stateless and
79  * immutable. This state object is built up during matching and then
80  * consumed when constructing a replacement Expr.
81  */
82 struct MatcherState {
85 
86  // values of the lanes field with special meaning.
87  static constexpr uint16_t signed_integer_overflow = 0x8000;
88  static constexpr uint16_t special_values_mask = 0x8000; // currently only one
89 
91 
93  void set_binding(int i, const BaseExprNode &n) noexcept {
94  bindings[i] = &n;
95  }
96 
98  const BaseExprNode *get_binding(int i) const noexcept {
99  return bindings[i];
100  }
101 
103  void set_bound_const(int i, int64_t s, halide_type_t t) noexcept {
104  bound_const[i].u.i64 = s;
105  bound_const_type[i] = t;
106  }
107 
109  void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept {
110  bound_const[i].u.u64 = u;
111  bound_const_type[i] = t;
112  }
113 
115  void set_bound_const(int i, double f, halide_type_t t) noexcept {
116  bound_const[i].u.f64 = f;
117  bound_const_type[i] = t;
118  }
119 
121  void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept {
122  bound_const[i] = val;
123  bound_const_type[i] = t;
124  }
125 
127  void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept {
128  val = bound_const[i];
129  type = bound_const_type[i];
130  }
131 
133  // NOLINTNEXTLINE(modernize-use-equals-default): Can't use `= default`; clang-tidy complains about noexcept mismatch
134  MatcherState() noexcept {
135  }
136 };
137 
138 template<typename T,
139  typename = typename std::remove_reference<T>::type::pattern_tag>
141  struct type {};
142 };
143 
144 template<typename T>
145 struct bindings {
146  constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
147 };
148 
151  ty.lanes &= ~MatcherState::special_values_mask;
153  return make_signed_integer_overflow(ty);
154  }
155  // unreachable
156  return Expr();
157 }
158 
161  halide_type_t scalar_type = ty;
162  if (scalar_type.lanes & MatcherState::special_values_mask) {
163  return make_const_special_expr(scalar_type);
164  }
165 
166  const int lanes = scalar_type.lanes;
167  scalar_type.lanes = 1;
168 
169  Expr e;
170  switch (scalar_type.code) {
171  case halide_type_int:
172  e = IntImm::make(scalar_type, val.u.i64);
173  break;
174  case halide_type_uint:
175  e = UIntImm::make(scalar_type, val.u.u64);
176  break;
177  case halide_type_float:
178  case halide_type_bfloat:
179  e = FloatImm::make(scalar_type, val.u.f64);
180  break;
181  default:
182  // Unreachable
183  return Expr();
184  }
185  if (lanes > 1) {
186  e = Broadcast::make(e, lanes);
187  }
188  return e;
189 }
190 
191 // A pattern that matches a specific expression
192 struct SpecificExpr {
193  struct pattern_tag {};
194 
195  constexpr static uint32_t binds = 0;
196 
197  // What is the weakest and strongest IR node this could possibly be
200  constexpr static bool canonical = true;
201 
203 
204  template<uint32_t bound>
205  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
206  return equal(expr, e);
207  }
208 
210  Expr make(MatcherState &state, halide_type_t type_hint) const {
211  return Expr(&expr);
212  }
213 
214  constexpr static bool foldable = false;
215 };
216 
217 inline std::ostream &operator<<(std::ostream &s, const SpecificExpr &e) {
218  s << Expr(&e.expr);
219  return s;
220 }
221 
222 template<int i>
223 struct WildConstInt {
224  struct pattern_tag {};
225 
226  constexpr static uint32_t binds = 1 << i;
227 
230  constexpr static bool canonical = true;
231 
232  template<uint32_t bound>
233  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
234  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
235  const BaseExprNode *op = &e;
236  if (op->node_type == IRNodeType::Broadcast) {
237  op = ((const Broadcast *)op)->value.get();
238  }
239  if (op->node_type != IRNodeType::IntImm) {
240  return false;
241  }
242  int64_t value = ((const IntImm *)op)->value;
243  if (bound & binds) {
245  halide_type_t type;
246  state.get_bound_const(i, val, type);
247  return (halide_type_t)e.type == type && value == val.u.i64;
248  }
249  state.set_bound_const(i, value, e.type);
250  return true;
251  }
252 
253  template<uint32_t bound>
254  HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept {
255  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
256  if (bound & binds) {
258  halide_type_t type;
259  state.get_bound_const(i, val, type);
260  return type == i64_type && value == val.u.i64;
261  }
262  state.set_bound_const(i, value, i64_type);
263  return true;
264  }
265 
267  Expr make(MatcherState &state, halide_type_t type_hint) const {
269  halide_type_t type;
270  state.get_bound_const(i, val, type);
271  return make_const_expr(val, type);
272  }
273 
274  constexpr static bool foldable = true;
275 
278  state.get_bound_const(i, val, ty);
279  }
280 };
281 
282 template<int i>
283 std::ostream &operator<<(std::ostream &s, const WildConstInt<i> &c) {
284  s << "ci" << i;
285  return s;
286 }
287 
288 template<int i>
290  struct pattern_tag {};
291 
292  constexpr static uint32_t binds = 1 << i;
293 
296  constexpr static bool canonical = true;
297 
298  template<uint32_t bound>
299  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
300  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
301  const BaseExprNode *op = &e;
302  if (op->node_type == IRNodeType::Broadcast) {
303  op = ((const Broadcast *)op)->value.get();
304  }
305  if (op->node_type != IRNodeType::UIntImm) {
306  return false;
307  }
308  uint64_t value = ((const UIntImm *)op)->value;
309  if (bound & binds) {
311  halide_type_t type;
312  state.get_bound_const(i, val, type);
313  return (halide_type_t)e.type == type && value == val.u.u64;
314  }
315  state.set_bound_const(i, value, e.type);
316  return true;
317  }
318 
320  Expr make(MatcherState &state, halide_type_t type_hint) const {
322  halide_type_t type;
323  state.get_bound_const(i, val, type);
324  return make_const_expr(val, type);
325  }
326 
327  constexpr static bool foldable = true;
328 
330  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
331  state.get_bound_const(i, val, ty);
332  }
333 };
334 
335 template<int i>
336 std::ostream &operator<<(std::ostream &s, const WildConstUInt<i> &c) {
337  s << "cu" << i;
338  return s;
339 }
340 
341 template<int i>
343  struct pattern_tag {};
344 
345  constexpr static uint32_t binds = 1 << i;
346 
349  constexpr static bool canonical = true;
350 
351  template<uint32_t bound>
352  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
353  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
354  const BaseExprNode *op = &e;
355  if (op->node_type == IRNodeType::Broadcast) {
356  op = ((const Broadcast *)op)->value.get();
357  }
358  if (op->node_type != IRNodeType::FloatImm) {
359  return false;
360  }
361  double value = ((const FloatImm *)op)->value;
362  if (bound & binds) {
364  halide_type_t type;
365  state.get_bound_const(i, val, type);
366  return (halide_type_t)e.type == type && value == val.u.f64;
367  }
368  state.set_bound_const(i, value, e.type);
369  return true;
370  }
371 
373  Expr make(MatcherState &state, halide_type_t type_hint) const {
375  halide_type_t type;
376  state.get_bound_const(i, val, type);
377  return make_const_expr(val, type);
378  }
379 
380  constexpr static bool foldable = true;
381 
383  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
384  state.get_bound_const(i, val, ty);
385  }
386 };
387 
388 template<int i>
389 std::ostream &operator<<(std::ostream &s, const WildConstFloat<i> &c) {
390  s << "cf" << i;
391  return s;
392 }
393 
394 // Matches and binds to any constant Expr. Does not support constant-folding.
395 template<int i>
396 struct WildConst {
397  struct pattern_tag {};
398 
399  constexpr static uint32_t binds = 1 << i;
400 
403  constexpr static bool canonical = true;
404 
405  template<uint32_t bound>
406  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
407  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
408  const BaseExprNode *op = &e;
409  if (op->node_type == IRNodeType::Broadcast) {
410  op = ((const Broadcast *)op)->value.get();
411  }
412  switch (op->node_type) {
413  case IRNodeType::IntImm:
414  return WildConstInt<i>().template match<bound>(e, state);
415  case IRNodeType::UIntImm:
416  return WildConstUInt<i>().template match<bound>(e, state);
418  return WildConstFloat<i>().template match<bound>(e, state);
419  default:
420  return false;
421  }
422  }
423 
424  template<uint32_t bound>
425  HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept {
426  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
427  return WildConstInt<i>().template match<bound>(e, state);
428  }
429 
431  Expr make(MatcherState &state, halide_type_t type_hint) const {
433  halide_type_t type;
434  state.get_bound_const(i, val, type);
435  return make_const_expr(val, type);
436  }
437 
438  constexpr static bool foldable = true;
439 
441  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
442  state.get_bound_const(i, val, ty);
443  }
444 };
445 
446 template<int i>
447 std::ostream &operator<<(std::ostream &s, const WildConst<i> &c) {
448  s << "c" << i;
449  return s;
450 }
451 
452 // Matches and binds to any Expr
453 template<int i>
454 struct Wild {
455  struct pattern_tag {};
456 
457  constexpr static uint32_t binds = 1 << (i + 16);
458 
461  constexpr static bool canonical = true;
462 
463  template<uint32_t bound>
464  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
465  if (bound & binds) {
466  return equal(*state.get_binding(i), e);
467  }
468  state.set_binding(i, e);
469  return true;
470  }
471 
473  Expr make(MatcherState &state, halide_type_t type_hint) const {
474  return state.get_binding(i);
475  }
476 
477  constexpr static bool foldable = false;
478 };
479 
480 template<int i>
481 std::ostream &operator<<(std::ostream &s, const Wild<i> &op) {
482  s << "_" << i;
483  return s;
484 }
485 
486 // Matches a specific constant or broadcast of that constant. The
487 // constant must be representable as an int64_t.
488 struct IntLiteral {
489  struct pattern_tag {};
491 
492  constexpr static uint32_t binds = 0;
493 
496  constexpr static bool canonical = true;
497 
499  explicit IntLiteral(int64_t v)
500  : v(v) {
501  }
502 
503  template<uint32_t bound>
504  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
505  const BaseExprNode *op = &e;
506  if (e.node_type == IRNodeType::Broadcast) {
507  op = ((const Broadcast *)op)->value.get();
508  }
509  switch (op->node_type) {
510  case IRNodeType::IntImm:
511  return ((const IntImm *)op)->value == (int64_t)v;
512  case IRNodeType::UIntImm:
513  return ((const UIntImm *)op)->value == (uint64_t)v;
515  return ((const FloatImm *)op)->value == (double)v;
516  default:
517  return false;
518  }
519  }
520 
521  template<uint32_t bound>
522  HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept {
523  return v == val;
524  }
525 
526  template<uint32_t bound>
527  HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept {
528  return v == b.v;
529  }
530 
532  Expr make(MatcherState &state, halide_type_t type_hint) const {
533  return make_const(type_hint, v);
534  }
535 
536  constexpr static bool foldable = true;
537 
539  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
540  // Assume type is already correct
541  switch (ty.code) {
542  case halide_type_int:
543  val.u.i64 = v;
544  break;
545  case halide_type_uint:
546  val.u.u64 = (uint64_t)v;
547  break;
548  case halide_type_float:
549  case halide_type_bfloat:
550  val.u.f64 = (double)v;
551  break;
552  default:
553  // Unreachable
554  ;
555  }
556  }
557 };
558 
560  return t.v;
561 }
562 
563 // Convert a provided pattern, expr, or constant int into the internal
564 // representation we use in the matcher trees.
565 template<typename T,
566  typename = typename std::decay<T>::type::pattern_tag>
568  return t;
569 }
572  return IntLiteral{x};
573 }
574 
575 template<typename T>
577  static_assert(!std::is_same<typename std::decay<T>::type, Expr>::value || std::is_lvalue_reference<T>::value,
578  "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
579 }
580 
582  return {*e.get()};
583 }
584 
585 // Helpers to deref SpecificExprs to const BaseExprNode & rather than
586 // passing them by value anywhere (incurring lots of refcounting)
587 template<typename T,
588  // T must be a pattern node
589  typename = typename std::decay<T>::type::pattern_tag,
590  // But T may not be SpecificExpr
591  typename = typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
593  return t;
594 }
595 
597 const BaseExprNode &unwrap(const SpecificExpr &e) {
598  return e.expr;
599 }
600 
601 inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) {
602  s << op.v;
603  return s;
604 }
605 
606 template<typename Op>
608 
609 template<typename Op>
611 
612 template<typename Op>
613 double constant_fold_bin_op(halide_type_t &, double, double) noexcept;
614 
615 constexpr bool commutative(IRNodeType t) {
616  return (t == IRNodeType::Add ||
617  t == IRNodeType::Mul ||
618  t == IRNodeType::And ||
619  t == IRNodeType::Or ||
620  t == IRNodeType::Min ||
621  t == IRNodeType::Max ||
622  t == IRNodeType::EQ ||
623  t == IRNodeType::NE);
624 }
625 
626 // Matches one of the binary operators
627 template<typename Op, typename A, typename B>
628 struct BinOp {
629  struct pattern_tag {};
630  A a;
631  B b;
632 
634 
635  constexpr static IRNodeType min_node_type = Op::_node_type;
636  constexpr static IRNodeType max_node_type = Op::_node_type;
637 
638  // For commutative bin ops, we expect the weaker IR node type on
639  // the right. That is, for the rule to be canonical it must be
640  // possible that A is at least as strong as B.
641  constexpr static bool canonical =
642  A::canonical && B::canonical && (!commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
643 
644  template<uint32_t bound>
645  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
646  if (e.node_type != Op::_node_type) {
647  return false;
648  }
649  const Op &op = (const Op &)e;
650  return (a.template match<bound>(*op.a.get(), state) &&
651  b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
652  }
653 
654  template<uint32_t bound, typename Op2, typename A2, typename B2>
655  HALIDE_ALWAYS_INLINE bool match(const BinOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
656  return (std::is_same<Op, Op2>::value &&
657  a.template match<bound>(unwrap(op.a), state) &&
658  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
659  }
660 
661  constexpr static bool foldable = A::foldable && B::foldable;
662 
664  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
665  halide_scalar_value_t val_a, val_b;
666  if (std::is_same<A, IntLiteral>::value) {
667  b.make_folded_const(val_b, ty, state);
668  if ((std::is_same<Op, And>::value && val_b.u.u64 == 0) ||
669  (std::is_same<Op, Or>::value && val_b.u.u64 == 1)) {
670  // Short circuit
671  val = val_b;
672  return;
673  }
674  const uint16_t l = ty.lanes;
675  a.make_folded_const(val_a, ty, state);
676  ty.lanes |= l; // Make sure the overflow bits are sticky
677  } else {
678  a.make_folded_const(val_a, ty, state);
679  if ((std::is_same<Op, And>::value && val_a.u.u64 == 0) ||
680  (std::is_same<Op, Or>::value && val_a.u.u64 == 1)) {
681  // Short circuit
682  val = val_a;
683  return;
684  }
685  const uint16_t l = ty.lanes;
686  b.make_folded_const(val_b, ty, state);
687  ty.lanes |= l;
688  }
689  switch (ty.code) {
690  case halide_type_int:
691  val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.u.i64, val_b.u.i64);
692  break;
693  case halide_type_uint:
694  val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.u.u64, val_b.u.u64);
695  break;
696  case halide_type_float:
697  case halide_type_bfloat:
698  val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.u.f64, val_b.u.f64);
699  break;
700  default:
701  // unreachable
702  ;
703  }
704  }
705 
707  Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
708  Expr ea, eb;
709  if (std::is_same<A, IntLiteral>::value) {
710  eb = b.make(state, type_hint);
711  ea = a.make(state, eb.type());
712  } else {
713  ea = a.make(state, type_hint);
714  eb = b.make(state, ea.type());
715  }
716  // We sometimes mix vectors and scalars in the rewrite rules,
717  // so insert a broadcast if necessary.
718  if (ea.type().is_vector() && !eb.type().is_vector()) {
719  eb = Broadcast::make(eb, ea.type().lanes());
720  }
721  if (eb.type().is_vector() && !ea.type().is_vector()) {
722  ea = Broadcast::make(ea, eb.type().lanes());
723  }
724  return Op::make(std::move(ea), std::move(eb));
725  }
726 };
727 
728 template<typename Op>
730 
731 template<typename Op>
733 
734 template<typename Op>
735 uint64_t constant_fold_cmp_op(double, double) noexcept;
736 
737 // Matches one of the comparison operators
738 template<typename Op, typename A, typename B>
739 struct CmpOp {
740  struct pattern_tag {};
741  A a;
742  B b;
743 
745 
746  constexpr static IRNodeType min_node_type = Op::_node_type;
747  constexpr static IRNodeType max_node_type = Op::_node_type;
748  constexpr static bool canonical = (A::canonical &&
749  B::canonical &&
750  (!commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
751  (Op::_node_type != IRNodeType::GE) &&
752  (Op::_node_type != IRNodeType::GT));
753 
754  template<uint32_t bound>
755  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
756  if (e.node_type != Op::_node_type) {
757  return false;
758  }
759  const Op &op = (const Op &)e;
760  return (a.template match<bound>(*op.a.get(), state) &&
761  b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
762  }
763 
764  template<uint32_t bound, typename Op2, typename A2, typename B2>
765  HALIDE_ALWAYS_INLINE bool match(const CmpOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
766  return (std::is_same<Op, Op2>::value &&
767  a.template match<bound>(unwrap(op.a), state) &&
768  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
769  }
770 
771  constexpr static bool foldable = A::foldable && B::foldable;
772 
774  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
775  halide_scalar_value_t val_a, val_b;
776  // If one side is an untyped const, evaluate the other side first to get a type hint.
777  if (std::is_same<A, IntLiteral>::value) {
778  b.make_folded_const(val_b, ty, state);
779  const uint16_t l = ty.lanes;
780  a.make_folded_const(val_a, ty, state);
781  ty.lanes |= l;
782  } else {
783  a.make_folded_const(val_a, ty, state);
784  const uint16_t l = ty.lanes;
785  b.make_folded_const(val_b, ty, state);
786  ty.lanes |= l;
787  }
788  switch (ty.code) {
789  case halide_type_int:
790  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.i64, val_b.u.i64);
791  break;
792  case halide_type_uint:
793  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.u64, val_b.u.u64);
794  break;
795  case halide_type_float:
796  case halide_type_bfloat:
797  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.f64, val_b.u.f64);
798  break;
799  default:
800  // unreachable
801  ;
802  }
803  ty.code = halide_type_uint;
804  ty.bits = 1;
805  }
806 
808  Expr make(MatcherState &state, halide_type_t type_hint) const {
809  // If one side is an untyped const, evaluate the other side first to get a type hint.
810  Expr ea, eb;
811  if (std::is_same<A, IntLiteral>::value) {
812  eb = b.make(state, {});
813  ea = a.make(state, eb.type());
814  } else {
815  ea = a.make(state, {});
816  eb = b.make(state, ea.type());
817  }
818  // We sometimes mix vectors and scalars in the rewrite rules,
819  // so insert a broadcast if necessary.
820  if (ea.type().is_vector() && !eb.type().is_vector()) {
821  eb = Broadcast::make(eb, ea.type().lanes());
822  }
823  if (eb.type().is_vector() && !ea.type().is_vector()) {
824  ea = Broadcast::make(ea, eb.type().lanes());
825  }
826  return Op::make(std::move(ea), std::move(eb));
827  }
828 };
829 
830 template<typename A, typename B>
831 std::ostream &operator<<(std::ostream &s, const BinOp<Add, A, B> &op) {
832  s << "(" << op.a << " + " << op.b << ")";
833  return s;
834 }
835 
836 template<typename A, typename B>
837 std::ostream &operator<<(std::ostream &s, const BinOp<Sub, A, B> &op) {
838  s << "(" << op.a << " - " << op.b << ")";
839  return s;
840 }
841 
842 template<typename A, typename B>
843 std::ostream &operator<<(std::ostream &s, const BinOp<Mul, A, B> &op) {
844  s << "(" << op.a << " * " << op.b << ")";
845  return s;
846 }
847 
848 template<typename A, typename B>
849 std::ostream &operator<<(std::ostream &s, const BinOp<Div, A, B> &op) {
850  s << "(" << op.a << " / " << op.b << ")";
851  return s;
852 }
853 
854 template<typename A, typename B>
855 std::ostream &operator<<(std::ostream &s, const BinOp<And, A, B> &op) {
856  s << "(" << op.a << " && " << op.b << ")";
857  return s;
858 }
859 
860 template<typename A, typename B>
861 std::ostream &operator<<(std::ostream &s, const BinOp<Or, A, B> &op) {
862  s << "(" << op.a << " || " << op.b << ")";
863  return s;
864 }
865 
866 template<typename A, typename B>
867 std::ostream &operator<<(std::ostream &s, const BinOp<Min, A, B> &op) {
868  s << "min(" << op.a << ", " << op.b << ")";
869  return s;
870 }
871 
872 template<typename A, typename B>
873 std::ostream &operator<<(std::ostream &s, const BinOp<Max, A, B> &op) {
874  s << "max(" << op.a << ", " << op.b << ")";
875  return s;
876 }
877 
878 template<typename A, typename B>
879 std::ostream &operator<<(std::ostream &s, const CmpOp<LE, A, B> &op) {
880  s << "(" << op.a << " <= " << op.b << ")";
881  return s;
882 }
883 
884 template<typename A, typename B>
885 std::ostream &operator<<(std::ostream &s, const CmpOp<LT, A, B> &op) {
886  s << "(" << op.a << " < " << op.b << ")";
887  return s;
888 }
889 
890 template<typename A, typename B>
891 std::ostream &operator<<(std::ostream &s, const CmpOp<GE, A, B> &op) {
892  s << "(" << op.a << " >= " << op.b << ")";
893  return s;
894 }
895 
896 template<typename A, typename B>
897 std::ostream &operator<<(std::ostream &s, const CmpOp<GT, A, B> &op) {
898  s << "(" << op.a << " > " << op.b << ")";
899  return s;
900 }
901 
902 template<typename A, typename B>
903 std::ostream &operator<<(std::ostream &s, const CmpOp<EQ, A, B> &op) {
904  s << "(" << op.a << " == " << op.b << ")";
905  return s;
906 }
907 
908 template<typename A, typename B>
909 std::ostream &operator<<(std::ostream &s, const CmpOp<NE, A, B> &op) {
910  s << "(" << op.a << " != " << op.b << ")";
911  return s;
912 }
913 
914 template<typename A, typename B>
915 std::ostream &operator<<(std::ostream &s, const BinOp<Mod, A, B> &op) {
916  s << "(" << op.a << " % " << op.b << ")";
917  return s;
918 }
919 
920 template<typename A, typename B>
921 HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp<Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
922  assert_is_lvalue_if_expr<A>();
923  assert_is_lvalue_if_expr<B>();
924  return {pattern_arg(a), pattern_arg(b)};
925 }
926 
927 template<typename A, typename B>
928 HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b)) {
929  assert_is_lvalue_if_expr<A>();
930  assert_is_lvalue_if_expr<B>();
931  return IRMatcher::operator+(a, b);
932 }
933 
934 template<>
936  t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
937  int dead_bits = 64 - t.bits;
938  // Drop the high bits then sign-extend them back
939  return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits;
940 }
941 
942 template<>
944  uint64_t ones = (uint64_t)(-1);
945  return (a + b) & (ones >> (64 - t.bits));
946 }
947 
948 template<>
949 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Add>(halide_type_t &t, double a, double b) noexcept {
950  return a + b;
951 }
952 
953 template<typename A, typename B>
954 HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp<Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
955  assert_is_lvalue_if_expr<A>();
956  assert_is_lvalue_if_expr<B>();
957  return {pattern_arg(a), pattern_arg(b)};
958 }
959 
960 template<typename A, typename B>
961 HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b)) {
962  assert_is_lvalue_if_expr<A>();
963  assert_is_lvalue_if_expr<B>();
964  return IRMatcher::operator-(a, b);
965 }
966 
967 template<>
969  t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
970  // Drop the high bits then sign-extend them back
971  int dead_bits = 64 - t.bits;
972  return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits;
973 }
974 
975 template<>
977  uint64_t ones = (uint64_t)(-1);
978  return (a - b) & (ones >> (64 - t.bits));
979 }
980 
981 template<>
982 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Sub>(halide_type_t &t, double a, double b) noexcept {
983  return a - b;
984 }
985 
986 template<typename A, typename B>
987 HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp<Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
988  assert_is_lvalue_if_expr<A>();
989  assert_is_lvalue_if_expr<B>();
990  return {pattern_arg(a), pattern_arg(b)};
991 }
992 
993 template<typename A, typename B>
994 HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b)) {
995  assert_is_lvalue_if_expr<A>();
996  assert_is_lvalue_if_expr<B>();
997  return IRMatcher::operator*(a, b);
998 }
999 
1000 template<>
1002  t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
1003  int dead_bits = 64 - t.bits;
1004  // Drop the high bits then sign-extend them back
1005  return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits;
1006 }
1007 
1008 template<>
1010  uint64_t ones = (uint64_t)(-1);
1011  return (a * b) & (ones >> (64 - t.bits));
1012 }
1013 
1014 template<>
1015 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mul>(halide_type_t &t, double a, double b) noexcept {
1016  return a * b;
1017 }
1018 
1019 template<typename A, typename B>
1020 HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp<Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1021  assert_is_lvalue_if_expr<A>();
1022  assert_is_lvalue_if_expr<B>();
1023  return {pattern_arg(a), pattern_arg(b)};
1024 }
1025 
1026 template<typename A, typename B>
1027 HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b)) {
1028  return IRMatcher::operator/(a, b);
1029 }
1030 
1031 template<>
1033  return div_imp(a, b);
1034 }
1035 
1036 template<>
1038  return div_imp(a, b);
1039 }
1040 
1041 template<>
1042 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Div>(halide_type_t &t, double a, double b) noexcept {
1043  return div_imp(a, b);
1044 }
1045 
1046 template<typename A, typename B>
1047 HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp<Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1048  assert_is_lvalue_if_expr<A>();
1049  assert_is_lvalue_if_expr<B>();
1050  return {pattern_arg(a), pattern_arg(b)};
1051 }
1052 
1053 template<typename A, typename B>
1054 HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b)) {
1055  assert_is_lvalue_if_expr<A>();
1056  assert_is_lvalue_if_expr<B>();
1057  return IRMatcher::operator%(a, b);
1058 }
1059 
1060 template<>
1062  return mod_imp(a, b);
1063 }
1064 
1065 template<>
1067  return mod_imp(a, b);
1068 }
1069 
1070 template<>
1071 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mod>(halide_type_t &t, double a, double b) noexcept {
1072  return mod_imp(a, b);
1073 }
1074 
1075 template<typename A, typename B>
1076 HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp<Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1077  assert_is_lvalue_if_expr<A>();
1078  assert_is_lvalue_if_expr<B>();
1079  return {pattern_arg(a), pattern_arg(b)};
1080 }
1081 
1082 template<>
1084  return std::min(a, b);
1085 }
1086 
1087 template<>
1089  return std::min(a, b);
1090 }
1091 
1092 template<>
1093 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Min>(halide_type_t &t, double a, double b) noexcept {
1094  return std::min(a, b);
1095 }
1096 
1097 template<typename A, typename B>
1098 HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp<Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1099  assert_is_lvalue_if_expr<A>();
1100  assert_is_lvalue_if_expr<B>();
1101  return {pattern_arg(std::forward<A>(a)), pattern_arg(std::forward<B>(b))};
1102 }
1103 
1104 template<>
1106  return std::max(a, b);
1107 }
1108 
1109 template<>
1111  return std::max(a, b);
1112 }
1113 
1114 template<>
1115 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Max>(halide_type_t &t, double a, double b) noexcept {
1116  return std::max(a, b);
1117 }
1118 
1119 template<typename A, typename B>
1120 HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp<LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1121  return {pattern_arg(a), pattern_arg(b)};
1122 }
1123 
1124 template<typename A, typename B>
1125 HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b)) {
1126  return IRMatcher::operator<(a, b);
1127 }
1128 
1129 template<>
1131  return a < b;
1132 }
1133 
1134 template<>
1136  return a < b;
1137 }
1138 
1139 template<>
1141  return a < b;
1142 }
1143 
1144 template<typename A, typename B>
1145 HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp<GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1146  return {pattern_arg(a), pattern_arg(b)};
1147 }
1148 
1149 template<typename A, typename B>
1150 HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b)) {
1151  return IRMatcher::operator>(a, b);
1152 }
1153 
1154 template<>
1156  return a > b;
1157 }
1158 
1159 template<>
1161  return a > b;
1162 }
1163 
1164 template<>
1166  return a > b;
1167 }
1168 
1169 template<typename A, typename B>
1170 HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp<LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1171  return {pattern_arg(a), pattern_arg(b)};
1172 }
1173 
1174 template<typename A, typename B>
1175 HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b)) {
1176  return IRMatcher::operator<=(a, b);
1177 }
1178 
1179 template<>
1181  return a <= b;
1182 }
1183 
1184 template<>
1186  return a <= b;
1187 }
1188 
1189 template<>
1191  return a <= b;
1192 }
1193 
1194 template<typename A, typename B>
1195 HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp<GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1196  return {pattern_arg(a), pattern_arg(b)};
1197 }
1198 
1199 template<typename A, typename B>
1200 HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b)) {
1201  return IRMatcher::operator>=(a, b);
1202 }
1203 
1204 template<>
1206  return a >= b;
1207 }
1208 
1209 template<>
1211  return a >= b;
1212 }
1213 
1214 template<>
1216  return a >= b;
1217 }
1218 
1219 template<typename A, typename B>
1220 HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp<EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1221  return {pattern_arg(a), pattern_arg(b)};
1222 }
1223 
1224 template<typename A, typename B>
1225 HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b)) {
1226  return IRMatcher::operator==(a, b);
1227 }
1228 
1229 template<>
1231  return a == b;
1232 }
1233 
1234 template<>
1236  return a == b;
1237 }
1238 
1239 template<>
1241  return a == b;
1242 }
1243 
1244 template<typename A, typename B>
1245 HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp<NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1246  return {pattern_arg(a), pattern_arg(b)};
1247 }
1248 
1249 template<typename A, typename B>
1250 HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b)) {
1251  return IRMatcher::operator!=(a, b);
1252 }
1253 
1254 template<>
1256  return a != b;
1257 }
1258 
1259 template<>
1261  return a != b;
1262 }
1263 
1264 template<>
1266  return a != b;
1267 }
1268 
1269 template<typename A, typename B>
1270 HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp<Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1271  return {pattern_arg(a), pattern_arg(b)};
1272 }
1273 
1274 template<typename A, typename B>
1275 HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b)) {
1276  return IRMatcher::operator||(a, b);
1277 }
1278 
1279 template<>
1281  return (a | b) & 1;
1282 }
1283 
1284 template<>
1286  return (a | b) & 1;
1287 }
1288 
1289 template<>
1290 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Or>(halide_type_t &t, double a, double b) noexcept {
1291  // Unreachable, as it would be a type mismatch.
1292  return 0;
1293 }
1294 
1295 template<typename A, typename B>
1296 HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp<And, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1297  return {pattern_arg(a), pattern_arg(b)};
1298 }
1299 
1300 template<typename A, typename B>
1301 HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b)) {
1302  return IRMatcher::operator&&(a, b);
1303 }
1304 
1305 template<>
1307  return a & b & 1;
1308 }
1309 
1310 template<>
1312  return a & b & 1;
1313 }
1314 
1315 template<>
1316 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<And>(halide_type_t &t, double a, double b) noexcept {
1317  // Unreachable
1318  return 0;
1319 }
1320 
1321 constexpr inline uint32_t bitwise_or_reduce() {
1322  return 0;
1323 }
1324 
1325 template<typename... Args>
1326 constexpr uint32_t bitwise_or_reduce(uint32_t first, Args... rest) {
1327  return first | bitwise_or_reduce(rest...);
1328 }
1329 
1330 constexpr inline bool and_reduce() {
1331  return true;
1332 }
1333 
1334 template<typename... Args>
1335 constexpr bool and_reduce(bool first, Args... rest) {
1336  return first && and_reduce(rest...);
1337 }
1338 
1339 // TODO: this can be replaced with std::min() once we require C++14 or later
1340 constexpr int const_min(int a, int b) {
1341  return a < b ? a : b;
1342 }
1343 
1344 template<typename... Args>
1345 struct Intrin {
1346  struct pattern_tag {};
1348  std::tuple<Args...> args;
1349  // The type of the output of the intrinsic node.
1350  // Only necessary in cases where it can't be inferred
1351  // from the input types (e.g. saturating_cast).
1353 
1355 
1358  constexpr static bool canonical = and_reduce((Args::canonical)...);
1359 
1360  template<int i,
1361  uint32_t bound,
1362  typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1363  HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept {
1364  using T = decltype(std::get<i>(args));
1365  return (std::get<i>(args).template match<bound>(*c.args[i].get(), state) &&
1366  match_args<i + 1, bound | bindings<T>::mask>(0, c, state));
1367  }
1368 
1369  template<int i, uint32_t binds>
1370  HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept {
1371  return true;
1372  }
1373 
1374  template<uint32_t bound>
1375  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1376  if (e.node_type != IRNodeType::Call) {
1377  return false;
1378  }
1379  const Call &c = (const Call &)e;
1380  return (c.is_intrinsic(intrin) &&
1381  ((optional_type_hint == Type()) || optional_type_hint == e.type) &&
1382  match_args<0, bound>(0, c, state));
1383  }
1384 
1385  template<int i,
1386  typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1387  HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const {
1388  s << std::get<i>(args);
1389  if (i + 1 < sizeof...(Args)) {
1390  s << ", ";
1391  }
1392  print_args<i + 1>(0, s);
1393  }
1394 
1395  template<int i>
1396  HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const {
1397  }
1398 
1400  void print_args(std::ostream &s) const {
1401  print_args<0>(0, s);
1402  }
1403 
1405  Expr make(MatcherState &state, halide_type_t type_hint) const {
1406  Expr arg0 = std::get<0>(args).make(state, type_hint);
1407  if (intrin == Call::likely) {
1408  return likely(arg0);
1409  } else if (intrin == Call::likely_if_innermost) {
1410  return likely_if_innermost(arg0);
1411  } else if (intrin == Call::abs) {
1412  return abs(arg0);
1413  } else if (intrin == Call::saturating_cast) {
1414  return saturating_cast(optional_type_hint, arg0);
1415  }
1416 
1417  Expr arg1 = std::get<const_min(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
1418  if (intrin == Call::absd) {
1419  return absd(arg0, arg1);
1420  } else if (intrin == Call::widen_right_add) {
1421  return widen_right_add(arg0, arg1);
1422  } else if (intrin == Call::widen_right_mul) {
1423  return widen_right_mul(arg0, arg1);
1424  } else if (intrin == Call::widen_right_sub) {
1425  return widen_right_sub(arg0, arg1);
1426  } else if (intrin == Call::widening_add) {
1427  return widening_add(arg0, arg1);
1428  } else if (intrin == Call::widening_sub) {
1429  return widening_sub(arg0, arg1);
1430  } else if (intrin == Call::widening_mul) {
1431  return widening_mul(arg0, arg1);
1432  } else if (intrin == Call::saturating_add) {
1433  return saturating_add(arg0, arg1);
1434  } else if (intrin == Call::saturating_sub) {
1435  return saturating_sub(arg0, arg1);
1436  } else if (intrin == Call::halving_add) {
1437  return halving_add(arg0, arg1);
1438  } else if (intrin == Call::halving_sub) {
1439  return halving_sub(arg0, arg1);
1440  } else if (intrin == Call::rounding_halving_add) {
1441  return rounding_halving_add(arg0, arg1);
1442  } else if (intrin == Call::shift_left) {
1443  return arg0 << arg1;
1444  } else if (intrin == Call::shift_right) {
1445  return arg0 >> arg1;
1446  } else if (intrin == Call::rounding_shift_left) {
1447  return rounding_shift_left(arg0, arg1);
1448  } else if (intrin == Call::rounding_shift_right) {
1449  return rounding_shift_right(arg0, arg1);
1450  }
1451 
1452  Expr arg2 = std::get<const_min(2, sizeof...(Args) - 1)>(args).make(state, type_hint);
1453  if (intrin == Call::mul_shift_right) {
1454  return mul_shift_right(arg0, arg1, arg2);
1455  } else if (intrin == Call::rounding_mul_shift_right) {
1456  return rounding_mul_shift_right(arg0, arg1, arg2);
1457  }
1458 
1459  internal_error << "Unhandled intrinsic in IRMatcher: " << intrin;
1460  return Expr();
1461  }
1462 
1463  constexpr static bool foldable = true;
1464 
1466  halide_scalar_value_t arg1;
1467  // Assuming the args have the same type as the intrinsic is incorrect in
1468  // general. But for the intrinsics we can fold (just shifts), the LHS
1469  // has the same type as the intrinsic, and we can always treat the RHS
1470  // as a signed int, because we're using 64 bits for it.
1471  std::get<0>(args).make_folded_const(val, ty, state);
1472  halide_type_t signed_ty = ty;
1473  signed_ty.code = halide_type_int;
1474  // We can just directly get the second arg here, because we only want to
1475  // instantiate this method for shifts, which have two args.
1476  std::get<1>(args).make_folded_const(arg1, signed_ty, state);
1477 
1478  if (intrin == Call::shift_left) {
1479  if (arg1.u.i64 < 0) {
1480  if (ty.code == halide_type_int) {
1481  // Arithmetic shift
1482  val.u.i64 >>= -arg1.u.i64;
1483  } else {
1484  // Logical shift
1485  val.u.u64 >>= -arg1.u.i64;
1486  }
1487  } else {
1488  val.u.u64 <<= arg1.u.i64;
1489  }
1490  } else if (intrin == Call::shift_right) {
1491  if (arg1.u.i64 > 0) {
1492  if (ty.code == halide_type_int) {
1493  // Arithmetic shift
1494  val.u.i64 >>= arg1.u.i64;
1495  } else {
1496  // Logical shift
1497  val.u.u64 >>= arg1.u.i64;
1498  }
1499  } else {
1500  val.u.u64 <<= -arg1.u.i64;
1501  }
1502  } else {
1503  internal_error << "Folding not implemented for intrinsic: " << intrin;
1504  }
1505  }
1506 
1509  : intrin(intrin), args(args...) {
1510  }
1511 };
1512 
1513 template<typename... Args>
1514 std::ostream &operator<<(std::ostream &s, const Intrin<Args...> &op) {
1515  s << op.intrin << "(";
1516  op.print_args(s);
1517  s << ")";
1518  return s;
1519 }
1520 
1521 template<typename... Args>
1522 HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin<decltype(pattern_arg(args))...> {
1523  return {intrinsic_op, pattern_arg(args)...};
1524 }
1525 
1526 template<typename A, typename B>
1527 auto widen_right_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1529 }
1530 template<typename A, typename B>
1531 auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1533 }
1534 template<typename A, typename B>
1535 auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1537 }
1538 
1539 template<typename A, typename B>
1540 auto widening_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1541  return {Call::widening_add, pattern_arg(a), pattern_arg(b)};
1542 }
1543 template<typename A, typename B>
1544 auto widening_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1545  return {Call::widening_sub, pattern_arg(a), pattern_arg(b)};
1546 }
1547 template<typename A, typename B>
1548 auto widening_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1549  return {Call::widening_mul, pattern_arg(a), pattern_arg(b)};
1550 }
1551 template<typename A, typename B>
1552 auto saturating_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1553  return {Call::saturating_add, pattern_arg(a), pattern_arg(b)};
1554 }
1555 template<typename A, typename B>
1556 auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1557  return {Call::saturating_sub, pattern_arg(a), pattern_arg(b)};
1558 }
1559 template<typename A>
1560 auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin<decltype(pattern_arg(a))> {
1561  Intrin<decltype(pattern_arg(a))> p = {Call::saturating_cast, pattern_arg(a)};
1562  p.optional_type_hint = t;
1563  return p;
1564 }
1565 template<typename A, typename B>
1566 auto halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1567  return {Call::halving_add, pattern_arg(a), pattern_arg(b)};
1568 }
1569 template<typename A, typename B>
1570 auto halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1571  return {Call::halving_sub, pattern_arg(a), pattern_arg(b)};
1572 }
1573 template<typename A, typename B>
1574 auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1576 }
1577 template<typename A, typename B>
1578 auto shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1579  return {Call::shift_left, pattern_arg(a), pattern_arg(b)};
1580 }
1581 template<typename A, typename B>
1582 auto shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1583  return {Call::shift_right, pattern_arg(a), pattern_arg(b)};
1584 }
1585 template<typename A, typename B>
1586 auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1588 }
1589 template<typename A, typename B>
1590 auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1592 }
1593 template<typename A, typename B, typename C>
1594 auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1596 }
1597 template<typename A, typename B, typename C>
1598 auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1600 }
1601 
1602 template<typename A>
1603 struct NotOp {
1604  struct pattern_tag {};
1605  A a;
1606 
1607  constexpr static uint32_t binds = bindings<A>::mask;
1608 
1611  constexpr static bool canonical = A::canonical;
1612 
1613  template<uint32_t bound>
1614  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1615  if (e.node_type != IRNodeType::Not) {
1616  return false;
1617  }
1618  const Not &op = (const Not &)e;
1619  return (a.template match<bound>(*op.a.get(), state));
1620  }
1621 
1622  template<uint32_t bound, typename A2>
1623  HALIDE_ALWAYS_INLINE bool match(const NotOp<A2> &op, MatcherState &state) const noexcept {
1624  return a.template match<bound>(unwrap(op.a), state);
1625  }
1626 
1628  Expr make(MatcherState &state, halide_type_t type_hint) const {
1629  return Not::make(a.make(state, type_hint));
1630  }
1631 
1632  constexpr static bool foldable = A::foldable;
1633 
1634  template<typename A1 = A>
1636  a.make_folded_const(val, ty, state);
1637  val.u.u64 = ~val.u.u64;
1638  val.u.u64 &= 1;
1639  }
1640 };
1641 
1642 template<typename A>
1643 HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp<decltype(pattern_arg(a))> {
1644  assert_is_lvalue_if_expr<A>();
1645  return {pattern_arg(a)};
1646 }
1647 
1648 template<typename A>
1649 HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a)) {
1650  assert_is_lvalue_if_expr<A>();
1651  return IRMatcher::operator!(a);
1652 }
1653 
1654 template<typename A>
1655 inline std::ostream &operator<<(std::ostream &s, const NotOp<A> &op) {
1656  s << "!(" << op.a << ")";
1657  return s;
1658 }
1659 
1660 template<typename C, typename T, typename F>
1661 struct SelectOp {
1662  struct pattern_tag {};
1663  C c;
1664  T t;
1665  F f;
1666 
1668 
1671 
1672  constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
1673 
1674  template<uint32_t bound>
1675  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1676  if (e.node_type != Select::_node_type) {
1677  return false;
1678  }
1679  const Select &op = (const Select &)e;
1680  return (c.template match<bound>(*op.condition.get(), state) &&
1681  t.template match<bound | bindings<C>::mask>(*op.true_value.get(), state) &&
1682  f.template match<bound | bindings<C>::mask | bindings<T>::mask>(*op.false_value.get(), state));
1683  }
1684  template<uint32_t bound, typename C2, typename T2, typename F2>
1685  HALIDE_ALWAYS_INLINE bool match(const SelectOp<C2, T2, F2> &instance, MatcherState &state) const noexcept {
1686  return (c.template match<bound>(unwrap(instance.c), state) &&
1687  t.template match<bound | bindings<C>::mask>(unwrap(instance.t), state) &&
1688  f.template match<bound | bindings<C>::mask | bindings<T>::mask>(unwrap(instance.f), state));
1689  }
1690 
1692  Expr make(MatcherState &state, halide_type_t type_hint) const {
1693  return Select::make(c.make(state, {}), t.make(state, type_hint), f.make(state, type_hint));
1694  }
1695 
1696  constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
1697 
1698  template<typename C1 = C>
1700  halide_scalar_value_t c_val, t_val, f_val;
1701  halide_type_t c_ty;
1702  c.make_folded_const(c_val, c_ty, state);
1703  if ((c_val.u.u64 & 1) == 1) {
1704  t.make_folded_const(val, ty, state);
1705  } else {
1706  f.make_folded_const(val, ty, state);
1707  }
1708  ty.lanes |= c_ty.lanes & MatcherState::special_values_mask;
1709  }
1710 };
1711 
1712 template<typename C, typename T, typename F>
1713 std::ostream &operator<<(std::ostream &s, const SelectOp<C, T, F> &op) {
1714  s << "select(" << op.c << ", " << op.t << ", " << op.f << ")";
1715  return s;
1716 }
1717 
1718 template<typename C, typename T, typename F>
1719 HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp<decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))> {
1720  assert_is_lvalue_if_expr<C>();
1721  assert_is_lvalue_if_expr<T>();
1722  assert_is_lvalue_if_expr<F>();
1723  return {pattern_arg(c), pattern_arg(t), pattern_arg(f)};
1724 }
1725 
1726 template<typename A, typename B>
1727 struct BroadcastOp {
1728  struct pattern_tag {};
1729  A a;
1731 
1733 
1736 
1737  constexpr static bool canonical = A::canonical && B::canonical;
1738 
1739  template<uint32_t bound>
1740  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1741  if (e.node_type == Broadcast::_node_type) {
1742  const Broadcast &op = (const Broadcast &)e;
1743  if (a.template match<bound>(*op.value.get(), state) &&
1744  lanes.template match<bound>(op.lanes, state)) {
1745  return true;
1746  }
1747  }
1748  return false;
1749  }
1750 
1751  template<uint32_t bound, typename A2, typename B2>
1752  HALIDE_ALWAYS_INLINE bool match(const BroadcastOp<A2, B2> &op, MatcherState &state) const noexcept {
1753  return (a.template match<bound>(unwrap(op.a), state) &&
1754  lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1755  }
1756 
1758  Expr make(MatcherState &state, halide_type_t type_hint) const {
1759  halide_scalar_value_t lanes_val;
1760  halide_type_t ty;
1761  lanes.make_folded_const(lanes_val, ty, state);
1762  int32_t l = (int32_t)lanes_val.u.i64;
1763  type_hint.lanes /= l;
1764  Expr val = a.make(state, type_hint);
1765  if (l == 1) {
1766  return val;
1767  } else {
1768  return Broadcast::make(std::move(val), l);
1769  }
1770  }
1771 
1772  constexpr static bool foldable = false;
1773 
1774  template<typename A1 = A>
1776  halide_scalar_value_t lanes_val;
1777  halide_type_t lanes_ty;
1778  lanes.make_folded_const(lanes_val, lanes_ty, state);
1779  uint16_t l = (uint16_t)lanes_val.u.i64;
1780  a.make_folded_const(val, ty, state);
1781  ty.lanes = l | (ty.lanes & MatcherState::special_values_mask);
1782  }
1783 };
1784 
1785 template<typename A, typename B>
1786 inline std::ostream &operator<<(std::ostream &s, const BroadcastOp<A, B> &op) {
1787  s << "broadcast(" << op.a << ", " << op.lanes << ")";
1788  return s;
1789 }
1790 
1791 template<typename A, typename B>
1792 HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes))> {
1793  assert_is_lvalue_if_expr<A>();
1794  return {pattern_arg(a), pattern_arg(lanes)};
1795 }
1796 
1797 template<typename A, typename B, typename C>
1798 struct RampOp {
1799  struct pattern_tag {};
1800  A a;
1801  B b;
1803 
1805 
1808 
1809  constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
1810 
1811  template<uint32_t bound>
1812  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1813  if (e.node_type != Ramp::_node_type) {
1814  return false;
1815  }
1816  const Ramp &op = (const Ramp &)e;
1817  if (a.template match<bound>(*op.base.get(), state) &&
1818  b.template match<bound | bindings<A>::mask>(*op.stride.get(), state) &&
1819  lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(op.lanes, state)) {
1820  return true;
1821  } else {
1822  return false;
1823  }
1824  }
1825 
1826  template<uint32_t bound, typename A2, typename B2, typename C2>
1827  HALIDE_ALWAYS_INLINE bool match(const RampOp<A2, B2, C2> &op, MatcherState &state) const noexcept {
1828  return (a.template match<bound>(unwrap(op.a), state) &&
1829  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state) &&
1830  lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(unwrap(op.lanes), state));
1831  }
1832 
1834  Expr make(MatcherState &state, halide_type_t type_hint) const {
1835  halide_scalar_value_t lanes_val;
1836  halide_type_t ty;
1837  lanes.make_folded_const(lanes_val, ty, state);
1838  int32_t l = (int32_t)lanes_val.u.i64;
1839  type_hint.lanes /= l;
1840  Expr ea, eb;
1841  eb = b.make(state, type_hint);
1842  ea = a.make(state, eb.type());
1843  return Ramp::make(ea, eb, l);
1844  }
1845 
1846  constexpr static bool foldable = false;
1847 };
1848 
1849 template<typename A, typename B, typename C>
1850 std::ostream &operator<<(std::ostream &s, const RampOp<A, B, C> &op) {
1851  s << "ramp(" << op.a << ", " << op.b << ", " << op.lanes << ")";
1852  return s;
1853 }
1854 
1855 template<typename A, typename B, typename C>
1856 HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1857  assert_is_lvalue_if_expr<A>();
1858  assert_is_lvalue_if_expr<B>();
1859  assert_is_lvalue_if_expr<C>();
1860  return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1861 }
1862 
1863 template<typename A, typename B, VectorReduce::Operator reduce_op>
1865  struct pattern_tag {};
1866  A a;
1868 
1870 
1873  constexpr static bool canonical = A::canonical;
1874 
1875  template<uint32_t bound>
1876  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1877  if (e.node_type == VectorReduce::_node_type) {
1878  const VectorReduce &op = (const VectorReduce &)e;
1879  if (op.op == reduce_op &&
1880  a.template match<bound>(*op.value.get(), state) &&
1881  lanes.template match<bound | bindings<A>::mask>(op.type.lanes(), state)) {
1882  return true;
1883  }
1884  }
1885  return false;
1886  }
1887 
1888  template<uint32_t bound, typename A2, typename B2, VectorReduce::Operator reduce_op_2>
1890  return (reduce_op == reduce_op_2 &&
1891  a.template match<bound>(unwrap(op.a), state) &&
1892  lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1893  }
1894 
1896  Expr make(MatcherState &state, halide_type_t type_hint) const {
1897  halide_scalar_value_t lanes_val;
1898  halide_type_t ty;
1899  lanes.make_folded_const(lanes_val, ty, state);
1900  int l = (int)lanes_val.u.i64;
1901  return VectorReduce::make(reduce_op, a.make(state, type_hint), l);
1902  }
1903 
1904  constexpr static bool foldable = false;
1905 };
1906 
1907 template<typename A, typename B, VectorReduce::Operator reduce_op>
1908 inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp<A, B, reduce_op> &op) {
1909  s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")";
1910  return s;
1911 }
1912 
1913 template<typename A, typename B>
1914 HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add> {
1915  assert_is_lvalue_if_expr<A>();
1916  return {pattern_arg(a), pattern_arg(lanes)};
1917 }
1918 
1919 template<typename A, typename B>
1920 HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min> {
1921  assert_is_lvalue_if_expr<A>();
1922  return {pattern_arg(a), pattern_arg(lanes)};
1923 }
1924 
1925 template<typename A, typename B>
1926 HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max> {
1927  assert_is_lvalue_if_expr<A>();
1928  return {pattern_arg(a), pattern_arg(lanes)};
1929 }
1930 
1931 template<typename A, typename B>
1932 HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And> {
1933  assert_is_lvalue_if_expr<A>();
1934  return {pattern_arg(a), pattern_arg(lanes)};
1935 }
1936 
1937 template<typename A, typename B>
1938 HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or> {
1939  assert_is_lvalue_if_expr<A>();
1940  return {pattern_arg(a), pattern_arg(lanes)};
1941 }
1942 
1943 template<typename A>
1944 struct NegateOp {
1945  struct pattern_tag {};
1946  A a;
1947 
1948  constexpr static uint32_t binds = bindings<A>::mask;
1949 
1952 
1953  constexpr static bool canonical = A::canonical;
1954 
1955  template<uint32_t bound>
1956  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1957  if (e.node_type != Sub::_node_type) {
1958  return false;
1959  }
1960  const Sub &op = (const Sub &)e;
1961  return (a.template match<bound>(*op.b.get(), state) &&
1962  is_const_zero(op.a));
1963  }
1964 
1965  template<uint32_t bound, typename A2>
1966  HALIDE_ALWAYS_INLINE bool match(NegateOp<A2> &&p, MatcherState &state) const noexcept {
1967  return a.template match<bound>(unwrap(p.a), state);
1968  }
1969 
1971  Expr make(MatcherState &state, halide_type_t type_hint) const {
1972  Expr ea = a.make(state, type_hint);
1973  Expr z = make_zero(ea.type());
1974  return Sub::make(std::move(z), std::move(ea));
1975  }
1976 
1977  constexpr static bool foldable = A::foldable;
1978 
1979  template<typename A1 = A>
1981  a.make_folded_const(val, ty, state);
1982  int dead_bits = 64 - ty.bits;
1983  switch (ty.code) {
1984  case halide_type_int:
1985  if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
1986  // Trying to negate the most negative signed int for a no-overflow type.
1988  } else {
1989  // Negate, drop the high bits, and then sign-extend them back
1990  val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits;
1991  }
1992  break;
1993  case halide_type_uint:
1994  val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
1995  break;
1996  case halide_type_float:
1997  case halide_type_bfloat:
1998  val.u.f64 = -val.u.f64;
1999  break;
2000  default:
2001  // unreachable
2002  ;
2003  }
2004  }
2005 };
2006 
2007 template<typename A>
2008 std::ostream &operator<<(std::ostream &s, const NegateOp<A> &op) {
2009  s << "-" << op.a;
2010  return s;
2011 }
2012 
2013 template<typename A>
2014 HALIDE_ALWAYS_INLINE auto operator-(A &&a) noexcept -> NegateOp<decltype(pattern_arg(a))> {
2015  assert_is_lvalue_if_expr<A>();
2016  return {pattern_arg(a)};
2017 }
2018 
2019 template<typename A>
2020 HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a)) {
2021  assert_is_lvalue_if_expr<A>();
2022  return IRMatcher::operator-(a);
2023 }
2024 
2025 template<typename A>
2026 struct CastOp {
2027  struct pattern_tag {};
2029  A a;
2030 
2031  constexpr static uint32_t binds = bindings<A>::mask;
2032 
2035  constexpr static bool canonical = A::canonical;
2036 
2037  template<uint32_t bound>
2038  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2039  if (e.node_type != Cast::_node_type) {
2040  return false;
2041  }
2042  const Cast &op = (const Cast &)e;
2043  return (e.type == t &&
2044  a.template match<bound>(*op.value.get(), state));
2045  }
2046  template<uint32_t bound, typename A2>
2047  HALIDE_ALWAYS_INLINE bool match(const CastOp<A2> &op, MatcherState &state) const noexcept {
2048  return t == op.t && a.template match<bound>(unwrap(op.a), state);
2049  }
2050 
2052  Expr make(MatcherState &state, halide_type_t type_hint) const {
2053  return cast(t, a.make(state, {}));
2054  }
2055 
2056  constexpr static bool foldable = false;
2057 };
2058 
2059 template<typename A>
2060 std::ostream &operator<<(std::ostream &s, const CastOp<A> &op) {
2061  s << "cast(" << op.t << ", " << op.a << ")";
2062  return s;
2063 }
2064 
2065 template<typename A>
2066 HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp<decltype(pattern_arg(a))> {
2067  assert_is_lvalue_if_expr<A>();
2068  return {t, pattern_arg(a)};
2069 }
2070 
2071 template<typename A>
2072 struct WidenOp {
2073  struct pattern_tag {};
2074  A a;
2075 
2076  constexpr static uint32_t binds = bindings<A>::mask;
2077 
2080  constexpr static bool canonical = A::canonical;
2081 
2082  template<uint32_t bound>
2083  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2084  if (e.node_type != Cast::_node_type) {
2085  return false;
2086  }
2087  const Cast &op = (const Cast &)e;
2088  return (e.type == op.value.type().widen() &&
2089  a.template match<bound>(*op.value.get(), state));
2090  }
2091  template<uint32_t bound, typename A2>
2092  HALIDE_ALWAYS_INLINE bool match(const WidenOp<A2> &op, MatcherState &state) const noexcept {
2093  return a.template match<bound>(unwrap(op.a), state);
2094  }
2095 
2097  Expr make(MatcherState &state, halide_type_t type_hint) const {
2098  Expr e = a.make(state, {});
2099  Type w = e.type().widen();
2100  return cast(w, std::move(e));
2101  }
2102 
2103  constexpr static bool foldable = false;
2104 };
2105 
2106 template<typename A>
2107 std::ostream &operator<<(std::ostream &s, const WidenOp<A> &op) {
2108  s << "widen(" << op.a << ")";
2109  return s;
2110 }
2111 
2112 template<typename A>
2113 HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp<decltype(pattern_arg(a))> {
2114  assert_is_lvalue_if_expr<A>();
2115  return {pattern_arg(a)};
2116 }
2117 
2118 template<typename Vec, typename Base, typename Stride, typename Lanes>
2119 struct SliceOp {
2120  struct pattern_tag {};
2121  Vec vec;
2122  Base base;
2123  Stride stride;
2124  Lanes lanes;
2125 
2126  static constexpr uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds;
2127 
2130  constexpr static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical;
2131 
2132  template<uint32_t bound>
2133  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2134  if (e.node_type != IRNodeType::Shuffle) {
2135  return false;
2136  }
2137  const Shuffle &v = (const Shuffle &)e;
2138  return v.vectors.size() == 1 &&
2139  v.is_slice() &&
2140  vec.template match<bound>(*v.vectors[0].get(), state) &&
2141  base.template match<bound | bindings<Vec>::mask>(v.slice_begin(), state) &&
2142  stride.template match<bound | bindings<Vec>::mask | bindings<Base>::mask>(v.slice_stride(), state) &&
2144  }
2145 
2147  Expr make(MatcherState &state, halide_type_t type_hint) const {
2148  halide_scalar_value_t base_val, stride_val, lanes_val;
2149  halide_type_t ty;
2150  base.make_folded_const(base_val, ty, state);
2151  int b = (int)base_val.u.i64;
2152  stride.make_folded_const(stride_val, ty, state);
2153  int s = (int)stride_val.u.i64;
2154  lanes.make_folded_const(lanes_val, ty, state);
2155  int l = (int)lanes_val.u.i64;
2156  return Shuffle::make_slice(vec.make(state, type_hint), b, s, l);
2157  }
2158 
2159  constexpr static bool foldable = false;
2160 
2162  SliceOp(Vec v, Base b, Stride s, Lanes l)
2163  : vec(v), base(b), stride(s), lanes(l) {
2164  static_assert(Base::foldable, "Base of slice should consist only of operations that constant-fold");
2165  static_assert(Stride::foldable, "Stride of slice should consist only of operations that constant-fold");
2166  static_assert(Lanes::foldable, "Lanes of slice should consist only of operations that constant-fold");
2167  }
2168 };
2169 
2170 template<typename Vec, typename Base, typename Stride, typename Lanes>
2171 std::ostream &operator<<(std::ostream &s, const SliceOp<Vec, Base, Stride, Lanes> &op) {
2172  s << "slice(" << op.vec << ", " << op.base << ", " << op.stride << ", " << op.lanes << ")";
2173  return s;
2174 }
2175 
2176 template<typename Vec, typename Base, typename Stride, typename Lanes>
2177 HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept
2178  -> SliceOp<decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))> {
2179  return {pattern_arg(vec), pattern_arg(base), pattern_arg(stride), pattern_arg(lanes)};
2180 }
2181 
2182 template<typename A>
2183 struct Fold {
2184  struct pattern_tag {};
2185  A a;
2186 
2187  constexpr static uint32_t binds = bindings<A>::mask;
2188 
2191  constexpr static bool canonical = true;
2192 
2194  Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
2196  halide_type_t ty = type_hint;
2197  a.make_folded_const(c, ty, state);
2198 
2199  // The result of the fold may have an underspecified type
2200  // (e.g. because it's from an int literal). Make the type code
2201  // and bits match the required type, if there is one (we can
2202  // tell from the bits field).
2203  if (type_hint.bits) {
2204  if (((int)ty.code == (int)halide_type_int) &&
2205  ((int)type_hint.code == (int)halide_type_float)) {
2206  int64_t x = c.u.i64;
2207  c.u.f64 = (double)x;
2208  }
2209  ty.code = type_hint.code;
2210  ty.bits = type_hint.bits;
2211  }
2212 
2213  Expr e = make_const_expr(c, ty);
2214  return e;
2215  }
2216 
2217  constexpr static bool foldable = A::foldable;
2218 
2219  template<typename A1 = A>
2221  a.make_folded_const(val, ty, state);
2222  }
2223 };
2224 
2225 template<typename A>
2226 HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold<decltype(pattern_arg(a))> {
2227  assert_is_lvalue_if_expr<A>();
2228  return {pattern_arg(a)};
2229 }
2230 
2231 template<typename A>
2232 std::ostream &operator<<(std::ostream &s, const Fold<A> &op) {
2233  s << "fold(" << op.a << ")";
2234  return s;
2235 }
2236 
2237 template<typename A>
2238 struct Overflows {
2239  struct pattern_tag {};
2240  A a;
2241 
2242  constexpr static uint32_t binds = bindings<A>::mask;
2243 
2244  // This rule is a predicate, so it always evaluates to a boolean,
2245  // which has IRNodeType UIntImm
2248  constexpr static bool canonical = true;
2249 
2250  constexpr static bool foldable = A::foldable;
2251 
2252  template<typename A1 = A>
2254  a.make_folded_const(val, ty, state);
2255  ty.code = halide_type_uint;
2256  ty.bits = 64;
2257  val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0;
2258  ty.lanes = 1;
2259  }
2260 };
2261 
2262 template<typename A>
2263 HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows<decltype(pattern_arg(a))> {
2264  assert_is_lvalue_if_expr<A>();
2265  return {pattern_arg(a)};
2266 }
2267 
2268 template<typename A>
2269 std::ostream &operator<<(std::ostream &s, const Overflows<A> &op) {
2270  s << "overflows(" << op.a << ")";
2271  return s;
2272 }
2273 
2274 struct Overflow {
2275  struct pattern_tag {};
2276 
2277  constexpr static uint32_t binds = 0;
2278 
2279  // Overflow is an intrinsic, represented as a Call node
2282  constexpr static bool canonical = true;
2283 
2284  template<uint32_t bound>
2285  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2286  if (e.node_type != Call::_node_type) {
2287  return false;
2288  }
2289  const Call &op = (const Call &)e;
2291  }
2292 
2294  Expr make(MatcherState &state, halide_type_t type_hint) const {
2296  return make_const_special_expr(type_hint);
2297  }
2298 
2299  constexpr static bool foldable = true;
2300 
2302  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
2303  val.u.u64 = 0;
2305  }
2306 };
2307 
2308 inline std::ostream &operator<<(std::ostream &s, const Overflow &op) {
2309  s << "overflow()";
2310  return s;
2311 }
2312 
2313 template<typename A>
2314 struct IsConst {
2315  struct pattern_tag {};
2316 
2317  constexpr static uint32_t binds = bindings<A>::mask;
2318 
2319  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2322  constexpr static bool canonical = true;
2323 
2324  A a;
2325  bool check_v;
2327 
2328  constexpr static bool foldable = true;
2329 
2330  template<typename A1 = A>
2332  Expr e = a.make(state, {});
2333  ty.code = halide_type_uint;
2334  ty.bits = 64;
2335  ty.lanes = 1;
2336  if (check_v) {
2337  val.u.u64 = ::Halide::Internal::is_const(e, v) ? 1 : 0;
2338  } else {
2339  val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0;
2340  }
2341  }
2342 };
2343 
2344 template<typename A>
2345 HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst<decltype(pattern_arg(a))> {
2346  assert_is_lvalue_if_expr<A>();
2347  return {pattern_arg(a), false, 0};
2348 }
2349 
2350 template<typename A>
2351 HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst<decltype(pattern_arg(a))> {
2352  assert_is_lvalue_if_expr<A>();
2353  return {pattern_arg(a), true, value};
2354 }
2355 
2356 template<typename A>
2357 std::ostream &operator<<(std::ostream &s, const IsConst<A> &op) {
2358  if (op.check_v) {
2359  s << "is_const(" << op.a << ")";
2360  } else {
2361  s << "is_const(" << op.a << ", " << op.v << ")";
2362  }
2363  return s;
2364 }
2365 
2366 template<typename A, typename Prover>
2367 struct CanProve {
2368  struct pattern_tag {};
2369  A a;
2370  Prover *prover; // An existing simplifying mutator
2371 
2372  constexpr static uint32_t binds = bindings<A>::mask;
2373 
2374  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2377  constexpr static bool canonical = true;
2378 
2379  constexpr static bool foldable = true;
2380 
2381  // Includes a raw call to an inlined make method, so don't inline.
2383  Expr condition = a.make(state, {});
2384  condition = prover->mutate(condition, nullptr);
2385  val.u.u64 = is_const_one(condition);
2386  ty.code = halide_type_uint;
2387  ty.bits = 1;
2388  ty.lanes = condition.type().lanes();
2389  }
2390 };
2391 
2392 template<typename A, typename Prover>
2393 HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve<decltype(pattern_arg(a)), Prover> {
2394  assert_is_lvalue_if_expr<A>();
2395  return {pattern_arg(a), p};
2396 }
2397 
2398 template<typename A, typename Prover>
2399 std::ostream &operator<<(std::ostream &s, const CanProve<A, Prover> &op) {
2400  s << "can_prove(" << op.a << ")";
2401  return s;
2402 }
2403 
2404 template<typename A>
2405 struct IsFloat {
2406  struct pattern_tag {};
2407  A a;
2408 
2409  constexpr static uint32_t binds = bindings<A>::mask;
2410 
2411  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2414  constexpr static bool canonical = true;
2415 
2416  constexpr static bool foldable = true;
2417 
2420  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2421  Type t = a.make(state, {}).type();
2422  val.u.u64 = t.is_float();
2423  ty.code = halide_type_uint;
2424  ty.bits = 1;
2425  ty.lanes = t.lanes();
2426  }
2427 };
2428 
2429 template<typename A>
2430 HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat<decltype(pattern_arg(a))> {
2431  assert_is_lvalue_if_expr<A>();
2432  return {pattern_arg(a)};
2433 }
2434 
2435 template<typename A>
2436 std::ostream &operator<<(std::ostream &s, const IsFloat<A> &op) {
2437  s << "is_float(" << op.a << ")";
2438  return s;
2439 }
2440 
2441 template<typename A>
2442 struct IsInt {
2443  struct pattern_tag {};
2444  A a;
2445  int bits, lanes;
2446 
2447  constexpr static uint32_t binds = bindings<A>::mask;
2448 
2449  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2452  constexpr static bool canonical = true;
2453 
2454  constexpr static bool foldable = true;
2455 
2458  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2459  Type t = a.make(state, {}).type();
2460  val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2461  ty.code = halide_type_uint;
2462  ty.bits = 1;
2463  ty.lanes = t.lanes();
2464  }
2465 };
2466 
2467 template<typename A>
2468 HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits = 0, int lanes = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
2469  assert_is_lvalue_if_expr<A>();
2470  return {pattern_arg(a), bits, lanes};
2471 }
2472 
2473 template<typename A>
2474 std::ostream &operator<<(std::ostream &s, const IsInt<A> &op) {
2475  s << "is_int(" << op.a;
2476  if (op.bits > 0) {
2477  s << ", " << op.bits;
2478  }
2479  if (op.lanes > 0) {
2480  s << ", " << op.lanes;
2481  }
2482  s << ")";
2483  return s;
2484 }
2485 
2486 template<typename A>
2487 struct IsUInt {
2488  struct pattern_tag {};
2489  A a;
2490  int bits, lanes;
2491 
2492  constexpr static uint32_t binds = bindings<A>::mask;
2493 
2494  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2497  constexpr static bool canonical = true;
2498 
2499  constexpr static bool foldable = true;
2500 
2503  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2504  Type t = a.make(state, {}).type();
2505  val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2506  ty.code = halide_type_uint;
2507  ty.bits = 1;
2508  ty.lanes = t.lanes();
2509  }
2510 };
2511 
2512 template<typename A>
2513 HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits = 0, int lanes = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
2514  assert_is_lvalue_if_expr<A>();
2515  return {pattern_arg(a), bits, lanes};
2516 }
2517 
2518 template<typename A>
2519 std::ostream &operator<<(std::ostream &s, const IsUInt<A> &op) {
2520  s << "is_uint(" << op.a;
2521  if (op.bits > 0) {
2522  s << ", " << op.bits;
2523  }
2524  if (op.lanes > 0) {
2525  s << ", " << op.lanes;
2526  }
2527  s << ")";
2528  return s;
2529 }
2530 
2531 template<typename A>
2532 struct IsScalar {
2533  struct pattern_tag {};
2534  A a;
2535 
2536  constexpr static uint32_t binds = bindings<A>::mask;
2537 
2538  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2541  constexpr static bool canonical = true;
2542 
2543  constexpr static bool foldable = true;
2544 
2547  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2548  Type t = a.make(state, {}).type();
2549  val.u.u64 = t.is_scalar();
2550  ty.code = halide_type_uint;
2551  ty.bits = 1;
2552  ty.lanes = t.lanes();
2553  }
2554 };
2555 
2556 template<typename A>
2557 HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar<decltype(pattern_arg(a))> {
2558  assert_is_lvalue_if_expr<A>();
2559  return {pattern_arg(a)};
2560 }
2561 
2562 template<typename A>
2563 std::ostream &operator<<(std::ostream &s, const IsScalar<A> &op) {
2564  s << "is_scalar(" << op.a << ")";
2565  return s;
2566 }
2567 
2568 template<typename A>
2569 struct IsMaxValue {
2570  struct pattern_tag {};
2571  A a;
2572 
2573  constexpr static uint32_t binds = bindings<A>::mask;
2574 
2575  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2578  constexpr static bool canonical = true;
2579 
2580  constexpr static bool foldable = true;
2581 
2584  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2585  a.make_folded_const(val, ty, state);
2586  const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int));
2587  if (ty.code == halide_type_uint || ty.code == halide_type_int) {
2588  val.u.u64 = (val.u.u64 == max_bits);
2589  } else {
2590  val.u.u64 = 0;
2591  }
2592  ty.code = halide_type_uint;
2593  ty.bits = 1;
2594  }
2595 };
2596 
2597 template<typename A>
2598 HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue<decltype(pattern_arg(a))> {
2599  assert_is_lvalue_if_expr<A>();
2600  return {pattern_arg(a)};
2601 }
2602 
2603 template<typename A>
2604 std::ostream &operator<<(std::ostream &s, const IsMaxValue<A> &op) {
2605  s << "is_max_value(" << op.a << ")";
2606  return s;
2607 }
2608 
2609 template<typename A>
2610 struct IsMinValue {
2611  struct pattern_tag {};
2612  A a;
2613 
2614  constexpr static uint32_t binds = bindings<A>::mask;
2615 
2616  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2619  constexpr static bool canonical = true;
2620 
2621  constexpr static bool foldable = true;
2622 
2625  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2626  a.make_folded_const(val, ty, state);
2627  if (ty.code == halide_type_int) {
2628  const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1);
2629  val.u.u64 = (val.u.u64 == min_bits);
2630  } else if (ty.code == halide_type_uint) {
2631  val.u.u64 = (val.u.u64 == 0);
2632  } else {
2633  val.u.u64 = 0;
2634  }
2635  ty.code = halide_type_uint;
2636  ty.bits = 1;
2637  }
2638 };
2639 
2640 template<typename A>
2641 HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue<decltype(pattern_arg(a))> {
2642  assert_is_lvalue_if_expr<A>();
2643  return {pattern_arg(a)};
2644 }
2645 
2646 template<typename A>
2647 std::ostream &operator<<(std::ostream &s, const IsMinValue<A> &op) {
2648  s << "is_min_value(" << op.a << ")";
2649  return s;
2650 }
2651 
2652 template<typename A>
2653 struct LanesOf {
2654  struct pattern_tag {};
2655  A a;
2656 
2657  constexpr static uint32_t binds = bindings<A>::mask;
2658 
2659  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2662  constexpr static bool canonical = true;
2663 
2664  constexpr static bool foldable = true;
2665 
2668  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2669  Type t = a.make(state, {}).type();
2670  val.u.u64 = t.lanes();
2671  ty.code = halide_type_uint;
2672  ty.bits = 32;
2673  ty.lanes = 1;
2674  }
2675 };
2676 
2677 template<typename A>
2678 HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf<decltype(pattern_arg(a))> {
2679  assert_is_lvalue_if_expr<A>();
2680  return {pattern_arg(a)};
2681 }
2682 
2683 template<typename A>
2684 std::ostream &operator<<(std::ostream &s, const LanesOf<A> &op) {
2685  s << "lanes_of(" << op.a << ")";
2686  return s;
2687 }
2688 
2689 // Verify properties of each rewrite rule. Currently just fuzz tests them.
2690 template<typename Before,
2691  typename After,
2692  typename Predicate,
2693  typename = typename std::enable_if<std::decay<Before>::type::foldable &&
2694  std::decay<After>::type::foldable>::type>
2695 HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2696  halide_type_t wildcard_type, halide_type_t output_type) noexcept {
2697 
2698  // We only validate the rules in the scalar case
2699  wildcard_type.lanes = output_type.lanes = 1;
2700 
2701  // Track which types this rule has been tested for before
2702  static std::set<uint32_t> tested;
2703 
2704  if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2705  return;
2706  }
2707 
2708  // Print it in a form where it can be piped into a python/z3 validator
2709  debug(0) << "validate('" << before << "', '" << after << "', '" << pred << "', " << Type(wildcard_type) << ", " << Type(output_type) << ")\n";
2710 
2711  // Substitute some random constants into the before and after
2712  // expressions and see if the rule holds true. This should catch
2713  // silly errors, but not necessarily corner cases.
2714  static std::mt19937_64 rng(0);
2715  MatcherState state;
2716 
2717  Expr exprs[max_wild];
2718 
2719  for (int trials = 0; trials < 100; trials++) {
2720  // We want to test small constants more frequently than
2721  // large ones, otherwise we'll just get coverage of
2722  // overflow rules.
2723  int shift = (int)(rng() & (wildcard_type.bits - 1));
2724 
2725  for (int i = 0; i < max_wild; i++) {
2726  // Bind all the exprs and constants
2727  switch (wildcard_type.code) {
2728  case halide_type_uint: {
2729  // Normalize to the type's range by adding zero
2730  uint64_t val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2731  state.set_bound_const(i, val, wildcard_type);
2732  val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2733  exprs[i] = make_const(wildcard_type, val);
2734  state.set_binding(i, *exprs[i].get());
2735  } break;
2736  case halide_type_int: {
2737  int64_t val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2738  state.set_bound_const(i, val, wildcard_type);
2739  val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2740  exprs[i] = make_const(wildcard_type, val);
2741  } break;
2742  case halide_type_float:
2743  case halide_type_bfloat: {
2744  // Use a very narrow range of precise floats, so
2745  // that none of the rules a human is likely to
2746  // write have instabilities.
2747  double val = ((int64_t)(rng() & 15) - 8) / 2.0;
2748  state.set_bound_const(i, val, wildcard_type);
2749  val = ((int64_t)(rng() & 15) - 8) / 2.0;
2750  exprs[i] = make_const(wildcard_type, val);
2751  } break;
2752  default:
2753  return; // Don't care about handles
2754  }
2755  state.set_binding(i, *exprs[i].get());
2756  }
2757 
2758  halide_scalar_value_t val_pred, val_before, val_after;
2759  halide_type_t type = output_type;
2760  if (!evaluate_predicate(pred, state)) {
2761  continue;
2762  }
2763  before.make_folded_const(val_before, type, state);
2764  uint16_t lanes = type.lanes;
2765  after.make_folded_const(val_after, type, state);
2766  lanes |= type.lanes;
2767 
2768  if (lanes & MatcherState::special_values_mask) {
2769  continue;
2770  }
2771 
2772  bool ok = true;
2773  switch (output_type.code) {
2774  case halide_type_uint:
2775  // Compare normalized representations
2776  ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.u64, 0) ==
2777  constant_fold_bin_op<Add>(output_type, val_after.u.u64, 0));
2778  break;
2779  case halide_type_int:
2780  ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.i64, 0) ==
2781  constant_fold_bin_op<Add>(output_type, val_after.u.i64, 0));
2782  break;
2783  case halide_type_float:
2784  case halide_type_bfloat: {
2785  double error = std::abs(val_before.u.f64 - val_after.u.f64);
2786  // We accept an equal bit pattern (e.g. inf vs inf),
2787  // a small floating point difference, or turning a nan into not-a-nan.
2788  ok &= (error < 0.01 ||
2789  val_before.u.u64 == val_after.u.u64 ||
2790  std::isnan(val_before.u.f64));
2791  break;
2792  }
2793  default:
2794  return;
2795  }
2796 
2797  if (!ok) {
2798  debug(0) << "Fails with values:\n";
2799  for (int i = 0; i < max_wild; i++) {
2801  state.get_bound_const(i, val, wildcard_type);
2802  debug(0) << " c" << i << ": " << make_const_expr(val, wildcard_type) << "\n";
2803  }
2804  for (int i = 0; i < max_wild; i++) {
2805  debug(0) << " _" << i << ": " << Expr(state.get_binding(i)) << "\n";
2806  }
2807  debug(0) << " Before: " << make_const_expr(val_before, output_type) << "\n";
2808  debug(0) << " After: " << make_const_expr(val_after, output_type) << "\n";
2809  debug(0) << val_before.u.u64 << " " << val_after.u.u64 << "\n";
2811  }
2812  }
2813 }
2814 
2815 template<typename Before,
2816  typename After,
2817  typename Predicate,
2818  typename = typename std::enable_if<!(std::decay<Before>::type::foldable &&
2819  std::decay<After>::type::foldable)>::type>
2820 HALIDE_ALWAYS_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2821  halide_type_t, halide_type_t, int dummy = 0) noexcept {
2822  // We can't verify rewrite rules that can't be constant-folded.
2823 }
2824 
2826 bool evaluate_predicate(bool x, MatcherState &) noexcept {
2827  return x;
2828 }
2829 
2830 template<typename Pattern,
2831  typename = typename enable_if_pattern<Pattern>::type>
2834  halide_type_t ty = halide_type_of<bool>();
2835  p.make_folded_const(c, ty, state);
2836  // Overflow counts as a failed predicate
2837  return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0);
2838 }
2839 
2840 // #defines for testing
2841 
2842 // Print all successful or failed matches
2843 #define HALIDE_DEBUG_MATCHED_RULES 0
2844 #define HALIDE_DEBUG_UNMATCHED_RULES 0
2845 
2846 // Set to true if you want to fuzz test every rewrite passed to
2847 // operator() to ensure the input and the output have the same value
2848 // for lots of random values of the wildcards. Run
2849 // correctness_simplify with this on.
2850 #define HALIDE_FUZZ_TEST_RULES 0
2851 
2852 template<typename Instance>
2853 struct Rewriter {
2854  Instance instance;
2858  bool validate;
2859 
2862  : instance(std::move(instance)), output_type(ot), wildcard_type(wt) {
2863  }
2864 
2865  template<typename After>
2867  result = after.make(state, output_type);
2868  }
2869 
2870  template<typename Before,
2871  typename After,
2872  typename = typename enable_if_pattern<Before>::type,
2873  typename = typename enable_if_pattern<After>::type>
2874  HALIDE_ALWAYS_INLINE bool operator()(Before before, After after) {
2875  static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2876  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2877  static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2878 #if HALIDE_FUZZ_TEST_RULES
2879  fuzz_test_rule(before, after, true, wildcard_type, output_type);
2880 #endif
2881  if (before.template match<0>(unwrap(instance), state)) {
2882  build_replacement(after);
2883 #if HALIDE_DEBUG_MATCHED_RULES
2884  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2885 #endif
2886  return true;
2887  } else {
2888 #if HALIDE_DEBUG_UNMATCHED_RULES
2889  debug(0) << instance << " does not match " << before << "\n";
2890 #endif
2891  return false;
2892  }
2893  }
2894 
2895  template<typename Before,
2896  typename = typename enable_if_pattern<Before>::type>
2897  HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept {
2898  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2899  if (before.template match<0>(unwrap(instance), state)) {
2900  result = after;
2901 #if HALIDE_DEBUG_MATCHED_RULES
2902  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2903 #endif
2904  return true;
2905  } else {
2906 #if HALIDE_DEBUG_UNMATCHED_RULES
2907  debug(0) << instance << " does not match " << before << "\n";
2908 #endif
2909  return false;
2910  }
2911  }
2912 
2913  template<typename Before,
2914  typename = typename enable_if_pattern<Before>::type>
2915  HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept {
2916  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2917 #if HALIDE_FUZZ_TEST_RULES
2918  fuzz_test_rule(before, IntLiteral(after), true, wildcard_type, output_type);
2919 #endif
2920  if (before.template match<0>(unwrap(instance), state)) {
2921  result = make_const(output_type, after);
2922 #if HALIDE_DEBUG_MATCHED_RULES
2923  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2924 #endif
2925  return true;
2926  } else {
2927 #if HALIDE_DEBUG_UNMATCHED_RULES
2928  debug(0) << instance << " does not match " << before << "\n";
2929 #endif
2930  return false;
2931  }
2932  }
2933 
2934  template<typename Before,
2935  typename After,
2936  typename Predicate,
2937  typename = typename enable_if_pattern<Before>::type,
2938  typename = typename enable_if_pattern<After>::type,
2939  typename = typename enable_if_pattern<Predicate>::type>
2940  HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred) {
2941  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2942  static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2943  static_assert((Before::binds & Predicate::binds) == Predicate::binds, "Rule predicate uses unbound values");
2944  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2945  static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2946 
2947 #if HALIDE_FUZZ_TEST_RULES
2948  fuzz_test_rule(before, after, pred, wildcard_type, output_type);
2949 #endif
2950  if (before.template match<0>(unwrap(instance), state) &&
2951  evaluate_predicate(pred, state)) {
2952  build_replacement(after);
2953 #if HALIDE_DEBUG_MATCHED_RULES
2954  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2955 #endif
2956  return true;
2957  } else {
2958 #if HALIDE_DEBUG_UNMATCHED_RULES
2959  debug(0) << instance << " does not match " << before << "\n";
2960 #endif
2961  return false;
2962  }
2963  }
2964 
2965  template<typename Before,
2966  typename Predicate,
2967  typename = typename enable_if_pattern<Before>::type,
2968  typename = typename enable_if_pattern<Predicate>::type>
2969  HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred) {
2970  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2971  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2972 
2973  if (before.template match<0>(unwrap(instance), state) &&
2974  evaluate_predicate(pred, state)) {
2975  result = after;
2976 #if HALIDE_DEBUG_MATCHED_RULES
2977  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2978 #endif
2979  return true;
2980  } else {
2981 #if HALIDE_DEBUG_UNMATCHED_RULES
2982  debug(0) << instance << " does not match " << before << "\n";
2983 #endif
2984  return false;
2985  }
2986  }
2987 
2988  template<typename Before,
2989  typename Predicate,
2990  typename = typename enable_if_pattern<Before>::type,
2991  typename = typename enable_if_pattern<Predicate>::type>
2992  HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred) {
2993  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2994  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2995 #if HALIDE_FUZZ_TEST_RULES
2996  fuzz_test_rule(before, IntLiteral(after), pred, wildcard_type, output_type);
2997 #endif
2998  if (before.template match<0>(unwrap(instance), state) &&
2999  evaluate_predicate(pred, state)) {
3000  result = make_const(output_type, after);
3001 #if HALIDE_DEBUG_MATCHED_RULES
3002  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
3003 #endif
3004  return true;
3005  } else {
3006 #if HALIDE_DEBUG_UNMATCHED_RULES
3007  debug(0) << instance << " does not match " << before << "\n";
3008 #endif
3009  return false;
3010  }
3011  }
3012 };
3013 
3014 /** Construct a rewriter for the given instance, which may be a pattern
3015  * with concrete expressions as leaves, or just an expression. The
3016  * second optional argument (wildcard_type) is a hint as to what the
3017  * type of the wildcards is likely to be. If omitted it uses the same
3018  * type as the expression itself. They are not required to be this
3019  * type, but the rule will only be tested for wildcards of that type
3020  * when testing is enabled.
3021  *
3022  * The rewriter can be used to check to see if the instance is one of
3023  * some number of patterns and if so rewrite it into another form,
3024  * using its operator() method. See Simplify.cpp for a bunch of
3025  * example usage.
3026  *
3027  * Important: Any Exprs in patterns are captured by reference, not by
3028  * value, so ensure they outlive the rewriter.
3029  */
3030 // @{
3031 template<typename Instance,
3032  typename = typename enable_if_pattern<Instance>::type>
3033 HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
3034  return {pattern_arg(instance), output_type, wildcard_type};
3035 }
3036 
3037 template<typename Instance,
3038  typename = typename enable_if_pattern<Instance>::type>
3039 HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
3040  return {pattern_arg(instance), output_type, output_type};
3041 }
3042 
3044 auto rewriter(const Expr &e, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3045  return {pattern_arg(e), e.type(), wildcard_type};
3046 }
3047 
3049 auto rewriter(const Expr &e) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3050  return {pattern_arg(e), e.type(), e.type()};
3051 }
3052 // @}
3053 
3054 } // namespace IRMatcher
3055 
3056 } // namespace Internal
3057 } // namespace Halide
3058 
3059 #endif
#define internal_error
Definition: Errors.h:23
@ halide_type_float
IEEE floating point numbers.
@ halide_type_bfloat
floating point numbers in the bfloat format
@ halide_type_int
signed integers
@ halide_type_uint
unsigned integers
#define HALIDE_NEVER_INLINE
Definition: HalideRuntime.h:50
#define HALIDE_ALWAYS_INLINE
Definition: HalideRuntime.h:49
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Methods to test Exprs and Stmts for equality of value.
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
For optional debugging during codegen, use the debug class as follows:
Definition: Debug.h:49
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
Definition: IRMatch.h:217
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1586
auto shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1578
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
Definition: IRMatch.h:3033
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
Definition: IRMatch.h:567
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1527
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
Definition: IRMatch.h:1275
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
Definition: IRMatch.h:1643
HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp< Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1076
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
Definition: IRMatch.h:2826
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1032
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
Definition: IRMatch.h:1250
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
Definition: IRMatch.h:2020
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1170
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:921
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
Definition: IRMatch.h:2598
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
Definition: IRMatch.h:1301
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
Definition: IRMatch.h:1932
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
Definition: IRMatch.h:1150
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition: IRMatch.h:2345
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin< decltype(pattern_arg(args))... >
Definition: IRMatch.h:1522
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1180
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:987
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1574
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1590
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1535
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
Definition: IRMatch.h:928
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
Definition: IRMatch.h:1027
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1552
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
Definition: IRMatch.h:994
HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp< Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1098
HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept -> SliceOp< decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))>
Definition: IRMatch.h:2177
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1856
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1020
HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp< decltype(pattern_arg(a))>
Definition: IRMatch.h:2113
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1548
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1061
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1306
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
Definition: IRMatch.h:559
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1145
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
Definition: IRMatch.h:2066
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
Definition: IRMatch.h:2263
auto widening_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1540
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
Definition: IRMatch.h:576
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1047
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:968
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
Definition: IRMatch.h:2557
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
Definition: IRMatch.h:2226
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
Definition: IRMatch.h:1649
auto halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1566
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1105
constexpr bool and_reduce()
Definition: IRMatch.h:1330
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1270
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1544
constexpr int max_wild
Definition: IRMatch.h:74
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1245
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
Definition: IRMatch.h:2430
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1195
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1120
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1296
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits=0, int lanes=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
Definition: IRMatch.h:2513
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
Definition: IRMatch.h:1938
constexpr bool commutative(IRNodeType t)
Definition: IRMatch.h:615
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1531
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
Definition: IRMatch.h:961
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
Definition: IRMatch.h:1926
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
Definition: IRMatch.h:1792
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits=0, int lanes=0) noexcept -> IsInt< decltype(pattern_arg(a))>
Definition: IRMatch.h:2468
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
Definition: IRMatch.h:1719
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
Definition: IRMatch.h:2641
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1083
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
Definition: IRMatch.h:2695
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1155
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1570
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1556
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1001
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1594
auto shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1582
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1205
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:954
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
Definition: IRMatch.h:1175
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
Definition: IRMatch.h:1125
HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition: IRMatch.h:2351
HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf< decltype(pattern_arg(a))>
Definition: IRMatch.h:2678
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1130
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
Definition: IRMatch.h:1920
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
Definition: IRMatch.h:1914
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1280
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
Definition: IRMatch.h:160
constexpr uint32_t bitwise_or_reduce()
Definition: IRMatch.h:1321
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1598
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1230
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
Definition: IRMatch.h:149
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
Definition: IRMatch.h:1200
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin< decltype(pattern_arg(a))>
Definition: IRMatch.h:1560
constexpr int const_min(int a, int b)
Definition: IRMatch.h:1340
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1255
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition: IRMatch.h:1054
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1220
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:935
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
Definition: IRMatch.h:2393
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
Definition: IRMatch.h:1225
T div_imp(T a, T b)
Definition: IROperator.h:267
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
Expr make_zero(Type t)
Construct the representation of zero in the given type.
void expr_match_test()
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
bool equal(const RDom &bounds0, const RDom &bounds1)
Return true if bounds0 and bounds1 represent the same bounds.
constexpr IRNodeType StrongestExprNodeType
Definition: Expr.h:81
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition: IROperator.h:246
bool sub_would_overflow(int bits, int64_t a, int64_t b)
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
ConstantInterval abs(const ConstantInterval &a)
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition: Expr.h:25
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
@ Predicate
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
@ C
No name mangling.
Expr likely_if_innermost(Expr e)
Equivalent to likely, but only triggers a loop partitioning if found in an innermost loop.
Expr likely(Expr e)
Expressions tagged with this intrinsic are considered to be part of the steady state of some loop wit...
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
Definition: Expr.h:258
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition: Expr.h:327
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition: Expr.h:321
The sum of two expressions.
Definition: IR.h:56
Logical and - are both expressions true.
Definition: IR.h:175
A base class for expression nodes.
Definition: Expr.h:143
A vector with 'lanes' elements, in which every element is 'value'.
Definition: IR.h:259
static Expr make(Expr value, int lanes)
static const IRNodeType _node_type
Definition: IR.h:265
A function call.
Definition: IR.h:490
@ signed_integer_overflow
Definition: IR.h:595
@ rounding_mul_shift_right
Definition: IR.h:585
bool is_intrinsic() const
Definition: IR.h:714
static const IRNodeType _node_type
Definition: IR.h:759
The actual IR nodes begin here.
Definition: IR.h:30
static const IRNodeType _node_type
Definition: IR.h:35
The ratio of two expressions.
Definition: IR.h:83
Is the first expression equal to the second.
Definition: IR.h:121
Floating point constants.
Definition: Expr.h:236
static const FloatImm * make(Type t, double value)
Is the first expression greater than or equal to the second.
Definition: IR.h:166
Is the first expression greater than the second.
Definition: IR.h:157
constexpr static uint32_t binds
Definition: IRMatch.h:633
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:636
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:664
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:645
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:635
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition: IRMatch.h:707
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:655
constexpr static bool canonical
Definition: IRMatch.h:641
constexpr static bool foldable
Definition: IRMatch.h:661
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1758
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1752
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1740
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1734
constexpr static uint32_t binds
Definition: IRMatch.h:1732
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1775
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1735
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2382
constexpr static bool foldable
Definition: IRMatch.h:2379
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2375
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2376
constexpr static uint32_t binds
Definition: IRMatch.h:2372
constexpr static bool canonical
Definition: IRMatch.h:2377
constexpr static bool canonical
Definition: IRMatch.h:2035
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2034
constexpr static bool foldable
Definition: IRMatch.h:2056
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2038
constexpr static uint32_t binds
Definition: IRMatch.h:2031
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2033
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:2047
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2052
constexpr static bool canonical
Definition: IRMatch.h:748
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:808
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:746
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:747
constexpr static bool foldable
Definition: IRMatch.h:771
constexpr static uint32_t binds
Definition: IRMatch.h:744
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:755
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:774
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:765
constexpr static bool foldable
Definition: IRMatch.h:2217
constexpr static uint32_t binds
Definition: IRMatch.h:2187
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2189
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2190
constexpr static bool canonical
Definition: IRMatch.h:2191
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition: IRMatch.h:2194
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2220
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:504
constexpr static bool canonical
Definition: IRMatch.h:496
constexpr static uint32_t binds
Definition: IRMatch.h:492
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
Definition: IRMatch.h:499
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
Definition: IRMatch.h:527
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:539
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:494
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:495
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:532
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
Definition: IRMatch.h:522
constexpr static bool foldable
Definition: IRMatch.h:536
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
Definition: IRMatch.h:1370
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1405
constexpr static bool canonical
Definition: IRMatch.h:1358
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
Definition: IRMatch.h:1400
constexpr static bool foldable
Definition: IRMatch.h:1463
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1465
static constexpr uint32_t binds
Definition: IRMatch.h:1354
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
Definition: IRMatch.h:1363
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
Definition: IRMatch.h:1387
std::tuple< Args... > args
Definition: IRMatch.h:1348
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1375
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
Definition: IRMatch.h:1396
HALIDE_ALWAYS_INLINE Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
Definition: IRMatch.h:1508
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1357
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1356
constexpr static bool canonical
Definition: IRMatch.h:2322
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2331
constexpr static bool foldable
Definition: IRMatch.h:2328
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2321
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2320
constexpr static uint32_t binds
Definition: IRMatch.h:2317
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2412
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2419
constexpr static bool canonical
Definition: IRMatch.h:2414
constexpr static uint32_t binds
Definition: IRMatch.h:2409
constexpr static bool foldable
Definition: IRMatch.h:2416
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2413
constexpr static uint32_t binds
Definition: IRMatch.h:2447
constexpr static bool foldable
Definition: IRMatch.h:2454
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2450
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2457
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2451
constexpr static bool canonical
Definition: IRMatch.h:2452
constexpr static bool canonical
Definition: IRMatch.h:2578
constexpr static bool foldable
Definition: IRMatch.h:2580
constexpr static uint32_t binds
Definition: IRMatch.h:2573
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2576
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2577
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2583
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2617
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2618
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2624
constexpr static bool canonical
Definition: IRMatch.h:2619
constexpr static uint32_t binds
Definition: IRMatch.h:2614
constexpr static bool foldable
Definition: IRMatch.h:2621
constexpr static bool foldable
Definition: IRMatch.h:2543
constexpr static bool canonical
Definition: IRMatch.h:2541
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2546
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2540
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2539
constexpr static uint32_t binds
Definition: IRMatch.h:2536
constexpr static bool canonical
Definition: IRMatch.h:2497
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2502
constexpr static uint32_t binds
Definition: IRMatch.h:2492
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2495
constexpr static bool foldable
Definition: IRMatch.h:2499
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2496
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2667
constexpr static bool foldable
Definition: IRMatch.h:2664
constexpr static bool canonical
Definition: IRMatch.h:2662
constexpr static uint32_t binds
Definition: IRMatch.h:2657
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2660
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2661
To save stack space, the matcher objects are largely stateless and immutable.
Definition: IRMatch.h:82
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
Definition: IRMatch.h:127
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
Definition: IRMatch.h:103
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
Definition: IRMatch.h:115
static constexpr uint16_t special_values_mask
Definition: IRMatch.h:88
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
Definition: IRMatch.h:121
halide_type_t bound_const_type[max_wild]
Definition: IRMatch.h:90
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
Definition: IRMatch.h:93
HALIDE_ALWAYS_INLINE MatcherState() noexcept
Definition: IRMatch.h:134
halide_scalar_value_t bound_const[max_wild]
Definition: IRMatch.h:84
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
Definition: IRMatch.h:98
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
Definition: IRMatch.h:109
static constexpr uint16_t signed_integer_overflow
Definition: IRMatch.h:87
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1950
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1951
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1956
constexpr static uint32_t binds
Definition: IRMatch.h:1948
constexpr static bool canonical
Definition: IRMatch.h:1953
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1971
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
Definition: IRMatch.h:1966
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1980
constexpr static bool foldable
Definition: IRMatch.h:1977
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1609
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1614
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1623
constexpr static uint32_t binds
Definition: IRMatch.h:1607
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1610
constexpr static bool foldable
Definition: IRMatch.h:1632
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1628
constexpr static bool canonical
Definition: IRMatch.h:1611
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1635
constexpr static bool canonical
Definition: IRMatch.h:2282
constexpr static bool foldable
Definition: IRMatch.h:2299
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2285
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2294
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2281
constexpr static uint32_t binds
Definition: IRMatch.h:2277
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2302
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2280
constexpr static bool foldable
Definition: IRMatch.h:2250
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2253
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2247
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2246
constexpr static uint32_t binds
Definition: IRMatch.h:2242
constexpr static bool canonical
Definition: IRMatch.h:2248
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1834
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1807
constexpr static bool canonical
Definition: IRMatch.h:1809
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1806
constexpr static bool foldable
Definition: IRMatch.h:1846
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1827
constexpr static uint32_t binds
Definition: IRMatch.h:1804
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1812
HALIDE_NEVER_INLINE void build_replacement(After after)
Definition: IRMatch.h:2866
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
Definition: IRMatch.h:2940
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
Definition: IRMatch.h:2915
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
Definition: IRMatch.h:2861
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
Definition: IRMatch.h:2969
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
Definition: IRMatch.h:2897
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
Definition: IRMatch.h:2992
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
Definition: IRMatch.h:2874
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1670
constexpr static bool canonical
Definition: IRMatch.h:1672
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1699
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1669
constexpr static uint32_t binds
Definition: IRMatch.h:1667
constexpr static bool foldable
Definition: IRMatch.h:1696
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
Definition: IRMatch.h:1685
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1675
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1692
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2129
constexpr static bool foldable
Definition: IRMatch.h:2159
HALIDE_ALWAYS_INLINE SliceOp(Vec v, Base b, Stride s, Lanes l)
Definition: IRMatch.h:2162
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2133
static constexpr uint32_t binds
Definition: IRMatch.h:2126
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2128
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2147
constexpr static bool canonical
Definition: IRMatch.h:2130
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:198
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:205
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:210
constexpr static uint32_t binds
Definition: IRMatch.h:195
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:199
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1889
constexpr static uint32_t binds
Definition: IRMatch.h:1869
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1872
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1876
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1871
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1896
constexpr static bool foldable
Definition: IRMatch.h:2103
constexpr static bool canonical
Definition: IRMatch.h:2080
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2097
HALIDE_ALWAYS_INLINE bool match(const WidenOp< A2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:2092
constexpr static uint32_t binds
Definition: IRMatch.h:2076
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2083
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2078
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2079
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:352
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:348
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:373
constexpr static uint32_t binds
Definition: IRMatch.h:345
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:347
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:383
constexpr static uint32_t binds
Definition: IRMatch.h:399
constexpr static bool foldable
Definition: IRMatch.h:438
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:431
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:402
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:406
constexpr static bool canonical
Definition: IRMatch.h:403
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:441
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:401
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
Definition: IRMatch.h:425
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:228
constexpr static uint32_t binds
Definition: IRMatch.h:226
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:267
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:277
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:233
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
Definition: IRMatch.h:254
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:229
constexpr static uint32_t binds
Definition: IRMatch.h:292
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:295
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:294
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:299
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:330
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:320
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:460
constexpr static bool foldable
Definition: IRMatch.h:477
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:473
constexpr static bool canonical
Definition: IRMatch.h:461
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:459
constexpr static uint32_t binds
Definition: IRMatch.h:457
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:464
constexpr static uint32_t mask
Definition: IRMatch.h:146
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition: Expr.h:113
Integer constants.
Definition: Expr.h:218
static const IntImm * make(Type t, int64_t value)
Is the first expression less than or equal to the second.
Definition: IR.h:148
Is the first expression less than the second.
Definition: IR.h:139
The greater of two values.
Definition: IR.h:112
The lesser of two values.
Definition: IR.h:103
The remainder of a / b.
Definition: IR.h:94
The product of two expressions.
Definition: IR.h:74
Is the first expression not equal to the second.
Definition: IR.h:130
Logical not - true if the expression false.
Definition: IR.h:193
static Expr make(Expr a)
Logical or - is at least one of the expression true.
Definition: IR.h:184
A linear ramp vector node.
Definition: IR.h:247
static const IRNodeType _node_type
Definition: IR.h:253
static Expr make(Expr base, Expr stride, int lanes)
A ternary operator.
Definition: IR.h:204
static Expr make(Expr condition, Expr true_value, Expr false_value)
static const IRNodeType _node_type
Definition: IR.h:209
Construct a new vector by taking elements from another sequence of vectors.
Definition: IR.h:848
static Expr make_slice(Expr vector, int begin, int stride, int size)
Convenience constructor for making a shuffle representing a contiguous subset of a vector.
std::vector< Expr > vectors
Definition: IR.h:849
bool is_slice() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
int slice_stride() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition: IR.h:902
int slice_begin() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition: IR.h:899
The difference of two expressions.
Definition: IR.h:65
static const IRNodeType _node_type
Definition: IR.h:70
static Expr make(Expr a, Expr b)
Unsigned integer constants.
Definition: Expr.h:227
static const UIntImm * make(Type t, uint64_t value)
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition: IR.h:972
static const IRNodeType _node_type
Definition: IR.h:991
static Expr make(Operator op, Expr vec, int lanes)
Types in the halide type system.
Definition: Type.h:283
Type widen() const
Return Type with the same type code and number of lanes, but with at least twice as many bits.
Definition: Type.h:378
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
Definition: Type.h:435
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
Definition: Type.h:355
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
Definition: Type.h:441
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
Definition: Type.h:349
HALIDE_ALWAYS_INLINE bool is_vector() const
Is this type a vector type? (lanes() != 1).
Definition: Type.h:410
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
Definition: Type.h:417
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
Definition: Type.h:423
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
union halide_scalar_value_t::@3 u
A runtime tag for a type in the halide type system.
uint8_t bits
The number of bits of precision of a single scalar value of this type.
uint16_t lanes
How many elements in a vector.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.