Halide  17.0.2
Halide compiler and libraries
IRMatch.h
Go to the documentation of this file.
1 #ifndef HALIDE_IR_MATCH_H
2 #define HALIDE_IR_MATCH_H
3 
4 /** \file
5  * Defines a method to match a fragment of IR against a pattern containing wildcards
6  */
7 
8 #include <map>
9 #include <random>
10 #include <set>
11 #include <vector>
12 
13 #include "IR.h"
14 #include "IREquality.h"
15 #include "IROperator.h"
16 
17 namespace Halide {
18 namespace Internal {
19 
20 /** Does the first expression have the same structure as the second?
21  * Variables in the first expression with the name * are interpreted
22  * as wildcards, and their matching equivalent in the second
23  * expression is placed in the vector give as the third argument.
24  * Wildcards require the types to match. For the type bits and width,
25  * a 0 indicates "match anything". So an Int(8, 0) will match 8-bit
26  * integer vectors of any width (including scalars), and a UInt(0, 0)
27  * will match any unsigned integer type.
28  *
29  * For example:
30  \code
31  Expr x = Variable::make(Int(32), "*");
32  match(x + x, 3 + (2*k), result)
33  \endcode
34  * should return true, and set result[0] to 3 and
35  * result[1] to 2*k.
36  */
37 bool expr_match(const Expr &pattern, const Expr &expr, std::vector<Expr> &result);
38 
39 /** Does the first expression have the same structure as the second?
40  * Variables are matched consistently. The first time a variable is
41  * matched, it assumes the value of the matching part of the second
42  * expression. Subsequent matches must be equal to the first match.
43  *
44  * For example:
45  \code
46  Var x("x"), y("y");
47  match(x*(x + y), a*(a + b), result)
48  \endcode
49  * should return true, and set result["x"] = a, and result["y"] = b.
50  */
51 bool expr_match(const Expr &pattern, const Expr &expr, std::map<std::string, Expr> &result);
52 
53 /** Rewrite the expression x to have `lanes` lanes. This is useful
54  * for substituting the results of expr_match into a pattern expression. */
55 Expr with_lanes(const Expr &x, int lanes);
56 
57 void expr_match_test();
58 
59 /** An alternative template-metaprogramming approach to expression
60  * matching. Potentially more efficient. We lift the expression
61  * pattern into a type, and then use force-inlined functions to
62  * generate efficient matching and reconstruction code for any
63  * pattern. Pattern elements are either one of the classes in the
64  * namespace IRMatcher, or are non-null Exprs (represented as
65  * BaseExprNode &).
66  *
67  * Pattern elements that are fully specified by their pattern can be
68  * built into an expression using the make method. Some patterns,
69  * such as a broadcast that matches any number of lanes, don't have
70  * enough information to recreate an Expr.
71  */
72 namespace IRMatcher {
73 
74 constexpr int max_wild = 6;
75 
76 static const halide_type_t i64_type = {halide_type_int, 64, 1};
77 
78 /** To save stack space, the matcher objects are largely stateless and
79  * immutable. This state object is built up during matching and then
80  * consumed when constructing a replacement Expr.
81  */
82 struct MatcherState {
85 
86  // values of the lanes field with special meaning.
87  static constexpr uint16_t signed_integer_overflow = 0x8000;
88  static constexpr uint16_t special_values_mask = 0x8000; // currently only one
89 
91 
93  void set_binding(int i, const BaseExprNode &n) noexcept {
94  bindings[i] = &n;
95  }
96 
98  const BaseExprNode *get_binding(int i) const noexcept {
99  return bindings[i];
100  }
101 
103  void set_bound_const(int i, int64_t s, halide_type_t t) noexcept {
104  bound_const[i].u.i64 = s;
105  bound_const_type[i] = t;
106  }
107 
109  void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept {
110  bound_const[i].u.u64 = u;
111  bound_const_type[i] = t;
112  }
113 
115  void set_bound_const(int i, double f, halide_type_t t) noexcept {
116  bound_const[i].u.f64 = f;
117  bound_const_type[i] = t;
118  }
119 
121  void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept {
122  bound_const[i] = val;
123  bound_const_type[i] = t;
124  }
125 
127  void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept {
128  val = bound_const[i];
129  type = bound_const_type[i];
130  }
131 
133  // NOLINTNEXTLINE(modernize-use-equals-default): Can't use `= default`; clang-tidy complains about noexcept mismatch
134  MatcherState() noexcept {
135  }
136 };
137 
138 template<typename T,
139  typename = typename std::remove_reference<T>::type::pattern_tag>
141  struct type {};
142 };
143 
144 template<typename T>
145 struct bindings {
146  constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
147 };
148 
151  ty.lanes &= ~MatcherState::special_values_mask;
153  return make_signed_integer_overflow(ty);
154  }
155  // unreachable
156  return Expr();
157 }
158 
161  halide_type_t scalar_type = ty;
162  if (scalar_type.lanes & MatcherState::special_values_mask) {
163  return make_const_special_expr(scalar_type);
164  }
165 
166  const int lanes = scalar_type.lanes;
167  scalar_type.lanes = 1;
168 
169  Expr e;
170  switch (scalar_type.code) {
171  case halide_type_int:
172  e = IntImm::make(scalar_type, val.u.i64);
173  break;
174  case halide_type_uint:
175  e = UIntImm::make(scalar_type, val.u.u64);
176  break;
177  case halide_type_float:
178  case halide_type_bfloat:
179  e = FloatImm::make(scalar_type, val.u.f64);
180  break;
181  default:
182  // Unreachable
183  return Expr();
184  }
185  if (lanes > 1) {
186  e = Broadcast::make(e, lanes);
187  }
188  return e;
189 }
190 
191 bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept;
192 
193 // A fast version of expression equality that assumes a well-typed non-null expression tree.
195 bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept {
196  // Early out
197  return (&a == &b) ||
198  ((a.type == b.type) &&
199  (a.node_type == b.node_type) &&
200  equal_helper(a, b));
201 }
202 
203 // A pattern that matches a specific expression
204 struct SpecificExpr {
205  struct pattern_tag {};
206 
207  constexpr static uint32_t binds = 0;
208 
209  // What is the weakest and strongest IR node this could possibly be
212  constexpr static bool canonical = true;
213 
215 
216  template<uint32_t bound>
217  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
218  return equal(expr, e);
219  }
220 
222  Expr make(MatcherState &state, halide_type_t type_hint) const {
223  return Expr(&expr);
224  }
225 
226  constexpr static bool foldable = false;
227 };
228 
229 inline std::ostream &operator<<(std::ostream &s, const SpecificExpr &e) {
230  s << Expr(&e.expr);
231  return s;
232 }
233 
234 template<int i>
235 struct WildConstInt {
236  struct pattern_tag {};
237 
238  constexpr static uint32_t binds = 1 << i;
239 
242  constexpr static bool canonical = true;
243 
244  template<uint32_t bound>
245  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
246  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
247  const BaseExprNode *op = &e;
248  if (op->node_type == IRNodeType::Broadcast) {
249  op = ((const Broadcast *)op)->value.get();
250  }
251  if (op->node_type != IRNodeType::IntImm) {
252  return false;
253  }
254  int64_t value = ((const IntImm *)op)->value;
255  if (bound & binds) {
257  halide_type_t type;
258  state.get_bound_const(i, val, type);
259  return (halide_type_t)e.type == type && value == val.u.i64;
260  }
261  state.set_bound_const(i, value, e.type);
262  return true;
263  }
264 
265  template<uint32_t bound>
266  HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept {
267  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
268  if (bound & binds) {
270  halide_type_t type;
271  state.get_bound_const(i, val, type);
272  return type == i64_type && value == val.u.i64;
273  }
274  state.set_bound_const(i, value, i64_type);
275  return true;
276  }
277 
279  Expr make(MatcherState &state, halide_type_t type_hint) const {
281  halide_type_t type;
282  state.get_bound_const(i, val, type);
283  return make_const_expr(val, type);
284  }
285 
286  constexpr static bool foldable = true;
287 
290  state.get_bound_const(i, val, ty);
291  }
292 };
293 
294 template<int i>
295 std::ostream &operator<<(std::ostream &s, const WildConstInt<i> &c) {
296  s << "ci" << i;
297  return s;
298 }
299 
300 template<int i>
302  struct pattern_tag {};
303 
304  constexpr static uint32_t binds = 1 << i;
305 
308  constexpr static bool canonical = true;
309 
310  template<uint32_t bound>
311  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
312  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
313  const BaseExprNode *op = &e;
314  if (op->node_type == IRNodeType::Broadcast) {
315  op = ((const Broadcast *)op)->value.get();
316  }
317  if (op->node_type != IRNodeType::UIntImm) {
318  return false;
319  }
320  uint64_t value = ((const UIntImm *)op)->value;
321  if (bound & binds) {
323  halide_type_t type;
324  state.get_bound_const(i, val, type);
325  return (halide_type_t)e.type == type && value == val.u.u64;
326  }
327  state.set_bound_const(i, value, e.type);
328  return true;
329  }
330 
332  Expr make(MatcherState &state, halide_type_t type_hint) const {
334  halide_type_t type;
335  state.get_bound_const(i, val, type);
336  return make_const_expr(val, type);
337  }
338 
339  constexpr static bool foldable = true;
340 
342  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
343  state.get_bound_const(i, val, ty);
344  }
345 };
346 
347 template<int i>
348 std::ostream &operator<<(std::ostream &s, const WildConstUInt<i> &c) {
349  s << "cu" << i;
350  return s;
351 }
352 
353 template<int i>
355  struct pattern_tag {};
356 
357  constexpr static uint32_t binds = 1 << i;
358 
361  constexpr static bool canonical = true;
362 
363  template<uint32_t bound>
364  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
365  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
366  const BaseExprNode *op = &e;
367  if (op->node_type == IRNodeType::Broadcast) {
368  op = ((const Broadcast *)op)->value.get();
369  }
370  if (op->node_type != IRNodeType::FloatImm) {
371  return false;
372  }
373  double value = ((const FloatImm *)op)->value;
374  if (bound & binds) {
376  halide_type_t type;
377  state.get_bound_const(i, val, type);
378  return (halide_type_t)e.type == type && value == val.u.f64;
379  }
380  state.set_bound_const(i, value, e.type);
381  return true;
382  }
383 
385  Expr make(MatcherState &state, halide_type_t type_hint) const {
387  halide_type_t type;
388  state.get_bound_const(i, val, type);
389  return make_const_expr(val, type);
390  }
391 
392  constexpr static bool foldable = true;
393 
395  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
396  state.get_bound_const(i, val, ty);
397  }
398 };
399 
400 template<int i>
401 std::ostream &operator<<(std::ostream &s, const WildConstFloat<i> &c) {
402  s << "cf" << i;
403  return s;
404 }
405 
406 // Matches and binds to any constant Expr. Does not support constant-folding.
407 template<int i>
408 struct WildConst {
409  struct pattern_tag {};
410 
411  constexpr static uint32_t binds = 1 << i;
412 
415  constexpr static bool canonical = true;
416 
417  template<uint32_t bound>
418  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
419  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
420  const BaseExprNode *op = &e;
421  if (op->node_type == IRNodeType::Broadcast) {
422  op = ((const Broadcast *)op)->value.get();
423  }
424  switch (op->node_type) {
425  case IRNodeType::IntImm:
426  return WildConstInt<i>().template match<bound>(e, state);
427  case IRNodeType::UIntImm:
428  return WildConstUInt<i>().template match<bound>(e, state);
430  return WildConstFloat<i>().template match<bound>(e, state);
431  default:
432  return false;
433  }
434  }
435 
436  template<uint32_t bound>
437  HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept {
438  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
439  return WildConstInt<i>().template match<bound>(e, state);
440  }
441 
443  Expr make(MatcherState &state, halide_type_t type_hint) const {
445  halide_type_t type;
446  state.get_bound_const(i, val, type);
447  return make_const_expr(val, type);
448  }
449 
450  constexpr static bool foldable = true;
451 
453  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
454  state.get_bound_const(i, val, ty);
455  }
456 };
457 
458 template<int i>
459 std::ostream &operator<<(std::ostream &s, const WildConst<i> &c) {
460  s << "c" << i;
461  return s;
462 }
463 
464 // Matches and binds to any Expr
465 template<int i>
466 struct Wild {
467  struct pattern_tag {};
468 
469  constexpr static uint32_t binds = 1 << (i + 16);
470 
473  constexpr static bool canonical = true;
474 
475  template<uint32_t bound>
476  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
477  if (bound & binds) {
478  return equal(*state.get_binding(i), e);
479  }
480  state.set_binding(i, e);
481  return true;
482  }
483 
485  Expr make(MatcherState &state, halide_type_t type_hint) const {
486  return state.get_binding(i);
487  }
488 
489  constexpr static bool foldable = false;
490 };
491 
492 template<int i>
493 std::ostream &operator<<(std::ostream &s, const Wild<i> &op) {
494  s << "_" << i;
495  return s;
496 }
497 
498 // Matches a specific constant or broadcast of that constant. The
499 // constant must be representable as an int64_t.
500 struct IntLiteral {
501  struct pattern_tag {};
503 
504  constexpr static uint32_t binds = 0;
505 
508  constexpr static bool canonical = true;
509 
511  explicit IntLiteral(int64_t v)
512  : v(v) {
513  }
514 
515  template<uint32_t bound>
516  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
517  const BaseExprNode *op = &e;
518  if (e.node_type == IRNodeType::Broadcast) {
519  op = ((const Broadcast *)op)->value.get();
520  }
521  switch (op->node_type) {
522  case IRNodeType::IntImm:
523  return ((const IntImm *)op)->value == (int64_t)v;
524  case IRNodeType::UIntImm:
525  return ((const UIntImm *)op)->value == (uint64_t)v;
527  return ((const FloatImm *)op)->value == (double)v;
528  default:
529  return false;
530  }
531  }
532 
533  template<uint32_t bound>
534  HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept {
535  return v == val;
536  }
537 
538  template<uint32_t bound>
539  HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept {
540  return v == b.v;
541  }
542 
544  Expr make(MatcherState &state, halide_type_t type_hint) const {
545  return make_const(type_hint, v);
546  }
547 
548  constexpr static bool foldable = true;
549 
551  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
552  // Assume type is already correct
553  switch (ty.code) {
554  case halide_type_int:
555  val.u.i64 = v;
556  break;
557  case halide_type_uint:
558  val.u.u64 = (uint64_t)v;
559  break;
560  case halide_type_float:
561  case halide_type_bfloat:
562  val.u.f64 = (double)v;
563  break;
564  default:
565  // Unreachable
566  ;
567  }
568  }
569 };
570 
572  return t.v;
573 }
574 
575 // Convert a provided pattern, expr, or constant int into the internal
576 // representation we use in the matcher trees.
577 template<typename T,
578  typename = typename std::decay<T>::type::pattern_tag>
580  return t;
581 }
584  return IntLiteral{x};
585 }
586 
587 template<typename T>
589  static_assert(!std::is_same<typename std::decay<T>::type, Expr>::value || std::is_lvalue_reference<T>::value,
590  "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
591 }
592 
594  return {*e.get()};
595 }
596 
597 // Helpers to deref SpecificExprs to const BaseExprNode & rather than
598 // passing them by value anywhere (incurring lots of refcounting)
599 template<typename T,
600  // T must be a pattern node
601  typename = typename std::decay<T>::type::pattern_tag,
602  // But T may not be SpecificExpr
603  typename = typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
605  return t;
606 }
607 
609 const BaseExprNode &unwrap(const SpecificExpr &e) {
610  return e.expr;
611 }
612 
613 inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) {
614  s << op.v;
615  return s;
616 }
617 
618 template<typename Op>
620 
621 template<typename Op>
623 
624 template<typename Op>
625 double constant_fold_bin_op(halide_type_t &, double, double) noexcept;
626 
627 constexpr bool commutative(IRNodeType t) {
628  return (t == IRNodeType::Add ||
629  t == IRNodeType::Mul ||
630  t == IRNodeType::And ||
631  t == IRNodeType::Or ||
632  t == IRNodeType::Min ||
633  t == IRNodeType::Max ||
634  t == IRNodeType::EQ ||
635  t == IRNodeType::NE);
636 }
637 
638 // Matches one of the binary operators
639 template<typename Op, typename A, typename B>
640 struct BinOp {
641  struct pattern_tag {};
642  A a;
643  B b;
644 
646 
647  constexpr static IRNodeType min_node_type = Op::_node_type;
648  constexpr static IRNodeType max_node_type = Op::_node_type;
649 
650  // For commutative bin ops, we expect the weaker IR node type on
651  // the right. That is, for the rule to be canonical it must be
652  // possible that A is at least as strong as B.
653  constexpr static bool canonical =
654  A::canonical && B::canonical && (!commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
655 
656  template<uint32_t bound>
657  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
658  if (e.node_type != Op::_node_type) {
659  return false;
660  }
661  const Op &op = (const Op &)e;
662  return (a.template match<bound>(*op.a.get(), state) &&
663  b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
664  }
665 
666  template<uint32_t bound, typename Op2, typename A2, typename B2>
667  HALIDE_ALWAYS_INLINE bool match(const BinOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
668  return (std::is_same<Op, Op2>::value &&
669  a.template match<bound>(unwrap(op.a), state) &&
670  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
671  }
672 
673  constexpr static bool foldable = A::foldable && B::foldable;
674 
676  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
677  halide_scalar_value_t val_a, val_b;
678  if (std::is_same<A, IntLiteral>::value) {
679  b.make_folded_const(val_b, ty, state);
680  if ((std::is_same<Op, And>::value && val_b.u.u64 == 0) ||
681  (std::is_same<Op, Or>::value && val_b.u.u64 == 1)) {
682  // Short circuit
683  val = val_b;
684  return;
685  }
686  const uint16_t l = ty.lanes;
687  a.make_folded_const(val_a, ty, state);
688  ty.lanes |= l; // Make sure the overflow bits are sticky
689  } else {
690  a.make_folded_const(val_a, ty, state);
691  if ((std::is_same<Op, And>::value && val_a.u.u64 == 0) ||
692  (std::is_same<Op, Or>::value && val_a.u.u64 == 1)) {
693  // Short circuit
694  val = val_a;
695  return;
696  }
697  const uint16_t l = ty.lanes;
698  b.make_folded_const(val_b, ty, state);
699  ty.lanes |= l;
700  }
701  switch (ty.code) {
702  case halide_type_int:
703  val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.u.i64, val_b.u.i64);
704  break;
705  case halide_type_uint:
706  val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.u.u64, val_b.u.u64);
707  break;
708  case halide_type_float:
709  case halide_type_bfloat:
710  val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.u.f64, val_b.u.f64);
711  break;
712  default:
713  // unreachable
714  ;
715  }
716  }
717 
719  Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
720  Expr ea, eb;
721  if (std::is_same<A, IntLiteral>::value) {
722  eb = b.make(state, type_hint);
723  ea = a.make(state, eb.type());
724  } else {
725  ea = a.make(state, type_hint);
726  eb = b.make(state, ea.type());
727  }
728  // We sometimes mix vectors and scalars in the rewrite rules,
729  // so insert a broadcast if necessary.
730  if (ea.type().is_vector() && !eb.type().is_vector()) {
731  eb = Broadcast::make(eb, ea.type().lanes());
732  }
733  if (eb.type().is_vector() && !ea.type().is_vector()) {
734  ea = Broadcast::make(ea, eb.type().lanes());
735  }
736  return Op::make(std::move(ea), std::move(eb));
737  }
738 };
739 
740 template<typename Op>
742 
743 template<typename Op>
745 
746 template<typename Op>
747 uint64_t constant_fold_cmp_op(double, double) noexcept;
748 
749 // Matches one of the comparison operators
750 template<typename Op, typename A, typename B>
751 struct CmpOp {
752  struct pattern_tag {};
753  A a;
754  B b;
755 
757 
758  constexpr static IRNodeType min_node_type = Op::_node_type;
759  constexpr static IRNodeType max_node_type = Op::_node_type;
760  constexpr static bool canonical = (A::canonical &&
761  B::canonical &&
762  (!commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
763  (Op::_node_type != IRNodeType::GE) &&
764  (Op::_node_type != IRNodeType::GT));
765 
766  template<uint32_t bound>
767  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
768  if (e.node_type != Op::_node_type) {
769  return false;
770  }
771  const Op &op = (const Op &)e;
772  return (a.template match<bound>(*op.a.get(), state) &&
773  b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
774  }
775 
776  template<uint32_t bound, typename Op2, typename A2, typename B2>
777  HALIDE_ALWAYS_INLINE bool match(const CmpOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
778  return (std::is_same<Op, Op2>::value &&
779  a.template match<bound>(unwrap(op.a), state) &&
780  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
781  }
782 
783  constexpr static bool foldable = A::foldable && B::foldable;
784 
786  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
787  halide_scalar_value_t val_a, val_b;
788  // If one side is an untyped const, evaluate the other side first to get a type hint.
789  if (std::is_same<A, IntLiteral>::value) {
790  b.make_folded_const(val_b, ty, state);
791  const uint16_t l = ty.lanes;
792  a.make_folded_const(val_a, ty, state);
793  ty.lanes |= l;
794  } else {
795  a.make_folded_const(val_a, ty, state);
796  const uint16_t l = ty.lanes;
797  b.make_folded_const(val_b, ty, state);
798  ty.lanes |= l;
799  }
800  switch (ty.code) {
801  case halide_type_int:
802  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.i64, val_b.u.i64);
803  break;
804  case halide_type_uint:
805  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.u64, val_b.u.u64);
806  break;
807  case halide_type_float:
808  case halide_type_bfloat:
809  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.f64, val_b.u.f64);
810  break;
811  default:
812  // unreachable
813  ;
814  }
815  ty.code = halide_type_uint;
816  ty.bits = 1;
817  }
818 
820  Expr make(MatcherState &state, halide_type_t type_hint) const {
821  // If one side is an untyped const, evaluate the other side first to get a type hint.
822  Expr ea, eb;
823  if (std::is_same<A, IntLiteral>::value) {
824  eb = b.make(state, {});
825  ea = a.make(state, eb.type());
826  } else {
827  ea = a.make(state, {});
828  eb = b.make(state, ea.type());
829  }
830  // We sometimes mix vectors and scalars in the rewrite rules,
831  // so insert a broadcast if necessary.
832  if (ea.type().is_vector() && !eb.type().is_vector()) {
833  eb = Broadcast::make(eb, ea.type().lanes());
834  }
835  if (eb.type().is_vector() && !ea.type().is_vector()) {
836  ea = Broadcast::make(ea, eb.type().lanes());
837  }
838  return Op::make(std::move(ea), std::move(eb));
839  }
840 };
841 
842 template<typename A, typename B>
843 std::ostream &operator<<(std::ostream &s, const BinOp<Add, A, B> &op) {
844  s << "(" << op.a << " + " << op.b << ")";
845  return s;
846 }
847 
848 template<typename A, typename B>
849 std::ostream &operator<<(std::ostream &s, const BinOp<Sub, A, B> &op) {
850  s << "(" << op.a << " - " << op.b << ")";
851  return s;
852 }
853 
854 template<typename A, typename B>
855 std::ostream &operator<<(std::ostream &s, const BinOp<Mul, A, B> &op) {
856  s << "(" << op.a << " * " << op.b << ")";
857  return s;
858 }
859 
860 template<typename A, typename B>
861 std::ostream &operator<<(std::ostream &s, const BinOp<Div, A, B> &op) {
862  s << "(" << op.a << " / " << op.b << ")";
863  return s;
864 }
865 
866 template<typename A, typename B>
867 std::ostream &operator<<(std::ostream &s, const BinOp<And, A, B> &op) {
868  s << "(" << op.a << " && " << op.b << ")";
869  return s;
870 }
871 
872 template<typename A, typename B>
873 std::ostream &operator<<(std::ostream &s, const BinOp<Or, A, B> &op) {
874  s << "(" << op.a << " || " << op.b << ")";
875  return s;
876 }
877 
878 template<typename A, typename B>
879 std::ostream &operator<<(std::ostream &s, const BinOp<Min, A, B> &op) {
880  s << "min(" << op.a << ", " << op.b << ")";
881  return s;
882 }
883 
884 template<typename A, typename B>
885 std::ostream &operator<<(std::ostream &s, const BinOp<Max, A, B> &op) {
886  s << "max(" << op.a << ", " << op.b << ")";
887  return s;
888 }
889 
890 template<typename A, typename B>
891 std::ostream &operator<<(std::ostream &s, const CmpOp<LE, A, B> &op) {
892  s << "(" << op.a << " <= " << op.b << ")";
893  return s;
894 }
895 
896 template<typename A, typename B>
897 std::ostream &operator<<(std::ostream &s, const CmpOp<LT, A, B> &op) {
898  s << "(" << op.a << " < " << op.b << ")";
899  return s;
900 }
901 
902 template<typename A, typename B>
903 std::ostream &operator<<(std::ostream &s, const CmpOp<GE, A, B> &op) {
904  s << "(" << op.a << " >= " << op.b << ")";
905  return s;
906 }
907 
908 template<typename A, typename B>
909 std::ostream &operator<<(std::ostream &s, const CmpOp<GT, A, B> &op) {
910  s << "(" << op.a << " > " << op.b << ")";
911  return s;
912 }
913 
914 template<typename A, typename B>
915 std::ostream &operator<<(std::ostream &s, const CmpOp<EQ, A, B> &op) {
916  s << "(" << op.a << " == " << op.b << ")";
917  return s;
918 }
919 
920 template<typename A, typename B>
921 std::ostream &operator<<(std::ostream &s, const CmpOp<NE, A, B> &op) {
922  s << "(" << op.a << " != " << op.b << ")";
923  return s;
924 }
925 
926 template<typename A, typename B>
927 std::ostream &operator<<(std::ostream &s, const BinOp<Mod, A, B> &op) {
928  s << "(" << op.a << " % " << op.b << ")";
929  return s;
930 }
931 
932 template<typename A, typename B>
934  assert_is_lvalue_if_expr<A>();
935  assert_is_lvalue_if_expr<B>();
936  return {pattern_arg(a), pattern_arg(b)};
937 }
938 
939 template<typename A, typename B>
940 HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b)) {
941  assert_is_lvalue_if_expr<A>();
942  assert_is_lvalue_if_expr<B>();
943  return IRMatcher::operator+(a, b);
944 }
945 
946 template<>
948  t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
949  int dead_bits = 64 - t.bits;
950  // Drop the high bits then sign-extend them back
951  return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits;
952 }
953 
954 template<>
956  uint64_t ones = (uint64_t)(-1);
957  return (a + b) & (ones >> (64 - t.bits));
958 }
959 
960 template<>
961 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Add>(halide_type_t &t, double a, double b) noexcept {
962  return a + b;
963 }
964 
965 template<typename A, typename B>
967  assert_is_lvalue_if_expr<A>();
968  assert_is_lvalue_if_expr<B>();
969  return {pattern_arg(a), pattern_arg(b)};
970 }
971 
972 template<typename A, typename B>
973 HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b)) {
974  assert_is_lvalue_if_expr<A>();
975  assert_is_lvalue_if_expr<B>();
976  return IRMatcher::operator-(a, b);
977 }
978 
979 template<>
981  t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
982  // Drop the high bits then sign-extend them back
983  int dead_bits = 64 - t.bits;
984  return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits;
985 }
986 
987 template<>
989  uint64_t ones = (uint64_t)(-1);
990  return (a - b) & (ones >> (64 - t.bits));
991 }
992 
993 template<>
994 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Sub>(halide_type_t &t, double a, double b) noexcept {
995  return a - b;
996 }
997 
998 template<typename A, typename B>
1000  assert_is_lvalue_if_expr<A>();
1001  assert_is_lvalue_if_expr<B>();
1002  return {pattern_arg(a), pattern_arg(b)};
1003 }
1004 
1005 template<typename A, typename B>
1006 HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b)) {
1007  assert_is_lvalue_if_expr<A>();
1008  assert_is_lvalue_if_expr<B>();
1009  return IRMatcher::operator*(a, b);
1010 }
1011 
1012 template<>
1014  t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
1015  int dead_bits = 64 - t.bits;
1016  // Drop the high bits then sign-extend them back
1017  return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits;
1018 }
1019 
1020 template<>
1022  uint64_t ones = (uint64_t)(-1);
1023  return (a * b) & (ones >> (64 - t.bits));
1024 }
1025 
1026 template<>
1027 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mul>(halide_type_t &t, double a, double b) noexcept {
1028  return a * b;
1029 }
1030 
1031 template<typename A, typename B>
1033  assert_is_lvalue_if_expr<A>();
1034  assert_is_lvalue_if_expr<B>();
1035  return {pattern_arg(a), pattern_arg(b)};
1036 }
1037 
1038 template<typename A, typename B>
1039 HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b)) {
1040  return IRMatcher::operator/(a, b);
1041 }
1042 
1043 template<>
1045  return div_imp(a, b);
1046 }
1047 
1048 template<>
1050  return div_imp(a, b);
1051 }
1052 
1053 template<>
1054 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Div>(halide_type_t &t, double a, double b) noexcept {
1055  return div_imp(a, b);
1056 }
1057 
1058 template<typename A, typename B>
1060  assert_is_lvalue_if_expr<A>();
1061  assert_is_lvalue_if_expr<B>();
1062  return {pattern_arg(a), pattern_arg(b)};
1063 }
1064 
1065 template<typename A, typename B>
1066 HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b)) {
1067  assert_is_lvalue_if_expr<A>();
1068  assert_is_lvalue_if_expr<B>();
1069  return IRMatcher::operator%(a, b);
1070 }
1071 
1072 template<>
1074  return mod_imp(a, b);
1075 }
1076 
1077 template<>
1079  return mod_imp(a, b);
1080 }
1081 
1082 template<>
1083 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mod>(halide_type_t &t, double a, double b) noexcept {
1084  return mod_imp(a, b);
1085 }
1086 
1087 template<typename A, typename B>
1089  assert_is_lvalue_if_expr<A>();
1090  assert_is_lvalue_if_expr<B>();
1091  return {pattern_arg(a), pattern_arg(b)};
1092 }
1093 
1094 template<>
1096  return std::min(a, b);
1097 }
1098 
1099 template<>
1101  return std::min(a, b);
1102 }
1103 
1104 template<>
1105 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Min>(halide_type_t &t, double a, double b) noexcept {
1106  return std::min(a, b);
1107 }
1108 
1109 template<typename A, typename B>
1111  assert_is_lvalue_if_expr<A>();
1112  assert_is_lvalue_if_expr<B>();
1113  return {pattern_arg(std::forward<A>(a)), pattern_arg(std::forward<B>(b))};
1114 }
1115 
1116 template<>
1118  return std::max(a, b);
1119 }
1120 
1121 template<>
1123  return std::max(a, b);
1124 }
1125 
1126 template<>
1127 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Max>(halide_type_t &t, double a, double b) noexcept {
1128  return std::max(a, b);
1129 }
1130 
1131 template<typename A, typename B>
1133  return {pattern_arg(a), pattern_arg(b)};
1134 }
1135 
1136 template<typename A, typename B>
1137 HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b)) {
1138  return IRMatcher::operator<(a, b);
1139 }
1140 
1141 template<>
1143  return a < b;
1144 }
1145 
1146 template<>
1148  return a < b;
1149 }
1150 
1151 template<>
1153  return a < b;
1154 }
1155 
1156 template<typename A, typename B>
1158  return {pattern_arg(a), pattern_arg(b)};
1159 }
1160 
1161 template<typename A, typename B>
1162 HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b)) {
1163  return IRMatcher::operator>(a, b);
1164 }
1165 
1166 template<>
1168  return a > b;
1169 }
1170 
1171 template<>
1173  return a > b;
1174 }
1175 
1176 template<>
1178  return a > b;
1179 }
1180 
1181 template<typename A, typename B>
1183  return {pattern_arg(a), pattern_arg(b)};
1184 }
1185 
1186 template<typename A, typename B>
1187 HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b)) {
1188  return IRMatcher::operator<=(a, b);
1189 }
1190 
1191 template<>
1193  return a <= b;
1194 }
1195 
1196 template<>
1198  return a <= b;
1199 }
1200 
1201 template<>
1203  return a <= b;
1204 }
1205 
1206 template<typename A, typename B>
1208  return {pattern_arg(a), pattern_arg(b)};
1209 }
1210 
1211 template<typename A, typename B>
1212 HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b)) {
1213  return IRMatcher::operator>=(a, b);
1214 }
1215 
1216 template<>
1218  return a >= b;
1219 }
1220 
1221 template<>
1223  return a >= b;
1224 }
1225 
1226 template<>
1228  return a >= b;
1229 }
1230 
1231 template<typename A, typename B>
1233  return {pattern_arg(a), pattern_arg(b)};
1234 }
1235 
1236 template<typename A, typename B>
1237 HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b)) {
1238  return IRMatcher::operator==(a, b);
1239 }
1240 
1241 template<>
1243  return a == b;
1244 }
1245 
1246 template<>
1248  return a == b;
1249 }
1250 
1251 template<>
1253  return a == b;
1254 }
1255 
1256 template<typename A, typename B>
1258  return {pattern_arg(a), pattern_arg(b)};
1259 }
1260 
1261 template<typename A, typename B>
1262 HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b)) {
1263  return IRMatcher::operator!=(a, b);
1264 }
1265 
1266 template<>
1268  return a != b;
1269 }
1270 
1271 template<>
1273  return a != b;
1274 }
1275 
1276 template<>
1278  return a != b;
1279 }
1280 
1281 template<typename A, typename B>
1283  return {pattern_arg(a), pattern_arg(b)};
1284 }
1285 
1286 template<typename A, typename B>
1287 HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b)) {
1288  return IRMatcher::operator||(a, b);
1289 }
1290 
1291 template<>
1293  return (a | b) & 1;
1294 }
1295 
1296 template<>
1298  return (a | b) & 1;
1299 }
1300 
1301 template<>
1302 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Or>(halide_type_t &t, double a, double b) noexcept {
1303  // Unreachable, as it would be a type mismatch.
1304  return 0;
1305 }
1306 
1307 template<typename A, typename B>
1309  return {pattern_arg(a), pattern_arg(b)};
1310 }
1311 
1312 template<typename A, typename B>
1313 HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b)) {
1314  return IRMatcher::operator&&(a, b);
1315 }
1316 
1317 template<>
1319  return a & b & 1;
1320 }
1321 
1322 template<>
1324  return a & b & 1;
1325 }
1326 
1327 template<>
1328 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<And>(halide_type_t &t, double a, double b) noexcept {
1329  // Unreachable
1330  return 0;
1331 }
1332 
1333 constexpr inline uint32_t bitwise_or_reduce() {
1334  return 0;
1335 }
1336 
1337 template<typename... Args>
1338 constexpr uint32_t bitwise_or_reduce(uint32_t first, Args... rest) {
1339  return first | bitwise_or_reduce(rest...);
1340 }
1341 
1342 constexpr inline bool and_reduce() {
1343  return true;
1344 }
1345 
1346 template<typename... Args>
1347 constexpr bool and_reduce(bool first, Args... rest) {
1348  return first && and_reduce(rest...);
1349 }
1350 
1351 // TODO: this can be replaced with std::min() once we require C++14 or later
1352 constexpr int const_min(int a, int b) {
1353  return a < b ? a : b;
1354 }
1355 
1356 template<typename... Args>
1357 struct Intrin {
1358  struct pattern_tag {};
1360  std::tuple<Args...> args;
1361  // The type of the output of the intrinsic node.
1362  // Only necessary in cases where it can't be inferred
1363  // from the input types (e.g. saturating_cast).
1365 
1367 
1370  constexpr static bool canonical = and_reduce((Args::canonical)...);
1371 
1372  template<int i,
1373  uint32_t bound,
1374  typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1375  HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept {
1376  using T = decltype(std::get<i>(args));
1377  return (std::get<i>(args).template match<bound>(*c.args[i].get(), state) &&
1378  match_args<i + 1, bound | bindings<T>::mask>(0, c, state));
1379  }
1380 
1381  template<int i, uint32_t binds>
1382  HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept {
1383  return true;
1384  }
1385 
1386  template<uint32_t bound>
1387  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1388  if (e.node_type != IRNodeType::Call) {
1389  return false;
1390  }
1391  const Call &c = (const Call &)e;
1392  return (c.is_intrinsic(intrin) &&
1393  ((optional_type_hint == Type()) || optional_type_hint == e.type) &&
1394  match_args<0, bound>(0, c, state));
1395  }
1396 
1397  template<int i,
1398  typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1399  HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const {
1400  s << std::get<i>(args);
1401  if (i + 1 < sizeof...(Args)) {
1402  s << ", ";
1403  }
1404  print_args<i + 1>(0, s);
1405  }
1406 
1407  template<int i>
1408  HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const {
1409  }
1410 
1412  void print_args(std::ostream &s) const {
1413  print_args<0>(0, s);
1414  }
1415 
1417  Expr make(MatcherState &state, halide_type_t type_hint) const {
1418  Expr arg0 = std::get<0>(args).make(state, type_hint);
1419  if (intrin == Call::likely) {
1420  return likely(arg0);
1421  } else if (intrin == Call::likely_if_innermost) {
1422  return likely_if_innermost(arg0);
1423  } else if (intrin == Call::abs) {
1424  return abs(arg0);
1425  } else if (intrin == Call::saturating_cast) {
1426  return saturating_cast(optional_type_hint, arg0);
1427  }
1428 
1429  Expr arg1 = std::get<const_min(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
1430  if (intrin == Call::absd) {
1431  return absd(arg0, arg1);
1432  } else if (intrin == Call::widen_right_add) {
1433  return widen_right_add(arg0, arg1);
1434  } else if (intrin == Call::widen_right_mul) {
1435  return widen_right_mul(arg0, arg1);
1436  } else if (intrin == Call::widen_right_sub) {
1437  return widen_right_sub(arg0, arg1);
1438  } else if (intrin == Call::widening_add) {
1439  return widening_add(arg0, arg1);
1440  } else if (intrin == Call::widening_sub) {
1441  return widening_sub(arg0, arg1);
1442  } else if (intrin == Call::widening_mul) {
1443  return widening_mul(arg0, arg1);
1444  } else if (intrin == Call::saturating_add) {
1445  return saturating_add(arg0, arg1);
1446  } else if (intrin == Call::saturating_sub) {
1447  return saturating_sub(arg0, arg1);
1448  } else if (intrin == Call::halving_add) {
1449  return halving_add(arg0, arg1);
1450  } else if (intrin == Call::halving_sub) {
1451  return halving_sub(arg0, arg1);
1452  } else if (intrin == Call::rounding_halving_add) {
1453  return rounding_halving_add(arg0, arg1);
1454  } else if (intrin == Call::shift_left) {
1455  return arg0 << arg1;
1456  } else if (intrin == Call::shift_right) {
1457  return arg0 >> arg1;
1458  } else if (intrin == Call::rounding_shift_left) {
1459  return rounding_shift_left(arg0, arg1);
1460  } else if (intrin == Call::rounding_shift_right) {
1461  return rounding_shift_right(arg0, arg1);
1462  }
1463 
1464  Expr arg2 = std::get<const_min(2, sizeof...(Args) - 1)>(args).make(state, type_hint);
1465  if (intrin == Call::mul_shift_right) {
1466  return mul_shift_right(arg0, arg1, arg2);
1467  } else if (intrin == Call::rounding_mul_shift_right) {
1468  return rounding_mul_shift_right(arg0, arg1, arg2);
1469  }
1470 
1471  internal_error << "Unhandled intrinsic in IRMatcher: " << intrin;
1472  return Expr();
1473  }
1474 
1475  constexpr static bool foldable = true;
1476 
1478  halide_scalar_value_t arg1;
1479  // Assuming the args have the same type as the intrinsic is incorrect in
1480  // general. But for the intrinsics we can fold (just shifts), the LHS
1481  // has the same type as the intrinsic, and we can always treat the RHS
1482  // as a signed int, because we're using 64 bits for it.
1483  std::get<0>(args).make_folded_const(val, ty, state);
1484  halide_type_t signed_ty = ty;
1485  signed_ty.code = halide_type_int;
1486  // We can just directly get the second arg here, because we only want to
1487  // instantiate this method for shifts, which have two args.
1488  std::get<1>(args).make_folded_const(arg1, signed_ty, state);
1489 
1490  if (intrin == Call::shift_left) {
1491  if (arg1.u.i64 < 0) {
1492  if (ty.code == halide_type_int) {
1493  // Arithmetic shift
1494  val.u.i64 >>= -arg1.u.i64;
1495  } else {
1496  // Logical shift
1497  val.u.u64 >>= -arg1.u.i64;
1498  }
1499  } else {
1500  val.u.u64 <<= arg1.u.i64;
1501  }
1502  } else if (intrin == Call::shift_right) {
1503  if (arg1.u.i64 > 0) {
1504  if (ty.code == halide_type_int) {
1505  // Arithmetic shift
1506  val.u.i64 >>= arg1.u.i64;
1507  } else {
1508  // Logical shift
1509  val.u.u64 >>= arg1.u.i64;
1510  }
1511  } else {
1512  val.u.u64 <<= -arg1.u.i64;
1513  }
1514  } else {
1515  internal_error << "Folding not implemented for intrinsic: " << intrin;
1516  }
1517  }
1518 
1521  : intrin(intrin), args(args...) {
1522  }
1523 };
1524 
1525 template<typename... Args>
1526 std::ostream &operator<<(std::ostream &s, const Intrin<Args...> &op) {
1527  s << op.intrin << "(";
1528  op.print_args(s);
1529  s << ")";
1530  return s;
1531 }
1532 
1533 template<typename... Args>
1534 HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin<decltype(pattern_arg(args))...> {
1535  return {intrinsic_op, pattern_arg(args)...};
1536 }
1537 
1538 template<typename A, typename B>
1541 }
1542 template<typename A, typename B>
1545 }
1546 template<typename A, typename B>
1549 }
1550 
1551 template<typename A, typename B>
1553  return {Call::widening_add, pattern_arg(a), pattern_arg(b)};
1554 }
1555 template<typename A, typename B>
1557  return {Call::widening_sub, pattern_arg(a), pattern_arg(b)};
1558 }
1559 template<typename A, typename B>
1561  return {Call::widening_mul, pattern_arg(a), pattern_arg(b)};
1562 }
1563 template<typename A, typename B>
1565  return {Call::saturating_add, pattern_arg(a), pattern_arg(b)};
1566 }
1567 template<typename A, typename B>
1569  return {Call::saturating_sub, pattern_arg(a), pattern_arg(b)};
1570 }
1571 template<typename A>
1572 auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin<decltype(pattern_arg(a))> {
1574  p.optional_type_hint = t;
1575  return p;
1576 }
1577 template<typename A, typename B>
1579  return {Call::halving_add, pattern_arg(a), pattern_arg(b)};
1580 }
1581 template<typename A, typename B>
1583  return {Call::halving_sub, pattern_arg(a), pattern_arg(b)};
1584 }
1585 template<typename A, typename B>
1588 }
1589 template<typename A, typename B>
1591  return {Call::shift_left, pattern_arg(a), pattern_arg(b)};
1592 }
1593 template<typename A, typename B>
1595  return {Call::shift_right, pattern_arg(a), pattern_arg(b)};
1596 }
1597 template<typename A, typename B>
1600 }
1601 template<typename A, typename B>
1604 }
1605 template<typename A, typename B, typename C>
1608 }
1609 template<typename A, typename B, typename C>
1612 }
1613 
1614 template<typename A>
1615 struct NotOp {
1616  struct pattern_tag {};
1617  A a;
1618 
1619  constexpr static uint32_t binds = bindings<A>::mask;
1620 
1623  constexpr static bool canonical = A::canonical;
1624 
1625  template<uint32_t bound>
1626  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1627  if (e.node_type != IRNodeType::Not) {
1628  return false;
1629  }
1630  const Not &op = (const Not &)e;
1631  return (a.template match<bound>(*op.a.get(), state));
1632  }
1633 
1634  template<uint32_t bound, typename A2>
1635  HALIDE_ALWAYS_INLINE bool match(const NotOp<A2> &op, MatcherState &state) const noexcept {
1636  return a.template match<bound>(unwrap(op.a), state);
1637  }
1638 
1640  Expr make(MatcherState &state, halide_type_t type_hint) const {
1641  return Not::make(a.make(state, type_hint));
1642  }
1643 
1644  constexpr static bool foldable = A::foldable;
1645 
1646  template<typename A1 = A>
1648  a.make_folded_const(val, ty, state);
1649  val.u.u64 = ~val.u.u64;
1650  val.u.u64 &= 1;
1651  }
1652 };
1653 
1654 template<typename A>
1656  assert_is_lvalue_if_expr<A>();
1657  return {pattern_arg(a)};
1658 }
1659 
1660 template<typename A>
1661 HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a)) {
1662  assert_is_lvalue_if_expr<A>();
1663  return IRMatcher::operator!(a);
1664 }
1665 
1666 template<typename A>
1667 inline std::ostream &operator<<(std::ostream &s, const NotOp<A> &op) {
1668  s << "!(" << op.a << ")";
1669  return s;
1670 }
1671 
1672 template<typename C, typename T, typename F>
1673 struct SelectOp {
1674  struct pattern_tag {};
1675  C c;
1676  T t;
1677  F f;
1678 
1680 
1683 
1684  constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
1685 
1686  template<uint32_t bound>
1687  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1688  if (e.node_type != Select::_node_type) {
1689  return false;
1690  }
1691  const Select &op = (const Select &)e;
1692  return (c.template match<bound>(*op.condition.get(), state) &&
1693  t.template match<bound | bindings<C>::mask>(*op.true_value.get(), state) &&
1694  f.template match<bound | bindings<C>::mask | bindings<T>::mask>(*op.false_value.get(), state));
1695  }
1696  template<uint32_t bound, typename C2, typename T2, typename F2>
1697  HALIDE_ALWAYS_INLINE bool match(const SelectOp<C2, T2, F2> &instance, MatcherState &state) const noexcept {
1698  return (c.template match<bound>(unwrap(instance.c), state) &&
1699  t.template match<bound | bindings<C>::mask>(unwrap(instance.t), state) &&
1700  f.template match<bound | bindings<C>::mask | bindings<T>::mask>(unwrap(instance.f), state));
1701  }
1702 
1704  Expr make(MatcherState &state, halide_type_t type_hint) const {
1705  return Select::make(c.make(state, {}), t.make(state, type_hint), f.make(state, type_hint));
1706  }
1707 
1708  constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
1709 
1710  template<typename C1 = C>
1712  halide_scalar_value_t c_val, t_val, f_val;
1713  halide_type_t c_ty;
1714  c.make_folded_const(c_val, c_ty, state);
1715  if ((c_val.u.u64 & 1) == 1) {
1716  t.make_folded_const(val, ty, state);
1717  } else {
1718  f.make_folded_const(val, ty, state);
1719  }
1720  ty.lanes |= c_ty.lanes & MatcherState::special_values_mask;
1721  }
1722 };
1723 
1724 template<typename C, typename T, typename F>
1725 std::ostream &operator<<(std::ostream &s, const SelectOp<C, T, F> &op) {
1726  s << "select(" << op.c << ", " << op.t << ", " << op.f << ")";
1727  return s;
1728 }
1729 
1730 template<typename C, typename T, typename F>
1732  assert_is_lvalue_if_expr<C>();
1733  assert_is_lvalue_if_expr<T>();
1734  assert_is_lvalue_if_expr<F>();
1735  return {pattern_arg(c), pattern_arg(t), pattern_arg(f)};
1736 }
1737 
1738 template<typename A, typename B>
1739 struct BroadcastOp {
1740  struct pattern_tag {};
1741  A a;
1743 
1745 
1748 
1749  constexpr static bool canonical = A::canonical && B::canonical;
1750 
1751  template<uint32_t bound>
1752  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1753  if (e.node_type == Broadcast::_node_type) {
1754  const Broadcast &op = (const Broadcast &)e;
1755  if (a.template match<bound>(*op.value.get(), state) &&
1756  lanes.template match<bound>(op.lanes, state)) {
1757  return true;
1758  }
1759  }
1760  return false;
1761  }
1762 
1763  template<uint32_t bound, typename A2, typename B2>
1764  HALIDE_ALWAYS_INLINE bool match(const BroadcastOp<A2, B2> &op, MatcherState &state) const noexcept {
1765  return (a.template match<bound>(unwrap(op.a), state) &&
1766  lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1767  }
1768 
1770  Expr make(MatcherState &state, halide_type_t type_hint) const {
1771  halide_scalar_value_t lanes_val;
1772  halide_type_t ty;
1773  lanes.make_folded_const(lanes_val, ty, state);
1774  int32_t l = (int32_t)lanes_val.u.i64;
1775  type_hint.lanes /= l;
1776  Expr val = a.make(state, type_hint);
1777  if (l == 1) {
1778  return val;
1779  } else {
1780  return Broadcast::make(std::move(val), l);
1781  }
1782  }
1783 
1784  constexpr static bool foldable = false;
1785 
1786  template<typename A1 = A>
1788  halide_scalar_value_t lanes_val;
1789  halide_type_t lanes_ty;
1790  lanes.make_folded_const(lanes_val, lanes_ty, state);
1791  uint16_t l = (uint16_t)lanes_val.u.i64;
1792  a.make_folded_const(val, ty, state);
1793  ty.lanes = l | (ty.lanes & MatcherState::special_values_mask);
1794  }
1795 };
1796 
1797 template<typename A, typename B>
1798 inline std::ostream &operator<<(std::ostream &s, const BroadcastOp<A, B> &op) {
1799  s << "broadcast(" << op.a << ", " << op.lanes << ")";
1800  return s;
1801 }
1802 
1803 template<typename A, typename B>
1805  assert_is_lvalue_if_expr<A>();
1806  return {pattern_arg(a), pattern_arg(lanes)};
1807 }
1808 
1809 template<typename A, typename B, typename C>
1810 struct RampOp {
1811  struct pattern_tag {};
1812  A a;
1813  B b;
1815 
1817 
1820 
1821  constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
1822 
1823  template<uint32_t bound>
1824  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1825  if (e.node_type != Ramp::_node_type) {
1826  return false;
1827  }
1828  const Ramp &op = (const Ramp &)e;
1829  if (a.template match<bound>(*op.base.get(), state) &&
1830  b.template match<bound | bindings<A>::mask>(*op.stride.get(), state) &&
1831  lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(op.lanes, state)) {
1832  return true;
1833  } else {
1834  return false;
1835  }
1836  }
1837 
1838  template<uint32_t bound, typename A2, typename B2, typename C2>
1839  HALIDE_ALWAYS_INLINE bool match(const RampOp<A2, B2, C2> &op, MatcherState &state) const noexcept {
1840  return (a.template match<bound>(unwrap(op.a), state) &&
1841  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state) &&
1842  lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(unwrap(op.lanes), state));
1843  }
1844 
1846  Expr make(MatcherState &state, halide_type_t type_hint) const {
1847  halide_scalar_value_t lanes_val;
1848  halide_type_t ty;
1849  lanes.make_folded_const(lanes_val, ty, state);
1850  int32_t l = (int32_t)lanes_val.u.i64;
1851  type_hint.lanes /= l;
1852  Expr ea, eb;
1853  eb = b.make(state, type_hint);
1854  ea = a.make(state, eb.type());
1855  return Ramp::make(ea, eb, l);
1856  }
1857 
1858  constexpr static bool foldable = false;
1859 };
1860 
1861 template<typename A, typename B, typename C>
1862 std::ostream &operator<<(std::ostream &s, const RampOp<A, B, C> &op) {
1863  s << "ramp(" << op.a << ", " << op.b << ", " << op.lanes << ")";
1864  return s;
1865 }
1866 
1867 template<typename A, typename B, typename C>
1869  assert_is_lvalue_if_expr<A>();
1870  assert_is_lvalue_if_expr<B>();
1871  assert_is_lvalue_if_expr<C>();
1872  return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1873 }
1874 
1875 template<typename A, typename B, VectorReduce::Operator reduce_op>
1877  struct pattern_tag {};
1878  A a;
1880 
1882 
1885  constexpr static bool canonical = A::canonical;
1886 
1887  template<uint32_t bound>
1888  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1889  if (e.node_type == VectorReduce::_node_type) {
1890  const VectorReduce &op = (const VectorReduce &)e;
1891  if (op.op == reduce_op &&
1892  a.template match<bound>(*op.value.get(), state) &&
1893  lanes.template match<bound | bindings<A>::mask>(op.type.lanes(), state)) {
1894  return true;
1895  }
1896  }
1897  return false;
1898  }
1899 
1900  template<uint32_t bound, typename A2, typename B2, VectorReduce::Operator reduce_op_2>
1902  return (reduce_op == reduce_op_2 &&
1903  a.template match<bound>(unwrap(op.a), state) &&
1904  lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1905  }
1906 
1908  Expr make(MatcherState &state, halide_type_t type_hint) const {
1909  halide_scalar_value_t lanes_val;
1910  halide_type_t ty;
1911  lanes.make_folded_const(lanes_val, ty, state);
1912  int l = (int)lanes_val.u.i64;
1913  return VectorReduce::make(reduce_op, a.make(state, type_hint), l);
1914  }
1915 
1916  constexpr static bool foldable = false;
1917 };
1918 
1919 template<typename A, typename B, VectorReduce::Operator reduce_op>
1920 inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp<A, B, reduce_op> &op) {
1921  s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")";
1922  return s;
1923 }
1924 
1925 template<typename A, typename B>
1927  assert_is_lvalue_if_expr<A>();
1928  return {pattern_arg(a), pattern_arg(lanes)};
1929 }
1930 
1931 template<typename A, typename B>
1933  assert_is_lvalue_if_expr<A>();
1934  return {pattern_arg(a), pattern_arg(lanes)};
1935 }
1936 
1937 template<typename A, typename B>
1939  assert_is_lvalue_if_expr<A>();
1940  return {pattern_arg(a), pattern_arg(lanes)};
1941 }
1942 
1943 template<typename A, typename B>
1945  assert_is_lvalue_if_expr<A>();
1946  return {pattern_arg(a), pattern_arg(lanes)};
1947 }
1948 
1949 template<typename A, typename B>
1951  assert_is_lvalue_if_expr<A>();
1952  return {pattern_arg(a), pattern_arg(lanes)};
1953 }
1954 
1955 template<typename A>
1956 struct NegateOp {
1957  struct pattern_tag {};
1958  A a;
1959 
1960  constexpr static uint32_t binds = bindings<A>::mask;
1961 
1964 
1965  constexpr static bool canonical = A::canonical;
1966 
1967  template<uint32_t bound>
1968  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1969  if (e.node_type != Sub::_node_type) {
1970  return false;
1971  }
1972  const Sub &op = (const Sub &)e;
1973  return (a.template match<bound>(*op.b.get(), state) &&
1974  is_const_zero(op.a));
1975  }
1976 
1977  template<uint32_t bound, typename A2>
1978  HALIDE_ALWAYS_INLINE bool match(NegateOp<A2> &&p, MatcherState &state) const noexcept {
1979  return a.template match<bound>(unwrap(p.a), state);
1980  }
1981 
1983  Expr make(MatcherState &state, halide_type_t type_hint) const {
1984  Expr ea = a.make(state, type_hint);
1985  Expr z = make_zero(ea.type());
1986  return Sub::make(std::move(z), std::move(ea));
1987  }
1988 
1989  constexpr static bool foldable = A::foldable;
1990 
1991  template<typename A1 = A>
1993  a.make_folded_const(val, ty, state);
1994  int dead_bits = 64 - ty.bits;
1995  switch (ty.code) {
1996  case halide_type_int:
1997  if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
1998  // Trying to negate the most negative signed int for a no-overflow type.
2000  } else {
2001  // Negate, drop the high bits, and then sign-extend them back
2002  val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits;
2003  }
2004  break;
2005  case halide_type_uint:
2006  val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
2007  break;
2008  case halide_type_float:
2009  case halide_type_bfloat:
2010  val.u.f64 = -val.u.f64;
2011  break;
2012  default:
2013  // unreachable
2014  ;
2015  }
2016  }
2017 };
2018 
2019 template<typename A>
2020 std::ostream &operator<<(std::ostream &s, const NegateOp<A> &op) {
2021  s << "-" << op.a;
2022  return s;
2023 }
2024 
2025 template<typename A>
2027  assert_is_lvalue_if_expr<A>();
2028  return {pattern_arg(a)};
2029 }
2030 
2031 template<typename A>
2032 HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a)) {
2033  assert_is_lvalue_if_expr<A>();
2034  return IRMatcher::operator-(a);
2035 }
2036 
2037 template<typename A>
2038 struct CastOp {
2039  struct pattern_tag {};
2041  A a;
2042 
2043  constexpr static uint32_t binds = bindings<A>::mask;
2044 
2047  constexpr static bool canonical = A::canonical;
2048 
2049  template<uint32_t bound>
2050  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2051  if (e.node_type != Cast::_node_type) {
2052  return false;
2053  }
2054  const Cast &op = (const Cast &)e;
2055  return (e.type == t &&
2056  a.template match<bound>(*op.value.get(), state));
2057  }
2058  template<uint32_t bound, typename A2>
2059  HALIDE_ALWAYS_INLINE bool match(const CastOp<A2> &op, MatcherState &state) const noexcept {
2060  return t == op.t && a.template match<bound>(unwrap(op.a), state);
2061  }
2062 
2064  Expr make(MatcherState &state, halide_type_t type_hint) const {
2065  return cast(t, a.make(state, {}));
2066  }
2067 
2068  constexpr static bool foldable = false;
2069 };
2070 
2071 template<typename A>
2072 std::ostream &operator<<(std::ostream &s, const CastOp<A> &op) {
2073  s << "cast(" << op.t << ", " << op.a << ")";
2074  return s;
2075 }
2076 
2077 template<typename A>
2079  assert_is_lvalue_if_expr<A>();
2080  return {t, pattern_arg(a)};
2081 }
2082 
2083 template<typename Vec, typename Base, typename Stride, typename Lanes>
2084 struct SliceOp {
2085  struct pattern_tag {};
2086  Vec vec;
2087  Base base;
2088  Stride stride;
2089  Lanes lanes;
2090 
2091  static constexpr uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds;
2092 
2095  constexpr static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical;
2096 
2097  template<uint32_t bound>
2098  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2099  if (e.node_type != IRNodeType::Shuffle) {
2100  return false;
2101  }
2102  const Shuffle &v = (const Shuffle &)e;
2103  return v.vectors.size() == 1 &&
2104  v.is_slice() &&
2105  vec.template match<bound>(*v.vectors[0].get(), state) &&
2106  base.template match<bound | bindings<Vec>::mask>(v.slice_begin(), state) &&
2107  stride.template match<bound | bindings<Vec>::mask | bindings<Base>::mask>(v.slice_stride(), state) &&
2109  }
2110 
2112  Expr make(MatcherState &state, halide_type_t type_hint) const {
2113  halide_scalar_value_t base_val, stride_val, lanes_val;
2114  halide_type_t ty;
2115  base.make_folded_const(base_val, ty, state);
2116  int b = (int)base_val.u.i64;
2117  stride.make_folded_const(stride_val, ty, state);
2118  int s = (int)stride_val.u.i64;
2119  lanes.make_folded_const(lanes_val, ty, state);
2120  int l = (int)lanes_val.u.i64;
2121  return Shuffle::make_slice(vec.make(state, type_hint), b, s, l);
2122  }
2123 
2124  constexpr static bool foldable = false;
2125 
2127  SliceOp(Vec v, Base b, Stride s, Lanes l)
2128  : vec(v), base(b), stride(s), lanes(l) {
2129  static_assert(Base::foldable, "Base of slice should consist only of operations that constant-fold");
2130  static_assert(Stride::foldable, "Stride of slice should consist only of operations that constant-fold");
2131  static_assert(Lanes::foldable, "Lanes of slice should consist only of operations that constant-fold");
2132  }
2133 };
2134 
2135 template<typename Vec, typename Base, typename Stride, typename Lanes>
2136 std::ostream &operator<<(std::ostream &s, const SliceOp<Vec, Base, Stride, Lanes> &op) {
2137  s << "slice(" << op.vec << ", " << op.base << ", " << op.stride << ", " << op.lanes << ")";
2138  return s;
2139 }
2140 
2141 template<typename Vec, typename Base, typename Stride, typename Lanes>
2142 HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept
2144  return {pattern_arg(vec), pattern_arg(base), pattern_arg(stride), pattern_arg(lanes)};
2145 }
2146 
2147 template<typename A>
2148 struct Fold {
2149  struct pattern_tag {};
2150  A a;
2151 
2152  constexpr static uint32_t binds = bindings<A>::mask;
2153 
2156  constexpr static bool canonical = true;
2157 
2159  Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
2161  halide_type_t ty = type_hint;
2162  a.make_folded_const(c, ty, state);
2163 
2164  // The result of the fold may have an underspecified type
2165  // (e.g. because it's from an int literal). Make the type code
2166  // and bits match the required type, if there is one (we can
2167  // tell from the bits field).
2168  if (type_hint.bits) {
2169  if (((int)ty.code == (int)halide_type_int) &&
2170  ((int)type_hint.code == (int)halide_type_float)) {
2171  int64_t x = c.u.i64;
2172  c.u.f64 = (double)x;
2173  }
2174  ty.code = type_hint.code;
2175  ty.bits = type_hint.bits;
2176  }
2177 
2178  Expr e = make_const_expr(c, ty);
2179  return e;
2180  }
2181 
2182  constexpr static bool foldable = A::foldable;
2183 
2184  template<typename A1 = A>
2186  a.make_folded_const(val, ty, state);
2187  }
2188 };
2189 
2190 template<typename A>
2192  assert_is_lvalue_if_expr<A>();
2193  return {pattern_arg(a)};
2194 }
2195 
2196 template<typename A>
2197 std::ostream &operator<<(std::ostream &s, const Fold<A> &op) {
2198  s << "fold(" << op.a << ")";
2199  return s;
2200 }
2201 
2202 template<typename A>
2203 struct Overflows {
2204  struct pattern_tag {};
2205  A a;
2206 
2207  constexpr static uint32_t binds = bindings<A>::mask;
2208 
2209  // This rule is a predicate, so it always evaluates to a boolean,
2210  // which has IRNodeType UIntImm
2213  constexpr static bool canonical = true;
2214 
2215  constexpr static bool foldable = A::foldable;
2216 
2217  template<typename A1 = A>
2219  a.make_folded_const(val, ty, state);
2220  ty.code = halide_type_uint;
2221  ty.bits = 64;
2222  val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0;
2223  ty.lanes = 1;
2224  }
2225 };
2226 
2227 template<typename A>
2229  assert_is_lvalue_if_expr<A>();
2230  return {pattern_arg(a)};
2231 }
2232 
2233 template<typename A>
2234 std::ostream &operator<<(std::ostream &s, const Overflows<A> &op) {
2235  s << "overflows(" << op.a << ")";
2236  return s;
2237 }
2238 
2239 struct Overflow {
2240  struct pattern_tag {};
2241 
2242  constexpr static uint32_t binds = 0;
2243 
2244  // Overflow is an intrinsic, represented as a Call node
2247  constexpr static bool canonical = true;
2248 
2249  template<uint32_t bound>
2250  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2251  if (e.node_type != Call::_node_type) {
2252  return false;
2253  }
2254  const Call &op = (const Call &)e;
2256  }
2257 
2259  Expr make(MatcherState &state, halide_type_t type_hint) const {
2261  return make_const_special_expr(type_hint);
2262  }
2263 
2264  constexpr static bool foldable = true;
2265 
2267  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
2268  val.u.u64 = 0;
2270  }
2271 };
2272 
2273 inline std::ostream &operator<<(std::ostream &s, const Overflow &op) {
2274  s << "overflow()";
2275  return s;
2276 }
2277 
2278 template<typename A>
2279 struct IsConst {
2280  struct pattern_tag {};
2281 
2282  constexpr static uint32_t binds = bindings<A>::mask;
2283 
2284  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2287  constexpr static bool canonical = true;
2288 
2289  A a;
2290  bool check_v;
2292 
2293  constexpr static bool foldable = true;
2294 
2295  template<typename A1 = A>
2297  Expr e = a.make(state, {});
2298  ty.code = halide_type_uint;
2299  ty.bits = 64;
2300  ty.lanes = 1;
2301  if (check_v) {
2302  val.u.u64 = ::Halide::Internal::is_const(e, v) ? 1 : 0;
2303  } else {
2304  val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0;
2305  }
2306  }
2307 };
2308 
2309 template<typename A>
2311  assert_is_lvalue_if_expr<A>();
2312  return {pattern_arg(a), false, 0};
2313 }
2314 
2315 template<typename A>
2317  assert_is_lvalue_if_expr<A>();
2318  return {pattern_arg(a), true, value};
2319 }
2320 
2321 template<typename A>
2322 std::ostream &operator<<(std::ostream &s, const IsConst<A> &op) {
2323  if (op.check_v) {
2324  s << "is_const(" << op.a << ")";
2325  } else {
2326  s << "is_const(" << op.a << ", " << op.v << ")";
2327  }
2328  return s;
2329 }
2330 
2331 template<typename A, typename Prover>
2332 struct CanProve {
2333  struct pattern_tag {};
2334  A a;
2335  Prover *prover; // An existing simplifying mutator
2336 
2337  constexpr static uint32_t binds = bindings<A>::mask;
2338 
2339  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2342  constexpr static bool canonical = true;
2343 
2344  constexpr static bool foldable = true;
2345 
2346  // Includes a raw call to an inlined make method, so don't inline.
2348  Expr condition = a.make(state, {});
2349  condition = prover->mutate(condition, nullptr);
2350  val.u.u64 = is_const_one(condition);
2351  ty.code = halide_type_uint;
2352  ty.bits = 1;
2353  ty.lanes = condition.type().lanes();
2354  }
2355 };
2356 
2357 template<typename A, typename Prover>
2359  assert_is_lvalue_if_expr<A>();
2360  return {pattern_arg(a), p};
2361 }
2362 
2363 template<typename A, typename Prover>
2364 std::ostream &operator<<(std::ostream &s, const CanProve<A, Prover> &op) {
2365  s << "can_prove(" << op.a << ")";
2366  return s;
2367 }
2368 
2369 template<typename A>
2370 struct IsFloat {
2371  struct pattern_tag {};
2372  A a;
2373 
2374  constexpr static uint32_t binds = bindings<A>::mask;
2375 
2376  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2379  constexpr static bool canonical = true;
2380 
2381  constexpr static bool foldable = true;
2382 
2385  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2386  Type t = a.make(state, {}).type();
2387  val.u.u64 = t.is_float();
2388  ty.code = halide_type_uint;
2389  ty.bits = 1;
2390  ty.lanes = t.lanes();
2391  }
2392 };
2393 
2394 template<typename A>
2396  assert_is_lvalue_if_expr<A>();
2397  return {pattern_arg(a)};
2398 }
2399 
2400 template<typename A>
2401 std::ostream &operator<<(std::ostream &s, const IsFloat<A> &op) {
2402  s << "is_float(" << op.a << ")";
2403  return s;
2404 }
2405 
2406 template<typename A>
2407 struct IsInt {
2408  struct pattern_tag {};
2409  A a;
2410  int bits, lanes;
2411 
2412  constexpr static uint32_t binds = bindings<A>::mask;
2413 
2414  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2417  constexpr static bool canonical = true;
2418 
2419  constexpr static bool foldable = true;
2420 
2423  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2424  Type t = a.make(state, {}).type();
2425  val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2426  ty.code = halide_type_uint;
2427  ty.bits = 1;
2428  ty.lanes = t.lanes();
2429  }
2430 };
2431 
2432 template<typename A>
2433 HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits = 0, int lanes = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
2434  assert_is_lvalue_if_expr<A>();
2435  return {pattern_arg(a), bits, lanes};
2436 }
2437 
2438 template<typename A>
2439 std::ostream &operator<<(std::ostream &s, const IsInt<A> &op) {
2440  s << "is_int(" << op.a;
2441  if (op.bits > 0) {
2442  s << ", " << op.bits;
2443  }
2444  if (op.lanes > 0) {
2445  s << ", " << op.lanes;
2446  }
2447  s << ")";
2448  return s;
2449 }
2450 
2451 template<typename A>
2452 struct IsUInt {
2453  struct pattern_tag {};
2454  A a;
2455  int bits, lanes;
2456 
2457  constexpr static uint32_t binds = bindings<A>::mask;
2458 
2459  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2462  constexpr static bool canonical = true;
2463 
2464  constexpr static bool foldable = true;
2465 
2468  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2469  Type t = a.make(state, {}).type();
2470  val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2471  ty.code = halide_type_uint;
2472  ty.bits = 1;
2473  ty.lanes = t.lanes();
2474  }
2475 };
2476 
2477 template<typename A>
2478 HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits = 0, int lanes = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
2479  assert_is_lvalue_if_expr<A>();
2480  return {pattern_arg(a), bits, lanes};
2481 }
2482 
2483 template<typename A>
2484 std::ostream &operator<<(std::ostream &s, const IsUInt<A> &op) {
2485  s << "is_uint(" << op.a;
2486  if (op.bits > 0) {
2487  s << ", " << op.bits;
2488  }
2489  if (op.lanes > 0) {
2490  s << ", " << op.lanes;
2491  }
2492  s << ")";
2493  return s;
2494 }
2495 
2496 template<typename A>
2497 struct IsScalar {
2498  struct pattern_tag {};
2499  A a;
2500 
2501  constexpr static uint32_t binds = bindings<A>::mask;
2502 
2503  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2506  constexpr static bool canonical = true;
2507 
2508  constexpr static bool foldable = true;
2509 
2512  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2513  Type t = a.make(state, {}).type();
2514  val.u.u64 = t.is_scalar();
2515  ty.code = halide_type_uint;
2516  ty.bits = 1;
2517  ty.lanes = t.lanes();
2518  }
2519 };
2520 
2521 template<typename A>
2523  assert_is_lvalue_if_expr<A>();
2524  return {pattern_arg(a)};
2525 }
2526 
2527 template<typename A>
2528 std::ostream &operator<<(std::ostream &s, const IsScalar<A> &op) {
2529  s << "is_scalar(" << op.a << ")";
2530  return s;
2531 }
2532 
2533 template<typename A>
2534 struct IsMaxValue {
2535  struct pattern_tag {};
2536  A a;
2537 
2538  constexpr static uint32_t binds = bindings<A>::mask;
2539 
2540  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2543  constexpr static bool canonical = true;
2544 
2545  constexpr static bool foldable = true;
2546 
2549  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2550  a.make_folded_const(val, ty, state);
2551  const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int));
2552  if (ty.code == halide_type_uint || ty.code == halide_type_int) {
2553  val.u.u64 = (val.u.u64 == max_bits);
2554  } else {
2555  val.u.u64 = 0;
2556  }
2557  ty.code = halide_type_uint;
2558  ty.bits = 1;
2559  }
2560 };
2561 
2562 template<typename A>
2564  assert_is_lvalue_if_expr<A>();
2565  return {pattern_arg(a)};
2566 }
2567 
2568 template<typename A>
2569 std::ostream &operator<<(std::ostream &s, const IsMaxValue<A> &op) {
2570  s << "is_max_value(" << op.a << ")";
2571  return s;
2572 }
2573 
2574 template<typename A>
2575 struct IsMinValue {
2576  struct pattern_tag {};
2577  A a;
2578 
2579  constexpr static uint32_t binds = bindings<A>::mask;
2580 
2581  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2584  constexpr static bool canonical = true;
2585 
2586  constexpr static bool foldable = true;
2587 
2590  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2591  a.make_folded_const(val, ty, state);
2592  if (ty.code == halide_type_int) {
2593  const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1);
2594  val.u.u64 = (val.u.u64 == min_bits);
2595  } else if (ty.code == halide_type_uint) {
2596  val.u.u64 = (val.u.u64 == 0);
2597  } else {
2598  val.u.u64 = 0;
2599  }
2600  ty.code = halide_type_uint;
2601  ty.bits = 1;
2602  }
2603 };
2604 
2605 template<typename A>
2607  assert_is_lvalue_if_expr<A>();
2608  return {pattern_arg(a)};
2609 }
2610 
2611 template<typename A>
2612 std::ostream &operator<<(std::ostream &s, const IsMinValue<A> &op) {
2613  s << "is_min_value(" << op.a << ")";
2614  return s;
2615 }
2616 
2617 template<typename A>
2618 struct LanesOf {
2619  struct pattern_tag {};
2620  A a;
2621 
2622  constexpr static uint32_t binds = bindings<A>::mask;
2623 
2624  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2627  constexpr static bool canonical = true;
2628 
2629  constexpr static bool foldable = true;
2630 
2633  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2634  Type t = a.make(state, {}).type();
2635  val.u.u64 = t.lanes();
2636  ty.code = halide_type_uint;
2637  ty.bits = 32;
2638  ty.lanes = 1;
2639  }
2640 };
2641 
2642 template<typename A>
2644  assert_is_lvalue_if_expr<A>();
2645  return {pattern_arg(a)};
2646 }
2647 
2648 template<typename A>
2649 std::ostream &operator<<(std::ostream &s, const LanesOf<A> &op) {
2650  s << "lanes_of(" << op.a << ")";
2651  return s;
2652 }
2653 
2654 // Verify properties of each rewrite rule. Currently just fuzz tests them.
2655 template<typename Before,
2656  typename After,
2657  typename Predicate,
2658  typename = typename std::enable_if<std::decay<Before>::type::foldable &&
2659  std::decay<After>::type::foldable>::type>
2660 HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2661  halide_type_t wildcard_type, halide_type_t output_type) noexcept {
2662 
2663  // We only validate the rules in the scalar case
2664  wildcard_type.lanes = output_type.lanes = 1;
2665 
2666  // Track which types this rule has been tested for before
2667  static std::set<uint32_t> tested;
2668 
2669  if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2670  return;
2671  }
2672 
2673  // Print it in a form where it can be piped into a python/z3 validator
2674  debug(0) << "validate('" << before << "', '" << after << "', '" << pred << "', " << Type(wildcard_type) << ", " << Type(output_type) << ")\n";
2675 
2676  // Substitute some random constants into the before and after
2677  // expressions and see if the rule holds true. This should catch
2678  // silly errors, but not necessarily corner cases.
2679  static std::mt19937_64 rng(0);
2680  MatcherState state;
2681 
2682  Expr exprs[max_wild];
2683 
2684  for (int trials = 0; trials < 100; trials++) {
2685  // We want to test small constants more frequently than
2686  // large ones, otherwise we'll just get coverage of
2687  // overflow rules.
2688  int shift = (int)(rng() & (wildcard_type.bits - 1));
2689 
2690  for (int i = 0; i < max_wild; i++) {
2691  // Bind all the exprs and constants
2692  switch (wildcard_type.code) {
2693  case halide_type_uint: {
2694  // Normalize to the type's range by adding zero
2695  uint64_t val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2696  state.set_bound_const(i, val, wildcard_type);
2697  val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2698  exprs[i] = make_const(wildcard_type, val);
2699  state.set_binding(i, *exprs[i].get());
2700  } break;
2701  case halide_type_int: {
2702  int64_t val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2703  state.set_bound_const(i, val, wildcard_type);
2704  val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2705  exprs[i] = make_const(wildcard_type, val);
2706  } break;
2707  case halide_type_float:
2708  case halide_type_bfloat: {
2709  // Use a very narrow range of precise floats, so
2710  // that none of the rules a human is likely to
2711  // write have instabilities.
2712  double val = ((int64_t)(rng() & 15) - 8) / 2.0;
2713  state.set_bound_const(i, val, wildcard_type);
2714  val = ((int64_t)(rng() & 15) - 8) / 2.0;
2715  exprs[i] = make_const(wildcard_type, val);
2716  } break;
2717  default:
2718  return; // Don't care about handles
2719  }
2720  state.set_binding(i, *exprs[i].get());
2721  }
2722 
2723  halide_scalar_value_t val_pred, val_before, val_after;
2724  halide_type_t type = output_type;
2725  if (!evaluate_predicate(pred, state)) {
2726  continue;
2727  }
2728  before.make_folded_const(val_before, type, state);
2729  uint16_t lanes = type.lanes;
2730  after.make_folded_const(val_after, type, state);
2731  lanes |= type.lanes;
2732 
2733  if (lanes & MatcherState::special_values_mask) {
2734  continue;
2735  }
2736 
2737  bool ok = true;
2738  switch (output_type.code) {
2739  case halide_type_uint:
2740  // Compare normalized representations
2741  ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.u64, 0) ==
2742  constant_fold_bin_op<Add>(output_type, val_after.u.u64, 0));
2743  break;
2744  case halide_type_int:
2745  ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.i64, 0) ==
2746  constant_fold_bin_op<Add>(output_type, val_after.u.i64, 0));
2747  break;
2748  case halide_type_float:
2749  case halide_type_bfloat: {
2750  double error = std::abs(val_before.u.f64 - val_after.u.f64);
2751  // We accept an equal bit pattern (e.g. inf vs inf),
2752  // a small floating point difference, or turning a nan into not-a-nan.
2753  ok &= (error < 0.01 ||
2754  val_before.u.u64 == val_after.u.u64 ||
2755  std::isnan(val_before.u.f64));
2756  break;
2757  }
2758  default:
2759  return;
2760  }
2761 
2762  if (!ok) {
2763  debug(0) << "Fails with values:\n";
2764  for (int i = 0; i < max_wild; i++) {
2766  state.get_bound_const(i, val, wildcard_type);
2767  debug(0) << " c" << i << ": " << make_const_expr(val, wildcard_type) << "\n";
2768  }
2769  for (int i = 0; i < max_wild; i++) {
2770  debug(0) << " _" << i << ": " << Expr(state.get_binding(i)) << "\n";
2771  }
2772  debug(0) << " Before: " << make_const_expr(val_before, output_type) << "\n";
2773  debug(0) << " After: " << make_const_expr(val_after, output_type) << "\n";
2774  debug(0) << val_before.u.u64 << " " << val_after.u.u64 << "\n";
2776  }
2777  }
2778 }
2779 
2780 template<typename Before,
2781  typename After,
2782  typename Predicate,
2783  typename = typename std::enable_if<!(std::decay<Before>::type::foldable &&
2784  std::decay<After>::type::foldable)>::type>
2785 HALIDE_ALWAYS_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2786  halide_type_t, halide_type_t, int dummy = 0) noexcept {
2787  // We can't verify rewrite rules that can't be constant-folded.
2788 }
2789 
2791 bool evaluate_predicate(bool x, MatcherState &) noexcept {
2792  return x;
2793 }
2794 
2795 template<typename Pattern,
2796  typename = typename enable_if_pattern<Pattern>::type>
2799  halide_type_t ty = halide_type_of<bool>();
2800  p.make_folded_const(c, ty, state);
2801  // Overflow counts as a failed predicate
2802  return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0);
2803 }
2804 
2805 // #defines for testing
2806 
2807 // Print all successful or failed matches
2808 #define HALIDE_DEBUG_MATCHED_RULES 0
2809 #define HALIDE_DEBUG_UNMATCHED_RULES 0
2810 
2811 // Set to true if you want to fuzz test every rewrite passed to
2812 // operator() to ensure the input and the output have the same value
2813 // for lots of random values of the wildcards. Run
2814 // correctness_simplify with this on.
2815 #define HALIDE_FUZZ_TEST_RULES 0
2816 
2817 template<typename Instance>
2818 struct Rewriter {
2819  Instance instance;
2823  bool validate;
2824 
2827  : instance(std::move(instance)), output_type(ot), wildcard_type(wt) {
2828  }
2829 
2830  template<typename After>
2832  result = after.make(state, output_type);
2833  }
2834 
2835  template<typename Before,
2836  typename After,
2837  typename = typename enable_if_pattern<Before>::type,
2838  typename = typename enable_if_pattern<After>::type>
2839  HALIDE_ALWAYS_INLINE bool operator()(Before before, After after) {
2840  static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2841  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2842  static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2843 #if HALIDE_FUZZ_TEST_RULES
2844  fuzz_test_rule(before, after, true, wildcard_type, output_type);
2845 #endif
2846  if (before.template match<0>(unwrap(instance), state)) {
2847  build_replacement(after);
2848 #if HALIDE_DEBUG_MATCHED_RULES
2849  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2850 #endif
2851  return true;
2852  } else {
2853 #if HALIDE_DEBUG_UNMATCHED_RULES
2854  debug(0) << instance << " does not match " << before << "\n";
2855 #endif
2856  return false;
2857  }
2858  }
2859 
2860  template<typename Before,
2861  typename = typename enable_if_pattern<Before>::type>
2862  HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept {
2863  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2864  if (before.template match<0>(unwrap(instance), state)) {
2865  result = after;
2866 #if HALIDE_DEBUG_MATCHED_RULES
2867  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2868 #endif
2869  return true;
2870  } else {
2871 #if HALIDE_DEBUG_UNMATCHED_RULES
2872  debug(0) << instance << " does not match " << before << "\n";
2873 #endif
2874  return false;
2875  }
2876  }
2877 
2878  template<typename Before,
2879  typename = typename enable_if_pattern<Before>::type>
2880  HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept {
2881  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2882 #if HALIDE_FUZZ_TEST_RULES
2883  fuzz_test_rule(before, IntLiteral(after), true, wildcard_type, output_type);
2884 #endif
2885  if (before.template match<0>(unwrap(instance), state)) {
2886  result = make_const(output_type, after);
2887 #if HALIDE_DEBUG_MATCHED_RULES
2888  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2889 #endif
2890  return true;
2891  } else {
2892 #if HALIDE_DEBUG_UNMATCHED_RULES
2893  debug(0) << instance << " does not match " << before << "\n";
2894 #endif
2895  return false;
2896  }
2897  }
2898 
2899  template<typename Before,
2900  typename After,
2901  typename Predicate,
2902  typename = typename enable_if_pattern<Before>::type,
2903  typename = typename enable_if_pattern<After>::type,
2904  typename = typename enable_if_pattern<Predicate>::type>
2905  HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred) {
2906  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2907  static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2908  static_assert((Before::binds & Predicate::binds) == Predicate::binds, "Rule predicate uses unbound values");
2909  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2910  static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2911 
2912 #if HALIDE_FUZZ_TEST_RULES
2913  fuzz_test_rule(before, after, pred, wildcard_type, output_type);
2914 #endif
2915  if (before.template match<0>(unwrap(instance), state) &&
2916  evaluate_predicate(pred, state)) {
2917  build_replacement(after);
2918 #if HALIDE_DEBUG_MATCHED_RULES
2919  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2920 #endif
2921  return true;
2922  } else {
2923 #if HALIDE_DEBUG_UNMATCHED_RULES
2924  debug(0) << instance << " does not match " << before << "\n";
2925 #endif
2926  return false;
2927  }
2928  }
2929 
2930  template<typename Before,
2931  typename Predicate,
2932  typename = typename enable_if_pattern<Before>::type,
2933  typename = typename enable_if_pattern<Predicate>::type>
2934  HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred) {
2935  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2936  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2937 
2938  if (before.template match<0>(unwrap(instance), state) &&
2939  evaluate_predicate(pred, state)) {
2940  result = after;
2941 #if HALIDE_DEBUG_MATCHED_RULES
2942  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2943 #endif
2944  return true;
2945  } else {
2946 #if HALIDE_DEBUG_UNMATCHED_RULES
2947  debug(0) << instance << " does not match " << before << "\n";
2948 #endif
2949  return false;
2950  }
2951  }
2952 
2953  template<typename Before,
2954  typename Predicate,
2955  typename = typename enable_if_pattern<Before>::type,
2956  typename = typename enable_if_pattern<Predicate>::type>
2957  HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred) {
2958  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2959  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2960 #if HALIDE_FUZZ_TEST_RULES
2961  fuzz_test_rule(before, IntLiteral(after), pred, wildcard_type, output_type);
2962 #endif
2963  if (before.template match<0>(unwrap(instance), state) &&
2964  evaluate_predicate(pred, state)) {
2965  result = make_const(output_type, after);
2966 #if HALIDE_DEBUG_MATCHED_RULES
2967  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2968 #endif
2969  return true;
2970  } else {
2971 #if HALIDE_DEBUG_UNMATCHED_RULES
2972  debug(0) << instance << " does not match " << before << "\n";
2973 #endif
2974  return false;
2975  }
2976  }
2977 };
2978 
2979 /** Construct a rewriter for the given instance, which may be a pattern
2980  * with concrete expressions as leaves, or just an expression. The
2981  * second optional argument (wildcard_type) is a hint as to what the
2982  * type of the wildcards is likely to be. If omitted it uses the same
2983  * type as the expression itself. They are not required to be this
2984  * type, but the rule will only be tested for wildcards of that type
2985  * when testing is enabled.
2986  *
2987  * The rewriter can be used to check to see if the instance is one of
2988  * some number of patterns and if so rewrite it into another form,
2989  * using its operator() method. See Simplify.cpp for a bunch of
2990  * example usage.
2991  *
2992  * Important: Any Exprs in patterns are captured by reference, not by
2993  * value, so ensure they outlive the rewriter.
2994  */
2995 // @{
2996 template<typename Instance,
2997  typename = typename enable_if_pattern<Instance>::type>
2998 HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
2999  return {pattern_arg(instance), output_type, wildcard_type};
3000 }
3001 
3002 template<typename Instance,
3003  typename = typename enable_if_pattern<Instance>::type>
3005  return {pattern_arg(instance), output_type, output_type};
3006 }
3007 
3009 auto rewriter(const Expr &e, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3010  return {pattern_arg(e), e.type(), wildcard_type};
3011 }
3012 
3015  return {pattern_arg(e), e.type(), e.type()};
3016 }
3017 // @}
3018 
3019 } // namespace IRMatcher
3020 
3021 } // namespace Internal
3022 } // namespace Halide
3023 
3024 #endif
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:476
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:777
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:507
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:1681
Unsigned integer constants.
Definition: Expr.h:227
static constexpr uint32_t binds
Definition: IRMatch.h:2579
static const IRNodeType _node_type
Definition: IR.h:209
The actual IR nodes begin here.
Definition: IR.h:30
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
Definition: IRMatch.h:437
static constexpr bool canonical
Definition: IRMatch.h:2247
static constexpr bool canonical
Definition: IRMatch.h:1821
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
Definition: IRMatch.h:2880
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:1962
static constexpr uint32_t binds
Definition: IRMatch.h:2282
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1242
constexpr bool and_reduce()
Definition: IRMatch.h:1342
A fragment of Halide syntax.
Definition: Expr.h:258
bool is_slice() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so...
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
Definition: IRMatch.h:2862
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:2246
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:311
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
Definition: IRMatch.h:127
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:2154
static constexpr bool canonical
Definition: IRMatch.h:2156
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:245
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:2155
static constexpr bool canonical
Definition: IRMatch.h:473
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:820
static Expr make(Expr condition, Expr true_value, Expr false_value)
Integer constants.
Definition: Expr.h:218
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2384
static constexpr bool canonical
Definition: IRMatch.h:1684
void expr_match_test()
static constexpr bool canonical
Definition: IRMatch.h:1370
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:2416
static constexpr bool foldable
Definition: IRMatch.h:2293
std::tuple< Args... > args
Definition: IRMatch.h:1360
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1901
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:2378
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:359
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:332
The difference of two expressions.
Definition: IR.h:65
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:786
static const IRNodeType _node_type
Definition: IR.h:985
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:1963
A vector with &#39;lanes&#39; elements, in which every element is &#39;value&#39;.
Definition: IR.h:259
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1764
static constexpr bool foldable
Definition: IRMatch.h:1989
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
Definition: IRMatch.h:1655
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1704
static constexpr bool foldable
Definition: IRMatch.h:1858
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:544
static constexpr bool canonical
Definition: IRMatch.h:2047
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:1884
auto widening_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1552
uint16_t lanes
How many elements in a vector.
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1992
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
Definition: IRMatch.h:2839
Logical not - true if the expression false.
Definition: IR.h:193
Methods to test Exprs and Stmts for equality of value.
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition: IR.h:966
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
Definition: IRMatch.h:571
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:2583
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2467
static constexpr bool canonical
Definition: IRMatch.h:760
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
Definition: IRMatch.h:103
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
Definition: IRMatch.h:539
static constexpr bool foldable
Definition: IRMatch.h:1475
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:485
HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp< Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1110
static constexpr bool canonical
Definition: IRMatch.h:1965
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
Definition: IRMatch.h:2791
static constexpr uint32_t binds
Definition: IRMatch.h:2538
static const UIntImm * make(Type t, uint64_t value)
IEEE floating point numbers.
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
Definition: IRMatch.h:2563
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1257
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:2626
static constexpr uint32_t binds
Definition: IRMatch.h:1816
Floating point constants.
Definition: Expr.h:236
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
Definition: IRMatch.h:2905
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
Definition: IRMatch.h:1006
static constexpr bool foldable
Definition: IRMatch.h:1644
static constexpr bool foldable
Definition: IRMatch.h:2586
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:1883
A ternary operator.
Definition: IR.h:204
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
Definition: Type.h:416
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2267
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:1369
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:240
static constexpr bool canonical
Definition: IRMatch.h:2543
STL namespace.
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
Definition: IRMatch.h:1926
static constexpr bool foldable
Definition: IRMatch.h:2545
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
Definition: Type.h:434
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
Definition: Type.h:348
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1207
static constexpr uint32_t binds
Definition: IRMatch.h:2207
Expr make_zero(Type t)
Construct the representation of zero in the given type.
A linear ramp vector node.
Definition: IR.h:247
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
Definition: IRMatch.h:1039
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:667
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
Definition: IRMatch.h:2191
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:1747
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
Definition: IRMatch.h:1212
static Expr make(Operator op, Expr vec, int lanes)
static constexpr bool foldable
Definition: IRMatch.h:489
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:759
static const IRNodeType _node_type
Definition: IR.h:70
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:2461
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
Definition: IRMatch.h:1944
static constexpr bool canonical
Definition: IRMatch.h:2417
This file defines the class FunctionDAG, which is our representation of a Halide pipeline, and contains methods to using Halide&#39;s bounds tools to query properties of it.
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2511
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:2542
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1547
static constexpr uint32_t binds
Definition: IRMatch.h:304
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:453
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:947
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1564
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1602
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:2285
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
Definition: IRMatch.h:1804
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1477
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:516
static constexpr bool foldable
Definition: IRMatch.h:450
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1095
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:966
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition: IROperator.h:239
static constexpr bool foldable
Definition: IRMatch.h:2508
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
Definition: IRMatch.h:2998
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1983
unsigned integers
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1217
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:1818
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1117
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
Definition: IRMatch.h:2358
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2098
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition: Expr.h:25
A base class for expression nodes.
Definition: Expr.h:143
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:551
static constexpr bool canonical
Definition: IRMatch.h:2342
static constexpr bool foldable
Definition: IRMatch.h:673
static constexpr uint32_t binds
Definition: IRMatch.h:2337
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
Definition: IRMatch.h:1978
int slice_stride() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so...
Definition: IR.h:896
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:980
static Expr make(Expr value, int lanes)
static constexpr bool foldable
Definition: IRMatch.h:2215
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1292
static constexpr uint32_t binds
Definition: IRMatch.h:1744
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
Definition: IRMatch.h:1313
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
Definition: IRMatch.h:1932
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:395
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2259
static const IntImm * make(Type t, int64_t value)
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:364
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1582
halide_scalar_value_t bound_const[max_wild]
Definition: IRMatch.h:84
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1032
static constexpr uint32_t binds
Definition: IRMatch.h:2043
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:657
HALIDE_ALWAYS_INLINE bool is_vector() const
Is this type a vector type? (lanes() != 1).
Definition: Type.h:403
static constexpr bool foldable
Definition: IRMatch.h:2464
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
Definition: IRMatch.h:1950
static constexpr uint16_t special_values_mask
Definition: IRMatch.h:88
HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf< decltype(pattern_arg(a))>
Definition: IRMatch.h:2643
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1647
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:414
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:1621
static constexpr uint32_t binds
Definition: IRMatch.h:2152
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2250
static constexpr uint32_t binds
Definition: IRMatch.h:1619
HALIDE_ALWAYS_INLINE bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept
Definition: IRMatch.h:195
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1318
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
Definition: IRMatch.h:940
Expr abs(Expr a)
Returns the absolute value of a signed integer or floating-point expression.
std::vector< Expr > vectors
Definition: IR.h:842
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:2582
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition: Expr.h:113
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1267
static constexpr uint32_t binds
Definition: IRMatch.h:2091
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition: Expr.h:322
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
Definition: IRMatch.h:1661
A function call.
Definition: IR.h:490
HALIDE_ALWAYS_INLINE Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
Definition: IRMatch.h:1520
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1770
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
Definition: IRMatch.h:109
static constexpr uint32_t binds
Definition: IRMatch.h:238
#define HALIDE_NEVER_INLINE
Definition: HalideRuntime.h:50
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin< decltype(pattern_arg(args))... >
Definition: IRMatch.h:1534
uint8_t bits
The number of bits of precision of a single scalar value of this type.
static constexpr bool canonical
Definition: IRMatch.h:415
For optional debugging during codegen, use the debug class as follows:
Definition: Debug.h:49
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
constexpr int const_min(int a, int b)
Definition: IRMatch.h:1352
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1824
static const IRNodeType _node_type
Definition: IR.h:752
constexpr uint32_t bitwise_or_reduce()
Definition: IRMatch.h:1333
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1232
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:443
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
Definition: Type.h:342
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition: IRMatch.h:1066
static constexpr uint32_t binds
Definition: IRMatch.h:504
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1868
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:2286
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:2093
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1543
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:222
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
Definition: IRMatch.h:98
static constexpr uint32_t binds
Definition: IRMatch.h:2501
static constexpr bool foldable
Definition: IRMatch.h:2068
static constexpr uint32_t binds
Definition: IRMatch.h:2457
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1142
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:2541
HALIDE_NEVER_INLINE void build_replacement(After after)
Definition: IRMatch.h:2831
unsigned __INT32_TYPE__ uint32_t
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1560
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:2625
union halide_scalar_value_t::@4 u
static constexpr uint32_t binds
Definition: IRMatch.h:357
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:217
static constexpr uint32_t binds
Definition: IRMatch.h:2242
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:767
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2185
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
Definition: IRMatch.h:511
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
Definition: IRMatch.h:2078
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:471
To save stack space, the matcher objects are largely stateless and immutable.
Definition: IRMatch.h:82
static constexpr bool canonical
Definition: IRMatch.h:653
Not visible externally, similar to &#39;static&#39; linkage in C.
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
Definition: IRMatch.h:121
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1192
static constexpr bool canonical
Definition: IRMatch.h:2213
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
Definition: IRMatch.h:115
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:2212
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
Definition: IRMatch.h:160
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
Definition: IRMatch.h:588
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1013
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1044
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1059
static constexpr uint32_t binds
Definition: IRMatch.h:411
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1387
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:933
No name mangling.
signed __INT64_TYPE__ int64_t
constexpr int max_wild
Definition: IRMatch.h:74
static constexpr bool canonical
Definition: IRMatch.h:2379
static Expr make(Expr a)
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1752
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
Definition: IRMatch.h:1162
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
Definition: Type.h:410
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1640
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits...
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:2504
#define internal_error
Definition: Errors.h:23
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:306
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2548
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2589
static constexpr uint32_t binds
Definition: IRMatch.h:207
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:1682
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2422
#define HALIDE_ALWAYS_INLINE
Definition: HalideRuntime.h:49
static constexpr uint32_t mask
Definition: IRMatch.h:146
static const FloatImm * make(Type t, double value)
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:758
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:1368
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1968
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
Definition: IRMatch.h:534
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:2245
static constexpr bool foldable
Definition: IRMatch.h:1708
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
Definition: IRMatch.h:1731
static constexpr uint32_t binds
Definition: IRMatch.h:1960
static const IRNodeType _node_type
Definition: IR.h:35
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2632
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition: IRMatch.h:2159
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
Definition: IRMatch.h:1262
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:2211
HALIDE_ALWAYS_INLINE auto operator &&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1308
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1635
static constexpr bool foldable
Definition: IRMatch.h:548
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition: IRMatch.h:2310
A runtime tag for a type in the halide type system.
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:360
constexpr IRNodeType StrongestExprNodeType
Definition: Expr.h:81
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
Definition: IRMatch.h:2660
static constexpr uint32_t binds
Definition: IRMatch.h:2412
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt) ...
auto shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1594
static constexpr bool foldable
Definition: IRMatch.h:2629
static constexpr bool canonical
Definition: IRMatch.h:2584
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:1819
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1417
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
Definition: IRMatch.h:229
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1787
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:413
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:307
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes...
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:241
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:289
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:647
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2064
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1556
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:342
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1610
static constexpr bool canonical
Definition: IRMatch.h:2627
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
Definition: Type.h:428
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp< Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1088
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:2059
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
Definition: IRMatch.h:1938
unsigned __INT16_TYPE__ uint16_t
static const IRNodeType _node_type
Definition: IR.h:253
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin< decltype(pattern_arg(a))>
Definition: IRMatch.h:1572
floating point numbers in the bfloat format
static constexpr uint32_t binds
Definition: IRMatch.h:2622
Types in the halide type system.
Definition: Type.h:276
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:418
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits=0, int lanes=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
Definition: IRMatch.h:2478
auto halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1578
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1626
constexpr bool commutative(IRNodeType t)
Definition: IRMatch.h:627
bool is_intrinsic() const
Definition: IR.h:707
static constexpr uint32_t binds
Definition: IRMatch.h:1366
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:2046
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
Definition: IRMatch.h:1287
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
Definition: IRMatch.h:1382
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
Definition: IRMatch.h:1375
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
Definition: IRMatch.h:1137
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1282
HALIDE_ALWAYS_INLINE SliceOp(Vec v, Base b, Stride s, Lanes l)
Definition: IRMatch.h:2127
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:2505
halide_type_t bound_const_type[max_wild]
Definition: IRMatch.h:90
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:2377
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:506
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:385
static constexpr uint32_t binds
Definition: IRMatch.h:1679
static constexpr bool foldable
Definition: IRMatch.h:2264
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes...
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:999
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
Definition: IRMatch.h:266
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
Definition: IRMatch.h:2228
static constexpr bool canonical
Definition: IRMatch.h:508
static Expr make(Expr a, Expr b)
static constexpr bool canonical
Definition: IRMatch.h:2287
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:676
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1073
static constexpr uint16_t signed_integer_overflow
Definition: IRMatch.h:87
Expr likely_if_innermost(Expr e)
Equivalent to likely, but only triggers a loop partitioning if found in an innermost loop...
static constexpr bool foldable
Definition: IRMatch.h:783
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:211
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:2460
bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2218
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1598
bool sub_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits...
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition: Expr.h:316
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1606
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:2340
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1908
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2112
static constexpr bool canonical
Definition: IRMatch.h:2506
static constexpr bool foldable
Definition: IRMatch.h:2419
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
Definition: IRMatch.h:1697
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
Definition: IRMatch.h:1412
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
Definition: IRMatch.h:579
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits=0, int lanes=0) noexcept -> IsInt< decltype(pattern_arg(a))>
Definition: IRMatch.h:2433
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1132
static const IRNodeType _node_type
Definition: IR.h:265
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1182
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1157
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1586
static constexpr uint32_t binds
Definition: IRMatch.h:2374
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1539
static constexpr uint32_t binds
Definition: IRMatch.h:469
auto shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1590
static constexpr bool canonical
Definition: IRMatch.h:2462
static Expr make(Expr base, Expr stride, int lanes)
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
Definition: IRMatch.h:1237
static constexpr uint32_t binds
Definition: IRMatch.h:645
static Expr make_slice(Expr vector, int begin, int stride, int size)
Convenience constructor for making a shuffle representing a contiguous subset of a vector...
static constexpr uint32_t binds
Definition: IRMatch.h:756
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1568
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2347
HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept -> SliceOp< decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))>
Definition: IRMatch.h:2142
unsigned __INT64_TYPE__ uint64_t
static constexpr bool foldable
Definition: IRMatch.h:2124
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1167
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition: IRMatch.h:719
HALIDE_ALWAYS_INLINE MatcherState() noexcept
Definition: IRMatch.h:134
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
Definition: IRMatch.h:93
HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition: IRMatch.h:2316
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
Definition: IRMatch.h:1399
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
Definition: IRMatch.h:2032
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits...
T div_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition: IROperator.h:260
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:648
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
Definition: IRMatch.h:1187
static constexpr uint32_t binds
Definition: IRMatch.h:1881
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:2045
static constexpr bool foldable
Definition: IRMatch.h:2381
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1839
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:472
static constexpr bool foldable
Definition: IRMatch.h:2182
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
Definition: IRMatch.h:2957
Expr likely(Expr e)
Expressions tagged with this intrinsic are considered to be part of the steady state of some loop wit...
signed __INT32_TYPE__ int32_t
static constexpr bool canonical
Definition: IRMatch.h:2095
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:1622
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1846
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:1746
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
Definition: IRMatch.h:973
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2050
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
Definition: IRMatch.h:149
signed integers
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
Definition: IRMatch.h:2606
int slice_begin() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so...
Definition: IR.h:893
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:2415
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
Definition: IRMatch.h:2395
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
Definition: IRMatch.h:2934
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1711
static constexpr bool canonical
Definition: IRMatch.h:1623
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:2094
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
Definition: IRMatch.h:2522
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2296
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:279
static constexpr IRNodeType max_node_type
Definition: IRMatch.h:2341
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
Definition: IRMatch.h:2826
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1888
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1687
static constexpr IRNodeType min_node_type
Definition: IRMatch.h:210
Construct a new vector by taking elements from another sequence of vectors.
Definition: IR.h:841
static constexpr bool foldable
Definition: IRMatch.h:2344
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
Definition: IRMatch.h:1408