Halide  17.0.2
Halide compiler and libraries
IRMatch.h
Go to the documentation of this file.
1 #ifndef HALIDE_IR_MATCH_H
2 #define HALIDE_IR_MATCH_H
3 
4 /** \file
5  * Defines a method to match a fragment of IR against a pattern containing wildcards
6  */
7 
8 #include <map>
9 #include <random>
10 #include <set>
11 #include <vector>
12 
13 #include "IR.h"
14 #include "IREquality.h"
15 #include "IROperator.h"
16 
17 namespace Halide {
18 namespace Internal {
19 
20 /** Does the first expression have the same structure as the second?
21  * Variables in the first expression with the name * are interpreted
22  * as wildcards, and their matching equivalent in the second
23  * expression is placed in the vector give as the third argument.
24  * Wildcards require the types to match. For the type bits and width,
25  * a 0 indicates "match anything". So an Int(8, 0) will match 8-bit
26  * integer vectors of any width (including scalars), and a UInt(0, 0)
27  * will match any unsigned integer type.
28  *
29  * For example:
30  \code
31  Expr x = Variable::make(Int(32), "*");
32  match(x + x, 3 + (2*k), result)
33  \endcode
34  * should return true, and set result[0] to 3 and
35  * result[1] to 2*k.
36  */
37 bool expr_match(const Expr &pattern, const Expr &expr, std::vector<Expr> &result);
38 
39 /** Does the first expression have the same structure as the second?
40  * Variables are matched consistently. The first time a variable is
41  * matched, it assumes the value of the matching part of the second
42  * expression. Subsequent matches must be equal to the first match.
43  *
44  * For example:
45  \code
46  Var x("x"), y("y");
47  match(x*(x + y), a*(a + b), result)
48  \endcode
49  * should return true, and set result["x"] = a, and result["y"] = b.
50  */
51 bool expr_match(const Expr &pattern, const Expr &expr, std::map<std::string, Expr> &result);
52 
53 /** Rewrite the expression x to have `lanes` lanes. This is useful
54  * for substituting the results of expr_match into a pattern expression. */
55 Expr with_lanes(const Expr &x, int lanes);
56 
58 
59 /** An alternative template-metaprogramming approach to expression
60  * matching. Potentially more efficient. We lift the expression
61  * pattern into a type, and then use force-inlined functions to
62  * generate efficient matching and reconstruction code for any
63  * pattern. Pattern elements are either one of the classes in the
64  * namespace IRMatcher, or are non-null Exprs (represented as
65  * BaseExprNode &).
66  *
67  * Pattern elements that are fully specified by their pattern can be
68  * built into an expression using the make method. Some patterns,
69  * such as a broadcast that matches any number of lanes, don't have
70  * enough information to recreate an Expr.
71  */
72 namespace IRMatcher {
73 
74 constexpr int max_wild = 6;
75 
76 static const halide_type_t i64_type = {halide_type_int, 64, 1};
77 
78 /** To save stack space, the matcher objects are largely stateless and
79  * immutable. This state object is built up during matching and then
80  * consumed when constructing a replacement Expr.
81  */
82 struct MatcherState {
85 
86  // values of the lanes field with special meaning.
87  static constexpr uint16_t signed_integer_overflow = 0x8000;
88  static constexpr uint16_t special_values_mask = 0x8000; // currently only one
89 
91 
93  void set_binding(int i, const BaseExprNode &n) noexcept {
94  bindings[i] = &n;
95  }
96 
98  const BaseExprNode *get_binding(int i) const noexcept {
99  return bindings[i];
100  }
101 
103  void set_bound_const(int i, int64_t s, halide_type_t t) noexcept {
104  bound_const[i].u.i64 = s;
105  bound_const_type[i] = t;
106  }
107 
109  void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept {
110  bound_const[i].u.u64 = u;
111  bound_const_type[i] = t;
112  }
113 
115  void set_bound_const(int i, double f, halide_type_t t) noexcept {
116  bound_const[i].u.f64 = f;
117  bound_const_type[i] = t;
118  }
119 
121  void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept {
122  bound_const[i] = val;
123  bound_const_type[i] = t;
124  }
125 
127  void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept {
128  val = bound_const[i];
129  type = bound_const_type[i];
130  }
131 
133  // NOLINTNEXTLINE(modernize-use-equals-default): Can't use `= default`; clang-tidy complains about noexcept mismatch
134  MatcherState() noexcept {
135  }
136 };
137 
138 template<typename T,
139  typename = typename std::remove_reference<T>::type::pattern_tag>
141  struct type {};
142 };
143 
144 template<typename T>
145 struct bindings {
146  constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
147 };
148 
151  ty.lanes &= ~MatcherState::special_values_mask;
153  return make_signed_integer_overflow(ty);
154  }
155  // unreachable
156  return Expr();
157 }
158 
161  halide_type_t scalar_type = ty;
162  if (scalar_type.lanes & MatcherState::special_values_mask) {
163  return make_const_special_expr(scalar_type);
164  }
165 
166  const int lanes = scalar_type.lanes;
167  scalar_type.lanes = 1;
168 
169  Expr e;
170  switch (scalar_type.code) {
171  case halide_type_int:
172  e = IntImm::make(scalar_type, val.u.i64);
173  break;
174  case halide_type_uint:
175  e = UIntImm::make(scalar_type, val.u.u64);
176  break;
177  case halide_type_float:
178  case halide_type_bfloat:
179  e = FloatImm::make(scalar_type, val.u.f64);
180  break;
181  default:
182  // Unreachable
183  return Expr();
184  }
185  if (lanes > 1) {
186  e = Broadcast::make(e, lanes);
187  }
188  return e;
189 }
190 
191 bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept;
192 
193 // A fast version of expression equality that assumes a well-typed non-null expression tree.
195 bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept {
196  // Early out
197  return (&a == &b) ||
198  ((a.type == b.type) &&
199  (a.node_type == b.node_type) &&
200  equal_helper(a, b));
201 }
202 
203 // A pattern that matches a specific expression
204 struct SpecificExpr {
205  struct pattern_tag {};
206 
207  constexpr static uint32_t binds = 0;
208 
209  // What is the weakest and strongest IR node this could possibly be
212  constexpr static bool canonical = true;
213 
215 
216  template<uint32_t bound>
217  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
218  return equal(expr, e);
219  }
220 
222  Expr make(MatcherState &state, halide_type_t type_hint) const {
223  return Expr(&expr);
224  }
225 
226  constexpr static bool foldable = false;
227 };
228 
229 inline std::ostream &operator<<(std::ostream &s, const SpecificExpr &e) {
230  s << Expr(&e.expr);
231  return s;
232 }
233 
234 template<int i>
235 struct WildConstInt {
236  struct pattern_tag {};
237 
238  constexpr static uint32_t binds = 1 << i;
239 
242  constexpr static bool canonical = true;
243 
244  template<uint32_t bound>
245  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
246  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
247  const BaseExprNode *op = &e;
248  if (op->node_type == IRNodeType::Broadcast) {
249  op = ((const Broadcast *)op)->value.get();
250  }
251  if (op->node_type != IRNodeType::IntImm) {
252  return false;
253  }
254  int64_t value = ((const IntImm *)op)->value;
255  if (bound & binds) {
257  halide_type_t type;
258  state.get_bound_const(i, val, type);
259  return (halide_type_t)e.type == type && value == val.u.i64;
260  }
261  state.set_bound_const(i, value, e.type);
262  return true;
263  }
264 
265  template<uint32_t bound>
266  HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept {
267  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
268  if (bound & binds) {
270  halide_type_t type;
271  state.get_bound_const(i, val, type);
272  return type == i64_type && value == val.u.i64;
273  }
274  state.set_bound_const(i, value, i64_type);
275  return true;
276  }
277 
279  Expr make(MatcherState &state, halide_type_t type_hint) const {
281  halide_type_t type;
282  state.get_bound_const(i, val, type);
283  return make_const_expr(val, type);
284  }
285 
286  constexpr static bool foldable = true;
287 
290  state.get_bound_const(i, val, ty);
291  }
292 };
293 
294 template<int i>
295 std::ostream &operator<<(std::ostream &s, const WildConstInt<i> &c) {
296  s << "ci" << i;
297  return s;
298 }
299 
300 template<int i>
302  struct pattern_tag {};
303 
304  constexpr static uint32_t binds = 1 << i;
305 
308  constexpr static bool canonical = true;
309 
310  template<uint32_t bound>
311  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
312  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
313  const BaseExprNode *op = &e;
314  if (op->node_type == IRNodeType::Broadcast) {
315  op = ((const Broadcast *)op)->value.get();
316  }
317  if (op->node_type != IRNodeType::UIntImm) {
318  return false;
319  }
320  uint64_t value = ((const UIntImm *)op)->value;
321  if (bound & binds) {
323  halide_type_t type;
324  state.get_bound_const(i, val, type);
325  return (halide_type_t)e.type == type && value == val.u.u64;
326  }
327  state.set_bound_const(i, value, e.type);
328  return true;
329  }
330 
332  Expr make(MatcherState &state, halide_type_t type_hint) const {
334  halide_type_t type;
335  state.get_bound_const(i, val, type);
336  return make_const_expr(val, type);
337  }
338 
339  constexpr static bool foldable = true;
340 
342  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
343  state.get_bound_const(i, val, ty);
344  }
345 };
346 
347 template<int i>
348 std::ostream &operator<<(std::ostream &s, const WildConstUInt<i> &c) {
349  s << "cu" << i;
350  return s;
351 }
352 
353 template<int i>
355  struct pattern_tag {};
356 
357  constexpr static uint32_t binds = 1 << i;
358 
361  constexpr static bool canonical = true;
362 
363  template<uint32_t bound>
364  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
365  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
366  const BaseExprNode *op = &e;
367  if (op->node_type == IRNodeType::Broadcast) {
368  op = ((const Broadcast *)op)->value.get();
369  }
370  if (op->node_type != IRNodeType::FloatImm) {
371  return false;
372  }
373  double value = ((const FloatImm *)op)->value;
374  if (bound & binds) {
376  halide_type_t type;
377  state.get_bound_const(i, val, type);
378  return (halide_type_t)e.type == type && value == val.u.f64;
379  }
380  state.set_bound_const(i, value, e.type);
381  return true;
382  }
383 
385  Expr make(MatcherState &state, halide_type_t type_hint) const {
387  halide_type_t type;
388  state.get_bound_const(i, val, type);
389  return make_const_expr(val, type);
390  }
391 
392  constexpr static bool foldable = true;
393 
395  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
396  state.get_bound_const(i, val, ty);
397  }
398 };
399 
400 template<int i>
401 std::ostream &operator<<(std::ostream &s, const WildConstFloat<i> &c) {
402  s << "cf" << i;
403  return s;
404 }
405 
406 // Matches and binds to any constant Expr. Does not support constant-folding.
407 template<int i>
408 struct WildConst {
409  struct pattern_tag {};
410 
411  constexpr static uint32_t binds = 1 << i;
412 
415  constexpr static bool canonical = true;
416 
417  template<uint32_t bound>
418  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
419  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
420  const BaseExprNode *op = &e;
421  if (op->node_type == IRNodeType::Broadcast) {
422  op = ((const Broadcast *)op)->value.get();
423  }
424  switch (op->node_type) {
425  case IRNodeType::IntImm:
426  return WildConstInt<i>().template match<bound>(e, state);
427  case IRNodeType::UIntImm:
428  return WildConstUInt<i>().template match<bound>(e, state);
430  return WildConstFloat<i>().template match<bound>(e, state);
431  default:
432  return false;
433  }
434  }
435 
436  template<uint32_t bound>
437  HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept {
438  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
439  return WildConstInt<i>().template match<bound>(e, state);
440  }
441 
443  Expr make(MatcherState &state, halide_type_t type_hint) const {
445  halide_type_t type;
446  state.get_bound_const(i, val, type);
447  return make_const_expr(val, type);
448  }
449 
450  constexpr static bool foldable = true;
451 
453  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
454  state.get_bound_const(i, val, ty);
455  }
456 };
457 
458 template<int i>
459 std::ostream &operator<<(std::ostream &s, const WildConst<i> &c) {
460  s << "c" << i;
461  return s;
462 }
463 
464 // Matches and binds to any Expr
465 template<int i>
466 struct Wild {
467  struct pattern_tag {};
468 
469  constexpr static uint32_t binds = 1 << (i + 16);
470 
473  constexpr static bool canonical = true;
474 
475  template<uint32_t bound>
476  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
477  if (bound & binds) {
478  return equal(*state.get_binding(i), e);
479  }
480  state.set_binding(i, e);
481  return true;
482  }
483 
485  Expr make(MatcherState &state, halide_type_t type_hint) const {
486  return state.get_binding(i);
487  }
488 
489  constexpr static bool foldable = false;
490 };
491 
492 template<int i>
493 std::ostream &operator<<(std::ostream &s, const Wild<i> &op) {
494  s << "_" << i;
495  return s;
496 }
497 
498 // Matches a specific constant or broadcast of that constant. The
499 // constant must be representable as an int64_t.
500 struct IntLiteral {
501  struct pattern_tag {};
503 
504  constexpr static uint32_t binds = 0;
505 
508  constexpr static bool canonical = true;
509 
511  explicit IntLiteral(int64_t v)
512  : v(v) {
513  }
514 
515  template<uint32_t bound>
516  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
517  const BaseExprNode *op = &e;
518  if (e.node_type == IRNodeType::Broadcast) {
519  op = ((const Broadcast *)op)->value.get();
520  }
521  switch (op->node_type) {
522  case IRNodeType::IntImm:
523  return ((const IntImm *)op)->value == (int64_t)v;
524  case IRNodeType::UIntImm:
525  return ((const UIntImm *)op)->value == (uint64_t)v;
527  return ((const FloatImm *)op)->value == (double)v;
528  default:
529  return false;
530  }
531  }
532 
533  template<uint32_t bound>
534  HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept {
535  return v == val;
536  }
537 
538  template<uint32_t bound>
539  HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept {
540  return v == b.v;
541  }
542 
544  Expr make(MatcherState &state, halide_type_t type_hint) const {
545  return make_const(type_hint, v);
546  }
547 
548  constexpr static bool foldable = true;
549 
551  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
552  // Assume type is already correct
553  switch (ty.code) {
554  case halide_type_int:
555  val.u.i64 = v;
556  break;
557  case halide_type_uint:
558  val.u.u64 = (uint64_t)v;
559  break;
560  case halide_type_float:
561  case halide_type_bfloat:
562  val.u.f64 = (double)v;
563  break;
564  default:
565  // Unreachable
566  ;
567  }
568  }
569 };
570 
572  return t.v;
573 }
574 
575 // Convert a provided pattern, expr, or constant int into the internal
576 // representation we use in the matcher trees.
577 template<typename T,
578  typename = typename std::decay<T>::type::pattern_tag>
580  return t;
581 }
584  return IntLiteral{x};
585 }
586 
587 template<typename T>
589  static_assert(!std::is_same<typename std::decay<T>::type, Expr>::value || std::is_lvalue_reference<T>::value,
590  "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
591 }
592 
594  return {*e.get()};
595 }
596 
597 // Helpers to deref SpecificExprs to const BaseExprNode & rather than
598 // passing them by value anywhere (incurring lots of refcounting)
599 template<typename T,
600  // T must be a pattern node
601  typename = typename std::decay<T>::type::pattern_tag,
602  // But T may not be SpecificExpr
603  typename = typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
605  return t;
606 }
607 
609 const BaseExprNode &unwrap(const SpecificExpr &e) {
610  return e.expr;
611 }
612 
613 inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) {
614  s << op.v;
615  return s;
616 }
617 
618 template<typename Op>
620 
621 template<typename Op>
623 
624 template<typename Op>
625 double constant_fold_bin_op(halide_type_t &, double, double) noexcept;
626 
627 constexpr bool commutative(IRNodeType t) {
628  return (t == IRNodeType::Add ||
629  t == IRNodeType::Mul ||
630  t == IRNodeType::And ||
631  t == IRNodeType::Or ||
632  t == IRNodeType::Min ||
633  t == IRNodeType::Max ||
634  t == IRNodeType::EQ ||
635  t == IRNodeType::NE);
636 }
637 
638 // Matches one of the binary operators
639 template<typename Op, typename A, typename B>
640 struct BinOp {
641  struct pattern_tag {};
642  A a;
643  B b;
644 
646 
647  constexpr static IRNodeType min_node_type = Op::_node_type;
648  constexpr static IRNodeType max_node_type = Op::_node_type;
649 
650  // For commutative bin ops, we expect the weaker IR node type on
651  // the right. That is, for the rule to be canonical it must be
652  // possible that A is at least as strong as B.
653  constexpr static bool canonical =
654  A::canonical && B::canonical && (!commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
655 
656  template<uint32_t bound>
657  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
658  if (e.node_type != Op::_node_type) {
659  return false;
660  }
661  const Op &op = (const Op &)e;
662  return (a.template match<bound>(*op.a.get(), state) &&
663  b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
664  }
665 
666  template<uint32_t bound, typename Op2, typename A2, typename B2>
667  HALIDE_ALWAYS_INLINE bool match(const BinOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
668  return (std::is_same<Op, Op2>::value &&
669  a.template match<bound>(unwrap(op.a), state) &&
670  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
671  }
672 
673  constexpr static bool foldable = A::foldable && B::foldable;
674 
676  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
677  halide_scalar_value_t val_a, val_b;
678  if (std::is_same<A, IntLiteral>::value) {
679  b.make_folded_const(val_b, ty, state);
680  if ((std::is_same<Op, And>::value && val_b.u.u64 == 0) ||
681  (std::is_same<Op, Or>::value && val_b.u.u64 == 1)) {
682  // Short circuit
683  val = val_b;
684  return;
685  }
686  const uint16_t l = ty.lanes;
687  a.make_folded_const(val_a, ty, state);
688  ty.lanes |= l; // Make sure the overflow bits are sticky
689  } else {
690  a.make_folded_const(val_a, ty, state);
691  if ((std::is_same<Op, And>::value && val_a.u.u64 == 0) ||
692  (std::is_same<Op, Or>::value && val_a.u.u64 == 1)) {
693  // Short circuit
694  val = val_a;
695  return;
696  }
697  const uint16_t l = ty.lanes;
698  b.make_folded_const(val_b, ty, state);
699  ty.lanes |= l;
700  }
701  switch (ty.code) {
702  case halide_type_int:
703  val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.u.i64, val_b.u.i64);
704  break;
705  case halide_type_uint:
706  val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.u.u64, val_b.u.u64);
707  break;
708  case halide_type_float:
709  case halide_type_bfloat:
710  val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.u.f64, val_b.u.f64);
711  break;
712  default:
713  // unreachable
714  ;
715  }
716  }
717 
719  Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
720  Expr ea, eb;
721  if (std::is_same<A, IntLiteral>::value) {
722  eb = b.make(state, type_hint);
723  ea = a.make(state, eb.type());
724  } else {
725  ea = a.make(state, type_hint);
726  eb = b.make(state, ea.type());
727  }
728  // We sometimes mix vectors and scalars in the rewrite rules,
729  // so insert a broadcast if necessary.
730  if (ea.type().is_vector() && !eb.type().is_vector()) {
731  eb = Broadcast::make(eb, ea.type().lanes());
732  }
733  if (eb.type().is_vector() && !ea.type().is_vector()) {
734  ea = Broadcast::make(ea, eb.type().lanes());
735  }
736  return Op::make(std::move(ea), std::move(eb));
737  }
738 };
739 
740 template<typename Op>
742 
743 template<typename Op>
745 
746 template<typename Op>
747 uint64_t constant_fold_cmp_op(double, double) noexcept;
748 
749 // Matches one of the comparison operators
750 template<typename Op, typename A, typename B>
751 struct CmpOp {
752  struct pattern_tag {};
753  A a;
754  B b;
755 
757 
758  constexpr static IRNodeType min_node_type = Op::_node_type;
759  constexpr static IRNodeType max_node_type = Op::_node_type;
760  constexpr static bool canonical = (A::canonical &&
761  B::canonical &&
762  (!commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
763  (Op::_node_type != IRNodeType::GE) &&
764  (Op::_node_type != IRNodeType::GT));
765 
766  template<uint32_t bound>
767  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
768  if (e.node_type != Op::_node_type) {
769  return false;
770  }
771  const Op &op = (const Op &)e;
772  return (a.template match<bound>(*op.a.get(), state) &&
773  b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
774  }
775 
776  template<uint32_t bound, typename Op2, typename A2, typename B2>
777  HALIDE_ALWAYS_INLINE bool match(const CmpOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
778  return (std::is_same<Op, Op2>::value &&
779  a.template match<bound>(unwrap(op.a), state) &&
780  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
781  }
782 
783  constexpr static bool foldable = A::foldable && B::foldable;
784 
786  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
787  halide_scalar_value_t val_a, val_b;
788  // If one side is an untyped const, evaluate the other side first to get a type hint.
789  if (std::is_same<A, IntLiteral>::value) {
790  b.make_folded_const(val_b, ty, state);
791  const uint16_t l = ty.lanes;
792  a.make_folded_const(val_a, ty, state);
793  ty.lanes |= l;
794  } else {
795  a.make_folded_const(val_a, ty, state);
796  const uint16_t l = ty.lanes;
797  b.make_folded_const(val_b, ty, state);
798  ty.lanes |= l;
799  }
800  switch (ty.code) {
801  case halide_type_int:
802  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.i64, val_b.u.i64);
803  break;
804  case halide_type_uint:
805  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.u64, val_b.u.u64);
806  break;
807  case halide_type_float:
808  case halide_type_bfloat:
809  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.f64, val_b.u.f64);
810  break;
811  default:
812  // unreachable
813  ;
814  }
815  ty.code = halide_type_uint;
816  ty.bits = 1;
817  }
818 
820  Expr make(MatcherState &state, halide_type_t type_hint) const {
821  // If one side is an untyped const, evaluate the other side first to get a type hint.
822  Expr ea, eb;
823  if (std::is_same<A, IntLiteral>::value) {
824  eb = b.make(state, {});
825  ea = a.make(state, eb.type());
826  } else {
827  ea = a.make(state, {});
828  eb = b.make(state, ea.type());
829  }
830  // We sometimes mix vectors and scalars in the rewrite rules,
831  // so insert a broadcast if necessary.
832  if (ea.type().is_vector() && !eb.type().is_vector()) {
833  eb = Broadcast::make(eb, ea.type().lanes());
834  }
835  if (eb.type().is_vector() && !ea.type().is_vector()) {
836  ea = Broadcast::make(ea, eb.type().lanes());
837  }
838  return Op::make(std::move(ea), std::move(eb));
839  }
840 };
841 
842 template<typename A, typename B>
843 std::ostream &operator<<(std::ostream &s, const BinOp<Add, A, B> &op) {
844  s << "(" << op.a << " + " << op.b << ")";
845  return s;
846 }
847 
848 template<typename A, typename B>
849 std::ostream &operator<<(std::ostream &s, const BinOp<Sub, A, B> &op) {
850  s << "(" << op.a << " - " << op.b << ")";
851  return s;
852 }
853 
854 template<typename A, typename B>
855 std::ostream &operator<<(std::ostream &s, const BinOp<Mul, A, B> &op) {
856  s << "(" << op.a << " * " << op.b << ")";
857  return s;
858 }
859 
860 template<typename A, typename B>
861 std::ostream &operator<<(std::ostream &s, const BinOp<Div, A, B> &op) {
862  s << "(" << op.a << " / " << op.b << ")";
863  return s;
864 }
865 
866 template<typename A, typename B>
867 std::ostream &operator<<(std::ostream &s, const BinOp<And, A, B> &op) {
868  s << "(" << op.a << " && " << op.b << ")";
869  return s;
870 }
871 
872 template<typename A, typename B>
873 std::ostream &operator<<(std::ostream &s, const BinOp<Or, A, B> &op) {
874  s << "(" << op.a << " || " << op.b << ")";
875  return s;
876 }
877 
878 template<typename A, typename B>
879 std::ostream &operator<<(std::ostream &s, const BinOp<Min, A, B> &op) {
880  s << "min(" << op.a << ", " << op.b << ")";
881  return s;
882 }
883 
884 template<typename A, typename B>
885 std::ostream &operator<<(std::ostream &s, const BinOp<Max, A, B> &op) {
886  s << "max(" << op.a << ", " << op.b << ")";
887  return s;
888 }
889 
890 template<typename A, typename B>
891 std::ostream &operator<<(std::ostream &s, const CmpOp<LE, A, B> &op) {
892  s << "(" << op.a << " <= " << op.b << ")";
893  return s;
894 }
895 
896 template<typename A, typename B>
897 std::ostream &operator<<(std::ostream &s, const CmpOp<LT, A, B> &op) {
898  s << "(" << op.a << " < " << op.b << ")";
899  return s;
900 }
901 
902 template<typename A, typename B>
903 std::ostream &operator<<(std::ostream &s, const CmpOp<GE, A, B> &op) {
904  s << "(" << op.a << " >= " << op.b << ")";
905  return s;
906 }
907 
908 template<typename A, typename B>
909 std::ostream &operator<<(std::ostream &s, const CmpOp<GT, A, B> &op) {
910  s << "(" << op.a << " > " << op.b << ")";
911  return s;
912 }
913 
914 template<typename A, typename B>
915 std::ostream &operator<<(std::ostream &s, const CmpOp<EQ, A, B> &op) {
916  s << "(" << op.a << " == " << op.b << ")";
917  return s;
918 }
919 
920 template<typename A, typename B>
921 std::ostream &operator<<(std::ostream &s, const CmpOp<NE, A, B> &op) {
922  s << "(" << op.a << " != " << op.b << ")";
923  return s;
924 }
925 
926 template<typename A, typename B>
927 std::ostream &operator<<(std::ostream &s, const BinOp<Mod, A, B> &op) {
928  s << "(" << op.a << " % " << op.b << ")";
929  return s;
930 }
931 
932 template<typename A, typename B>
933 HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp<Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
934  assert_is_lvalue_if_expr<A>();
935  assert_is_lvalue_if_expr<B>();
936  return {pattern_arg(a), pattern_arg(b)};
937 }
938 
939 template<typename A, typename B>
940 HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b)) {
941  assert_is_lvalue_if_expr<A>();
942  assert_is_lvalue_if_expr<B>();
943  return IRMatcher::operator+(a, b);
944 }
945 
946 template<>
948  t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
949  int dead_bits = 64 - t.bits;
950  // Drop the high bits then sign-extend them back
951  return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits;
952 }
953 
954 template<>
956  uint64_t ones = (uint64_t)(-1);
957  return (a + b) & (ones >> (64 - t.bits));
958 }
959 
960 template<>
961 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Add>(halide_type_t &t, double a, double b) noexcept {
962  return a + b;
963 }
964 
965 template<typename A, typename B>
966 HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp<Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
967  assert_is_lvalue_if_expr<A>();
968  assert_is_lvalue_if_expr<B>();
969  return {pattern_arg(a), pattern_arg(b)};
970 }
971 
972 template<typename A, typename B>
973 HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b)) {
974  assert_is_lvalue_if_expr<A>();
975  assert_is_lvalue_if_expr<B>();
976  return IRMatcher::operator-(a, b);
977 }
978 
979 template<>
981  t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
982  // Drop the high bits then sign-extend them back
983  int dead_bits = 64 - t.bits;
984  return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits;
985 }
986 
987 template<>
989  uint64_t ones = (uint64_t)(-1);
990  return (a - b) & (ones >> (64 - t.bits));
991 }
992 
993 template<>
994 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Sub>(halide_type_t &t, double a, double b) noexcept {
995  return a - b;
996 }
997 
998 template<typename A, typename B>
999 HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp<Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1000  assert_is_lvalue_if_expr<A>();
1001  assert_is_lvalue_if_expr<B>();
1002  return {pattern_arg(a), pattern_arg(b)};
1003 }
1004 
1005 template<typename A, typename B>
1006 HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b)) {
1007  assert_is_lvalue_if_expr<A>();
1008  assert_is_lvalue_if_expr<B>();
1009  return IRMatcher::operator*(a, b);
1010 }
1011 
1012 template<>
1014  t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
1015  int dead_bits = 64 - t.bits;
1016  // Drop the high bits then sign-extend them back
1017  return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits;
1018 }
1019 
1020 template<>
1022  uint64_t ones = (uint64_t)(-1);
1023  return (a * b) & (ones >> (64 - t.bits));
1024 }
1025 
1026 template<>
1027 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mul>(halide_type_t &t, double a, double b) noexcept {
1028  return a * b;
1029 }
1030 
1031 template<typename A, typename B>
1032 HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp<Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1033  assert_is_lvalue_if_expr<A>();
1034  assert_is_lvalue_if_expr<B>();
1035  return {pattern_arg(a), pattern_arg(b)};
1036 }
1037 
1038 template<typename A, typename B>
1039 HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b)) {
1040  return IRMatcher::operator/(a, b);
1041 }
1042 
1043 template<>
1045  return div_imp(a, b);
1046 }
1047 
1048 template<>
1050  return div_imp(a, b);
1051 }
1052 
1053 template<>
1054 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Div>(halide_type_t &t, double a, double b) noexcept {
1055  return div_imp(a, b);
1056 }
1057 
1058 template<typename A, typename B>
1059 HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp<Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1060  assert_is_lvalue_if_expr<A>();
1061  assert_is_lvalue_if_expr<B>();
1062  return {pattern_arg(a), pattern_arg(b)};
1063 }
1064 
1065 template<typename A, typename B>
1066 HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b)) {
1067  assert_is_lvalue_if_expr<A>();
1068  assert_is_lvalue_if_expr<B>();
1069  return IRMatcher::operator%(a, b);
1070 }
1071 
1072 template<>
1074  return mod_imp(a, b);
1075 }
1076 
1077 template<>
1079  return mod_imp(a, b);
1080 }
1081 
1082 template<>
1083 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mod>(halide_type_t &t, double a, double b) noexcept {
1084  return mod_imp(a, b);
1085 }
1086 
1087 template<typename A, typename B>
1088 HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp<Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1089  assert_is_lvalue_if_expr<A>();
1090  assert_is_lvalue_if_expr<B>();
1091  return {pattern_arg(a), pattern_arg(b)};
1092 }
1093 
1094 template<>
1096  return std::min(a, b);
1097 }
1098 
1099 template<>
1101  return std::min(a, b);
1102 }
1103 
1104 template<>
1105 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Min>(halide_type_t &t, double a, double b) noexcept {
1106  return std::min(a, b);
1107 }
1108 
1109 template<typename A, typename B>
1110 HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp<Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1111  assert_is_lvalue_if_expr<A>();
1112  assert_is_lvalue_if_expr<B>();
1113  return {pattern_arg(std::forward<A>(a)), pattern_arg(std::forward<B>(b))};
1114 }
1115 
1116 template<>
1118  return std::max(a, b);
1119 }
1120 
1121 template<>
1123  return std::max(a, b);
1124 }
1125 
1126 template<>
1127 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Max>(halide_type_t &t, double a, double b) noexcept {
1128  return std::max(a, b);
1129 }
1130 
1131 template<typename A, typename B>
1132 HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp<LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1133  return {pattern_arg(a), pattern_arg(b)};
1134 }
1135 
1136 template<typename A, typename B>
1137 HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b)) {
1138  return IRMatcher::operator<(a, b);
1139 }
1140 
1141 template<>
1143  return a < b;
1144 }
1145 
1146 template<>
1148  return a < b;
1149 }
1150 
1151 template<>
1153  return a < b;
1154 }
1155 
1156 template<typename A, typename B>
1157 HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp<GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1158  return {pattern_arg(a), pattern_arg(b)};
1159 }
1160 
1161 template<typename A, typename B>
1162 HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b)) {
1163  return IRMatcher::operator>(a, b);
1164 }
1165 
1166 template<>
1168  return a > b;
1169 }
1170 
1171 template<>
1173  return a > b;
1174 }
1175 
1176 template<>
1178  return a > b;
1179 }
1180 
1181 template<typename A, typename B>
1182 HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp<LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1183  return {pattern_arg(a), pattern_arg(b)};
1184 }
1185 
1186 template<typename A, typename B>
1187 HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b)) {
1188  return IRMatcher::operator<=(a, b);
1189 }
1190 
1191 template<>
1193  return a <= b;
1194 }
1195 
1196 template<>
1198  return a <= b;
1199 }
1200 
1201 template<>
1203  return a <= b;
1204 }
1205 
1206 template<typename A, typename B>
1207 HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp<GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1208  return {pattern_arg(a), pattern_arg(b)};
1209 }
1210 
1211 template<typename A, typename B>
1212 HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b)) {
1213  return IRMatcher::operator>=(a, b);
1214 }
1215 
1216 template<>
1218  return a >= b;
1219 }
1220 
1221 template<>
1223  return a >= b;
1224 }
1225 
1226 template<>
1228  return a >= b;
1229 }
1230 
1231 template<typename A, typename B>
1232 HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp<EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1233  return {pattern_arg(a), pattern_arg(b)};
1234 }
1235 
1236 template<typename A, typename B>
1237 HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b)) {
1238  return IRMatcher::operator==(a, b);
1239 }
1240 
1241 template<>
1243  return a == b;
1244 }
1245 
1246 template<>
1248  return a == b;
1249 }
1250 
1251 template<>
1253  return a == b;
1254 }
1255 
1256 template<typename A, typename B>
1257 HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp<NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1258  return {pattern_arg(a), pattern_arg(b)};
1259 }
1260 
1261 template<typename A, typename B>
1262 HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b)) {
1263  return IRMatcher::operator!=(a, b);
1264 }
1265 
1266 template<>
1268  return a != b;
1269 }
1270 
1271 template<>
1273  return a != b;
1274 }
1275 
1276 template<>
1278  return a != b;
1279 }
1280 
1281 template<typename A, typename B>
1282 HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp<Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1283  return {pattern_arg(a), pattern_arg(b)};
1284 }
1285 
1286 template<typename A, typename B>
1287 HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b)) {
1288  return IRMatcher::operator||(a, b);
1289 }
1290 
1291 template<>
1293  return (a | b) & 1;
1294 }
1295 
1296 template<>
1298  return (a | b) & 1;
1299 }
1300 
1301 template<>
1302 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Or>(halide_type_t &t, double a, double b) noexcept {
1303  // Unreachable, as it would be a type mismatch.
1304  return 0;
1305 }
1306 
1307 template<typename A, typename B>
1308 HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp<And, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1309  return {pattern_arg(a), pattern_arg(b)};
1310 }
1311 
1312 template<typename A, typename B>
1313 HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b)) {
1314  return IRMatcher::operator&&(a, b);
1315 }
1316 
1317 template<>
1319  return a & b & 1;
1320 }
1321 
1322 template<>
1324  return a & b & 1;
1325 }
1326 
1327 template<>
1328 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<And>(halide_type_t &t, double a, double b) noexcept {
1329  // Unreachable
1330  return 0;
1331 }
1332 
1333 constexpr inline uint32_t bitwise_or_reduce() {
1334  return 0;
1335 }
1336 
1337 template<typename... Args>
1338 constexpr uint32_t bitwise_or_reduce(uint32_t first, Args... rest) {
1339  return first | bitwise_or_reduce(rest...);
1340 }
1341 
1342 constexpr inline bool and_reduce() {
1343  return true;
1344 }
1345 
1346 template<typename... Args>
1347 constexpr bool and_reduce(bool first, Args... rest) {
1348  return first && and_reduce(rest...);
1349 }
1350 
1351 // TODO: this can be replaced with std::min() once we require C++14 or later
1352 constexpr int const_min(int a, int b) {
1353  return a < b ? a : b;
1354 }
1355 
1356 template<typename... Args>
1357 struct Intrin {
1358  struct pattern_tag {};
1360  std::tuple<Args...> args;
1361  // The type of the output of the intrinsic node.
1362  // Only necessary in cases where it can't be inferred
1363  // from the input types (e.g. saturating_cast).
1365 
1367 
1370  constexpr static bool canonical = and_reduce((Args::canonical)...);
1371 
1372  template<int i,
1373  uint32_t bound,
1374  typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1375  HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept {
1376  using T = decltype(std::get<i>(args));
1377  return (std::get<i>(args).template match<bound>(*c.args[i].get(), state) &&
1378  match_args<i + 1, bound | bindings<T>::mask>(0, c, state));
1379  }
1380 
1381  template<int i, uint32_t binds>
1382  HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept {
1383  return true;
1384  }
1385 
1386  template<uint32_t bound>
1387  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1388  if (e.node_type != IRNodeType::Call) {
1389  return false;
1390  }
1391  const Call &c = (const Call &)e;
1392  return (c.is_intrinsic(intrin) &&
1393  ((optional_type_hint == Type()) || optional_type_hint == e.type) &&
1394  match_args<0, bound>(0, c, state));
1395  }
1396 
1397  template<int i,
1398  typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1399  HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const {
1400  s << std::get<i>(args);
1401  if (i + 1 < sizeof...(Args)) {
1402  s << ", ";
1403  }
1404  print_args<i + 1>(0, s);
1405  }
1406 
1407  template<int i>
1408  HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const {
1409  }
1410 
1412  void print_args(std::ostream &s) const {
1413  print_args<0>(0, s);
1414  }
1415 
1417  Expr make(MatcherState &state, halide_type_t type_hint) const {
1418  Expr arg0 = std::get<0>(args).make(state, type_hint);
1419  if (intrin == Call::likely) {
1420  return likely(arg0);
1421  } else if (intrin == Call::likely_if_innermost) {
1422  return likely_if_innermost(arg0);
1423  } else if (intrin == Call::abs) {
1424  return abs(arg0);
1425  } else if (intrin == Call::saturating_cast) {
1426  return saturating_cast(optional_type_hint, arg0);
1427  }
1428 
1429  Expr arg1 = std::get<const_min(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
1430  if (intrin == Call::absd) {
1431  return absd(arg0, arg1);
1432  } else if (intrin == Call::widen_right_add) {
1433  return widen_right_add(arg0, arg1);
1434  } else if (intrin == Call::widen_right_mul) {
1435  return widen_right_mul(arg0, arg1);
1436  } else if (intrin == Call::widen_right_sub) {
1437  return widen_right_sub(arg0, arg1);
1438  } else if (intrin == Call::widening_add) {
1439  return widening_add(arg0, arg1);
1440  } else if (intrin == Call::widening_sub) {
1441  return widening_sub(arg0, arg1);
1442  } else if (intrin == Call::widening_mul) {
1443  return widening_mul(arg0, arg1);
1444  } else if (intrin == Call::saturating_add) {
1445  return saturating_add(arg0, arg1);
1446  } else if (intrin == Call::saturating_sub) {
1447  return saturating_sub(arg0, arg1);
1448  } else if (intrin == Call::halving_add) {
1449  return halving_add(arg0, arg1);
1450  } else if (intrin == Call::halving_sub) {
1451  return halving_sub(arg0, arg1);
1452  } else if (intrin == Call::rounding_halving_add) {
1453  return rounding_halving_add(arg0, arg1);
1454  } else if (intrin == Call::shift_left) {
1455  return arg0 << arg1;
1456  } else if (intrin == Call::shift_right) {
1457  return arg0 >> arg1;
1458  } else if (intrin == Call::rounding_shift_left) {
1459  return rounding_shift_left(arg0, arg1);
1460  } else if (intrin == Call::rounding_shift_right) {
1461  return rounding_shift_right(arg0, arg1);
1462  }
1463 
1464  Expr arg2 = std::get<const_min(2, sizeof...(Args) - 1)>(args).make(state, type_hint);
1465  if (intrin == Call::mul_shift_right) {
1466  return mul_shift_right(arg0, arg1, arg2);
1467  } else if (intrin == Call::rounding_mul_shift_right) {
1468  return rounding_mul_shift_right(arg0, arg1, arg2);
1469  }
1470 
1471  internal_error << "Unhandled intrinsic in IRMatcher: " << intrin;
1472  return Expr();
1473  }
1474 
1475  constexpr static bool foldable = true;
1476 
1478  halide_scalar_value_t arg1;
1479  // Assuming the args have the same type as the intrinsic is incorrect in
1480  // general. But for the intrinsics we can fold (just shifts), the LHS
1481  // has the same type as the intrinsic, and we can always treat the RHS
1482  // as a signed int, because we're using 64 bits for it.
1483  std::get<0>(args).make_folded_const(val, ty, state);
1484  halide_type_t signed_ty = ty;
1485  signed_ty.code = halide_type_int;
1486  // We can just directly get the second arg here, because we only want to
1487  // instantiate this method for shifts, which have two args.
1488  std::get<1>(args).make_folded_const(arg1, signed_ty, state);
1489 
1490  if (intrin == Call::shift_left) {
1491  if (arg1.u.i64 < 0) {
1492  if (ty.code == halide_type_int) {
1493  // Arithmetic shift
1494  val.u.i64 >>= -arg1.u.i64;
1495  } else {
1496  // Logical shift
1497  val.u.u64 >>= -arg1.u.i64;
1498  }
1499  } else {
1500  val.u.u64 <<= arg1.u.i64;
1501  }
1502  } else if (intrin == Call::shift_right) {
1503  if (arg1.u.i64 > 0) {
1504  if (ty.code == halide_type_int) {
1505  // Arithmetic shift
1506  val.u.i64 >>= arg1.u.i64;
1507  } else {
1508  // Logical shift
1509  val.u.u64 >>= arg1.u.i64;
1510  }
1511  } else {
1512  val.u.u64 <<= -arg1.u.i64;
1513  }
1514  } else {
1515  internal_error << "Folding not implemented for intrinsic: " << intrin;
1516  }
1517  }
1518 
1521  : intrin(intrin), args(args...) {
1522  }
1523 };
1524 
1525 template<typename... Args>
1526 std::ostream &operator<<(std::ostream &s, const Intrin<Args...> &op) {
1527  s << op.intrin << "(";
1528  op.print_args(s);
1529  s << ")";
1530  return s;
1531 }
1532 
1533 template<typename... Args>
1534 HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin<decltype(pattern_arg(args))...> {
1535  return {intrinsic_op, pattern_arg(args)...};
1536 }
1537 
1538 template<typename A, typename B>
1539 auto widen_right_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1541 }
1542 template<typename A, typename B>
1543 auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1545 }
1546 template<typename A, typename B>
1547 auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1549 }
1550 
1551 template<typename A, typename B>
1552 auto widening_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1553  return {Call::widening_add, pattern_arg(a), pattern_arg(b)};
1554 }
1555 template<typename A, typename B>
1556 auto widening_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1557  return {Call::widening_sub, pattern_arg(a), pattern_arg(b)};
1558 }
1559 template<typename A, typename B>
1560 auto widening_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1561  return {Call::widening_mul, pattern_arg(a), pattern_arg(b)};
1562 }
1563 template<typename A, typename B>
1564 auto saturating_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1565  return {Call::saturating_add, pattern_arg(a), pattern_arg(b)};
1566 }
1567 template<typename A, typename B>
1568 auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1569  return {Call::saturating_sub, pattern_arg(a), pattern_arg(b)};
1570 }
1571 template<typename A>
1572 auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin<decltype(pattern_arg(a))> {
1573  Intrin<decltype(pattern_arg(a))> p = {Call::saturating_cast, pattern_arg(a)};
1574  p.optional_type_hint = t;
1575  return p;
1576 }
1577 template<typename A, typename B>
1578 auto halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1579  return {Call::halving_add, pattern_arg(a), pattern_arg(b)};
1580 }
1581 template<typename A, typename B>
1582 auto halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1583  return {Call::halving_sub, pattern_arg(a), pattern_arg(b)};
1584 }
1585 template<typename A, typename B>
1586 auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1588 }
1589 template<typename A, typename B>
1590 auto shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1591  return {Call::shift_left, pattern_arg(a), pattern_arg(b)};
1592 }
1593 template<typename A, typename B>
1594 auto shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1595  return {Call::shift_right, pattern_arg(a), pattern_arg(b)};
1596 }
1597 template<typename A, typename B>
1598 auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1600 }
1601 template<typename A, typename B>
1602 auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1604 }
1605 template<typename A, typename B, typename C>
1606 auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1608 }
1609 template<typename A, typename B, typename C>
1610 auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1612 }
1613 
1614 template<typename A>
1615 struct NotOp {
1616  struct pattern_tag {};
1617  A a;
1618 
1619  constexpr static uint32_t binds = bindings<A>::mask;
1620 
1623  constexpr static bool canonical = A::canonical;
1624 
1625  template<uint32_t bound>
1626  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1627  if (e.node_type != IRNodeType::Not) {
1628  return false;
1629  }
1630  const Not &op = (const Not &)e;
1631  return (a.template match<bound>(*op.a.get(), state));
1632  }
1633 
1634  template<uint32_t bound, typename A2>
1635  HALIDE_ALWAYS_INLINE bool match(const NotOp<A2> &op, MatcherState &state) const noexcept {
1636  return a.template match<bound>(unwrap(op.a), state);
1637  }
1638 
1640  Expr make(MatcherState &state, halide_type_t type_hint) const {
1641  return Not::make(a.make(state, type_hint));
1642  }
1643 
1644  constexpr static bool foldable = A::foldable;
1645 
1646  template<typename A1 = A>
1648  a.make_folded_const(val, ty, state);
1649  val.u.u64 = ~val.u.u64;
1650  val.u.u64 &= 1;
1651  }
1652 };
1653 
1654 template<typename A>
1655 HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp<decltype(pattern_arg(a))> {
1656  assert_is_lvalue_if_expr<A>();
1657  return {pattern_arg(a)};
1658 }
1659 
1660 template<typename A>
1661 HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a)) {
1662  assert_is_lvalue_if_expr<A>();
1663  return IRMatcher::operator!(a);
1664 }
1665 
1666 template<typename A>
1667 inline std::ostream &operator<<(std::ostream &s, const NotOp<A> &op) {
1668  s << "!(" << op.a << ")";
1669  return s;
1670 }
1671 
1672 template<typename C, typename T, typename F>
1673 struct SelectOp {
1674  struct pattern_tag {};
1675  C c;
1676  T t;
1677  F f;
1678 
1680 
1683 
1684  constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
1685 
1686  template<uint32_t bound>
1687  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1688  if (e.node_type != Select::_node_type) {
1689  return false;
1690  }
1691  const Select &op = (const Select &)e;
1692  return (c.template match<bound>(*op.condition.get(), state) &&
1693  t.template match<bound | bindings<C>::mask>(*op.true_value.get(), state) &&
1694  f.template match<bound | bindings<C>::mask | bindings<T>::mask>(*op.false_value.get(), state));
1695  }
1696  template<uint32_t bound, typename C2, typename T2, typename F2>
1697  HALIDE_ALWAYS_INLINE bool match(const SelectOp<C2, T2, F2> &instance, MatcherState &state) const noexcept {
1698  return (c.template match<bound>(unwrap(instance.c), state) &&
1699  t.template match<bound | bindings<C>::mask>(unwrap(instance.t), state) &&
1700  f.template match<bound | bindings<C>::mask | bindings<T>::mask>(unwrap(instance.f), state));
1701  }
1702 
1704  Expr make(MatcherState &state, halide_type_t type_hint) const {
1705  return Select::make(c.make(state, {}), t.make(state, type_hint), f.make(state, type_hint));
1706  }
1707 
1708  constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
1709 
1710  template<typename C1 = C>
1712  halide_scalar_value_t c_val, t_val, f_val;
1713  halide_type_t c_ty;
1714  c.make_folded_const(c_val, c_ty, state);
1715  if ((c_val.u.u64 & 1) == 1) {
1716  t.make_folded_const(val, ty, state);
1717  } else {
1718  f.make_folded_const(val, ty, state);
1719  }
1720  ty.lanes |= c_ty.lanes & MatcherState::special_values_mask;
1721  }
1722 };
1723 
1724 template<typename C, typename T, typename F>
1725 std::ostream &operator<<(std::ostream &s, const SelectOp<C, T, F> &op) {
1726  s << "select(" << op.c << ", " << op.t << ", " << op.f << ")";
1727  return s;
1728 }
1729 
1730 template<typename C, typename T, typename F>
1731 HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp<decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))> {
1732  assert_is_lvalue_if_expr<C>();
1733  assert_is_lvalue_if_expr<T>();
1734  assert_is_lvalue_if_expr<F>();
1735  return {pattern_arg(c), pattern_arg(t), pattern_arg(f)};
1736 }
1737 
1738 template<typename A, typename B>
1739 struct BroadcastOp {
1740  struct pattern_tag {};
1741  A a;
1743 
1745 
1748 
1749  constexpr static bool canonical = A::canonical && B::canonical;
1750 
1751  template<uint32_t bound>
1752  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1753  if (e.node_type == Broadcast::_node_type) {
1754  const Broadcast &op = (const Broadcast &)e;
1755  if (a.template match<bound>(*op.value.get(), state) &&
1756  lanes.template match<bound>(op.lanes, state)) {
1757  return true;
1758  }
1759  }
1760  return false;
1761  }
1762 
1763  template<uint32_t bound, typename A2, typename B2>
1764  HALIDE_ALWAYS_INLINE bool match(const BroadcastOp<A2, B2> &op, MatcherState &state) const noexcept {
1765  return (a.template match<bound>(unwrap(op.a), state) &&
1766  lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1767  }
1768 
1770  Expr make(MatcherState &state, halide_type_t type_hint) const {
1771  halide_scalar_value_t lanes_val;
1772  halide_type_t ty;
1773  lanes.make_folded_const(lanes_val, ty, state);
1774  int32_t l = (int32_t)lanes_val.u.i64;
1775  type_hint.lanes /= l;
1776  Expr val = a.make(state, type_hint);
1777  if (l == 1) {
1778  return val;
1779  } else {
1780  return Broadcast::make(std::move(val), l);
1781  }
1782  }
1783 
1784  constexpr static bool foldable = false;
1785 
1786  template<typename A1 = A>
1788  halide_scalar_value_t lanes_val;
1789  halide_type_t lanes_ty;
1790  lanes.make_folded_const(lanes_val, lanes_ty, state);
1791  uint16_t l = (uint16_t)lanes_val.u.i64;
1792  a.make_folded_const(val, ty, state);
1793  ty.lanes = l | (ty.lanes & MatcherState::special_values_mask);
1794  }
1795 };
1796 
1797 template<typename A, typename B>
1798 inline std::ostream &operator<<(std::ostream &s, const BroadcastOp<A, B> &op) {
1799  s << "broadcast(" << op.a << ", " << op.lanes << ")";
1800  return s;
1801 }
1802 
1803 template<typename A, typename B>
1804 HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes))> {
1805  assert_is_lvalue_if_expr<A>();
1806  return {pattern_arg(a), pattern_arg(lanes)};
1807 }
1808 
1809 template<typename A, typename B, typename C>
1810 struct RampOp {
1811  struct pattern_tag {};
1812  A a;
1813  B b;
1815 
1817 
1820 
1821  constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
1822 
1823  template<uint32_t bound>
1824  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1825  if (e.node_type != Ramp::_node_type) {
1826  return false;
1827  }
1828  const Ramp &op = (const Ramp &)e;
1829  if (a.template match<bound>(*op.base.get(), state) &&
1830  b.template match<bound | bindings<A>::mask>(*op.stride.get(), state) &&
1831  lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(op.lanes, state)) {
1832  return true;
1833  } else {
1834  return false;
1835  }
1836  }
1837 
1838  template<uint32_t bound, typename A2, typename B2, typename C2>
1839  HALIDE_ALWAYS_INLINE bool match(const RampOp<A2, B2, C2> &op, MatcherState &state) const noexcept {
1840  return (a.template match<bound>(unwrap(op.a), state) &&
1841  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state) &&
1842  lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(unwrap(op.lanes), state));
1843  }
1844 
1846  Expr make(MatcherState &state, halide_type_t type_hint) const {
1847  halide_scalar_value_t lanes_val;
1848  halide_type_t ty;
1849  lanes.make_folded_const(lanes_val, ty, state);
1850  int32_t l = (int32_t)lanes_val.u.i64;
1851  type_hint.lanes /= l;
1852  Expr ea, eb;
1853  eb = b.make(state, type_hint);
1854  ea = a.make(state, eb.type());
1855  return Ramp::make(ea, eb, l);
1856  }
1857 
1858  constexpr static bool foldable = false;
1859 };
1860 
1861 template<typename A, typename B, typename C>
1862 std::ostream &operator<<(std::ostream &s, const RampOp<A, B, C> &op) {
1863  s << "ramp(" << op.a << ", " << op.b << ", " << op.lanes << ")";
1864  return s;
1865 }
1866 
1867 template<typename A, typename B, typename C>
1868 HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1869  assert_is_lvalue_if_expr<A>();
1870  assert_is_lvalue_if_expr<B>();
1871  assert_is_lvalue_if_expr<C>();
1872  return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1873 }
1874 
1875 template<typename A, typename B, VectorReduce::Operator reduce_op>
1877  struct pattern_tag {};
1878  A a;
1880 
1882 
1885  constexpr static bool canonical = A::canonical;
1886 
1887  template<uint32_t bound>
1888  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1889  if (e.node_type == VectorReduce::_node_type) {
1890  const VectorReduce &op = (const VectorReduce &)e;
1891  if (op.op == reduce_op &&
1892  a.template match<bound>(*op.value.get(), state) &&
1893  lanes.template match<bound | bindings<A>::mask>(op.type.lanes(), state)) {
1894  return true;
1895  }
1896  }
1897  return false;
1898  }
1899 
1900  template<uint32_t bound, typename A2, typename B2, VectorReduce::Operator reduce_op_2>
1902  return (reduce_op == reduce_op_2 &&
1903  a.template match<bound>(unwrap(op.a), state) &&
1904  lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1905  }
1906 
1908  Expr make(MatcherState &state, halide_type_t type_hint) const {
1909  halide_scalar_value_t lanes_val;
1910  halide_type_t ty;
1911  lanes.make_folded_const(lanes_val, ty, state);
1912  int l = (int)lanes_val.u.i64;
1913  return VectorReduce::make(reduce_op, a.make(state, type_hint), l);
1914  }
1915 
1916  constexpr static bool foldable = false;
1917 };
1918 
1919 template<typename A, typename B, VectorReduce::Operator reduce_op>
1920 inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp<A, B, reduce_op> &op) {
1921  s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")";
1922  return s;
1923 }
1924 
1925 template<typename A, typename B>
1926 HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add> {
1927  assert_is_lvalue_if_expr<A>();
1928  return {pattern_arg(a), pattern_arg(lanes)};
1929 }
1930 
1931 template<typename A, typename B>
1932 HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min> {
1933  assert_is_lvalue_if_expr<A>();
1934  return {pattern_arg(a), pattern_arg(lanes)};
1935 }
1936 
1937 template<typename A, typename B>
1938 HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max> {
1939  assert_is_lvalue_if_expr<A>();
1940  return {pattern_arg(a), pattern_arg(lanes)};
1941 }
1942 
1943 template<typename A, typename B>
1944 HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And> {
1945  assert_is_lvalue_if_expr<A>();
1946  return {pattern_arg(a), pattern_arg(lanes)};
1947 }
1948 
1949 template<typename A, typename B>
1950 HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or> {
1951  assert_is_lvalue_if_expr<A>();
1952  return {pattern_arg(a), pattern_arg(lanes)};
1953 }
1954 
1955 template<typename A>
1956 struct NegateOp {
1957  struct pattern_tag {};
1958  A a;
1959 
1960  constexpr static uint32_t binds = bindings<A>::mask;
1961 
1964 
1965  constexpr static bool canonical = A::canonical;
1966 
1967  template<uint32_t bound>
1968  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1969  if (e.node_type != Sub::_node_type) {
1970  return false;
1971  }
1972  const Sub &op = (const Sub &)e;
1973  return (a.template match<bound>(*op.b.get(), state) &&
1974  is_const_zero(op.a));
1975  }
1976 
1977  template<uint32_t bound, typename A2>
1978  HALIDE_ALWAYS_INLINE bool match(NegateOp<A2> &&p, MatcherState &state) const noexcept {
1979  return a.template match<bound>(unwrap(p.a), state);
1980  }
1981 
1983  Expr make(MatcherState &state, halide_type_t type_hint) const {
1984  Expr ea = a.make(state, type_hint);
1985  Expr z = make_zero(ea.type());
1986  return Sub::make(std::move(z), std::move(ea));
1987  }
1988 
1989  constexpr static bool foldable = A::foldable;
1990 
1991  template<typename A1 = A>
1993  a.make_folded_const(val, ty, state);
1994  int dead_bits = 64 - ty.bits;
1995  switch (ty.code) {
1996  case halide_type_int:
1997  if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
1998  // Trying to negate the most negative signed int for a no-overflow type.
2000  } else {
2001  // Negate, drop the high bits, and then sign-extend them back
2002  val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits;
2003  }
2004  break;
2005  case halide_type_uint:
2006  val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
2007  break;
2008  case halide_type_float:
2009  case halide_type_bfloat:
2010  val.u.f64 = -val.u.f64;
2011  break;
2012  default:
2013  // unreachable
2014  ;
2015  }
2016  }
2017 };
2018 
2019 template<typename A>
2020 std::ostream &operator<<(std::ostream &s, const NegateOp<A> &op) {
2021  s << "-" << op.a;
2022  return s;
2023 }
2024 
2025 template<typename A>
2026 HALIDE_ALWAYS_INLINE auto operator-(A &&a) noexcept -> NegateOp<decltype(pattern_arg(a))> {
2027  assert_is_lvalue_if_expr<A>();
2028  return {pattern_arg(a)};
2029 }
2030 
2031 template<typename A>
2032 HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a)) {
2033  assert_is_lvalue_if_expr<A>();
2034  return IRMatcher::operator-(a);
2035 }
2036 
2037 template<typename A>
2038 struct CastOp {
2039  struct pattern_tag {};
2041  A a;
2042 
2043  constexpr static uint32_t binds = bindings<A>::mask;
2044 
2047  constexpr static bool canonical = A::canonical;
2048 
2049  template<uint32_t bound>
2050  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2051  if (e.node_type != Cast::_node_type) {
2052  return false;
2053  }
2054  const Cast &op = (const Cast &)e;
2055  return (e.type == t &&
2056  a.template match<bound>(*op.value.get(), state));
2057  }
2058  template<uint32_t bound, typename A2>
2059  HALIDE_ALWAYS_INLINE bool match(const CastOp<A2> &op, MatcherState &state) const noexcept {
2060  return t == op.t && a.template match<bound>(unwrap(op.a), state);
2061  }
2062 
2064  Expr make(MatcherState &state, halide_type_t type_hint) const {
2065  return cast(t, a.make(state, {}));
2066  }
2067 
2068  constexpr static bool foldable = false;
2069 };
2070 
2071 template<typename A>
2072 std::ostream &operator<<(std::ostream &s, const CastOp<A> &op) {
2073  s << "cast(" << op.t << ", " << op.a << ")";
2074  return s;
2075 }
2076 
2077 template<typename A>
2078 HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp<decltype(pattern_arg(a))> {
2079  assert_is_lvalue_if_expr<A>();
2080  return {t, pattern_arg(a)};
2081 }
2082 
2083 template<typename Vec, typename Base, typename Stride, typename Lanes>
2084 struct SliceOp {
2085  struct pattern_tag {};
2086  Vec vec;
2087  Base base;
2088  Stride stride;
2089  Lanes lanes;
2090 
2091  static constexpr uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds;
2092 
2095  constexpr static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical;
2096 
2097  template<uint32_t bound>
2098  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2099  if (e.node_type != IRNodeType::Shuffle) {
2100  return false;
2101  }
2102  const Shuffle &v = (const Shuffle &)e;
2103  return v.vectors.size() == 1 &&
2104  v.is_slice() &&
2105  vec.template match<bound>(*v.vectors[0].get(), state) &&
2106  base.template match<bound | bindings<Vec>::mask>(v.slice_begin(), state) &&
2107  stride.template match<bound | bindings<Vec>::mask | bindings<Base>::mask>(v.slice_stride(), state) &&
2109  }
2110 
2112  Expr make(MatcherState &state, halide_type_t type_hint) const {
2113  halide_scalar_value_t base_val, stride_val, lanes_val;
2114  halide_type_t ty;
2115  base.make_folded_const(base_val, ty, state);
2116  int b = (int)base_val.u.i64;
2117  stride.make_folded_const(stride_val, ty, state);
2118  int s = (int)stride_val.u.i64;
2119  lanes.make_folded_const(lanes_val, ty, state);
2120  int l = (int)lanes_val.u.i64;
2121  return Shuffle::make_slice(vec.make(state, type_hint), b, s, l);
2122  }
2123 
2124  constexpr static bool foldable = false;
2125 
2127  SliceOp(Vec v, Base b, Stride s, Lanes l)
2128  : vec(v), base(b), stride(s), lanes(l) {
2129  static_assert(Base::foldable, "Base of slice should consist only of operations that constant-fold");
2130  static_assert(Stride::foldable, "Stride of slice should consist only of operations that constant-fold");
2131  static_assert(Lanes::foldable, "Lanes of slice should consist only of operations that constant-fold");
2132  }
2133 };
2134 
2135 template<typename Vec, typename Base, typename Stride, typename Lanes>
2136 std::ostream &operator<<(std::ostream &s, const SliceOp<Vec, Base, Stride, Lanes> &op) {
2137  s << "slice(" << op.vec << ", " << op.base << ", " << op.stride << ", " << op.lanes << ")";
2138  return s;
2139 }
2140 
2141 template<typename Vec, typename Base, typename Stride, typename Lanes>
2142 HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept
2143  -> SliceOp<decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))> {
2144  return {pattern_arg(vec), pattern_arg(base), pattern_arg(stride), pattern_arg(lanes)};
2145 }
2146 
2147 template<typename A>
2148 struct Fold {
2149  struct pattern_tag {};
2150  A a;
2151 
2152  constexpr static uint32_t binds = bindings<A>::mask;
2153 
2156  constexpr static bool canonical = true;
2157 
2159  Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
2161  halide_type_t ty = type_hint;
2162  a.make_folded_const(c, ty, state);
2163 
2164  // The result of the fold may have an underspecified type
2165  // (e.g. because it's from an int literal). Make the type code
2166  // and bits match the required type, if there is one (we can
2167  // tell from the bits field).
2168  if (type_hint.bits) {
2169  if (((int)ty.code == (int)halide_type_int) &&
2170  ((int)type_hint.code == (int)halide_type_float)) {
2171  int64_t x = c.u.i64;
2172  c.u.f64 = (double)x;
2173  }
2174  ty.code = type_hint.code;
2175  ty.bits = type_hint.bits;
2176  }
2177 
2178  Expr e = make_const_expr(c, ty);
2179  return e;
2180  }
2181 
2182  constexpr static bool foldable = A::foldable;
2183 
2184  template<typename A1 = A>
2186  a.make_folded_const(val, ty, state);
2187  }
2188 };
2189 
2190 template<typename A>
2191 HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold<decltype(pattern_arg(a))> {
2192  assert_is_lvalue_if_expr<A>();
2193  return {pattern_arg(a)};
2194 }
2195 
2196 template<typename A>
2197 std::ostream &operator<<(std::ostream &s, const Fold<A> &op) {
2198  s << "fold(" << op.a << ")";
2199  return s;
2200 }
2201 
2202 template<typename A>
2203 struct Overflows {
2204  struct pattern_tag {};
2205  A a;
2206 
2207  constexpr static uint32_t binds = bindings<A>::mask;
2208 
2209  // This rule is a predicate, so it always evaluates to a boolean,
2210  // which has IRNodeType UIntImm
2213  constexpr static bool canonical = true;
2214 
2215  constexpr static bool foldable = A::foldable;
2216 
2217  template<typename A1 = A>
2219  a.make_folded_const(val, ty, state);
2220  ty.code = halide_type_uint;
2221  ty.bits = 64;
2222  val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0;
2223  ty.lanes = 1;
2224  }
2225 };
2226 
2227 template<typename A>
2228 HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows<decltype(pattern_arg(a))> {
2229  assert_is_lvalue_if_expr<A>();
2230  return {pattern_arg(a)};
2231 }
2232 
2233 template<typename A>
2234 std::ostream &operator<<(std::ostream &s, const Overflows<A> &op) {
2235  s << "overflows(" << op.a << ")";
2236  return s;
2237 }
2238 
2239 struct Overflow {
2240  struct pattern_tag {};
2241 
2242  constexpr static uint32_t binds = 0;
2243 
2244  // Overflow is an intrinsic, represented as a Call node
2247  constexpr static bool canonical = true;
2248 
2249  template<uint32_t bound>
2250  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2251  if (e.node_type != Call::_node_type) {
2252  return false;
2253  }
2254  const Call &op = (const Call &)e;
2256  }
2257 
2259  Expr make(MatcherState &state, halide_type_t type_hint) const {
2261  return make_const_special_expr(type_hint);
2262  }
2263 
2264  constexpr static bool foldable = true;
2265 
2267  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
2268  val.u.u64 = 0;
2270  }
2271 };
2272 
2273 inline std::ostream &operator<<(std::ostream &s, const Overflow &op) {
2274  s << "overflow()";
2275  return s;
2276 }
2277 
2278 template<typename A>
2279 struct IsConst {
2280  struct pattern_tag {};
2281 
2282  constexpr static uint32_t binds = bindings<A>::mask;
2283 
2284  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2287  constexpr static bool canonical = true;
2288 
2289  A a;
2290  bool check_v;
2292 
2293  constexpr static bool foldable = true;
2294 
2295  template<typename A1 = A>
2297  Expr e = a.make(state, {});
2298  ty.code = halide_type_uint;
2299  ty.bits = 64;
2300  ty.lanes = 1;
2301  if (check_v) {
2302  val.u.u64 = ::Halide::Internal::is_const(e, v) ? 1 : 0;
2303  } else {
2304  val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0;
2305  }
2306  }
2307 };
2308 
2309 template<typename A>
2310 HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst<decltype(pattern_arg(a))> {
2311  assert_is_lvalue_if_expr<A>();
2312  return {pattern_arg(a), false, 0};
2313 }
2314 
2315 template<typename A>
2316 HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst<decltype(pattern_arg(a))> {
2317  assert_is_lvalue_if_expr<A>();
2318  return {pattern_arg(a), true, value};
2319 }
2320 
2321 template<typename A>
2322 std::ostream &operator<<(std::ostream &s, const IsConst<A> &op) {
2323  if (op.check_v) {
2324  s << "is_const(" << op.a << ")";
2325  } else {
2326  s << "is_const(" << op.a << ", " << op.v << ")";
2327  }
2328  return s;
2329 }
2330 
2331 template<typename A, typename Prover>
2332 struct CanProve {
2333  struct pattern_tag {};
2334  A a;
2335  Prover *prover; // An existing simplifying mutator
2336 
2337  constexpr static uint32_t binds = bindings<A>::mask;
2338 
2339  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2342  constexpr static bool canonical = true;
2343 
2344  constexpr static bool foldable = true;
2345 
2346  // Includes a raw call to an inlined make method, so don't inline.
2348  Expr condition = a.make(state, {});
2349  condition = prover->mutate(condition, nullptr);
2350  val.u.u64 = is_const_one(condition);
2351  ty.code = halide_type_uint;
2352  ty.bits = 1;
2353  ty.lanes = condition.type().lanes();
2354  }
2355 };
2356 
2357 template<typename A, typename Prover>
2358 HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve<decltype(pattern_arg(a)), Prover> {
2359  assert_is_lvalue_if_expr<A>();
2360  return {pattern_arg(a), p};
2361 }
2362 
2363 template<typename A, typename Prover>
2364 std::ostream &operator<<(std::ostream &s, const CanProve<A, Prover> &op) {
2365  s << "can_prove(" << op.a << ")";
2366  return s;
2367 }
2368 
2369 template<typename A>
2370 struct IsFloat {
2371  struct pattern_tag {};
2372  A a;
2373 
2374  constexpr static uint32_t binds = bindings<A>::mask;
2375 
2376  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2379  constexpr static bool canonical = true;
2380 
2381  constexpr static bool foldable = true;
2382 
2385  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2386  Type t = a.make(state, {}).type();
2387  val.u.u64 = t.is_float();
2388  ty.code = halide_type_uint;
2389  ty.bits = 1;
2390  ty.lanes = t.lanes();
2391  }
2392 };
2393 
2394 template<typename A>
2395 HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat<decltype(pattern_arg(a))> {
2396  assert_is_lvalue_if_expr<A>();
2397  return {pattern_arg(a)};
2398 }
2399 
2400 template<typename A>
2401 std::ostream &operator<<(std::ostream &s, const IsFloat<A> &op) {
2402  s << "is_float(" << op.a << ")";
2403  return s;
2404 }
2405 
2406 template<typename A>
2407 struct IsInt {
2408  struct pattern_tag {};
2409  A a;
2410  int bits, lanes;
2411 
2412  constexpr static uint32_t binds = bindings<A>::mask;
2413 
2414  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2417  constexpr static bool canonical = true;
2418 
2419  constexpr static bool foldable = true;
2420 
2423  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2424  Type t = a.make(state, {}).type();
2425  val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2426  ty.code = halide_type_uint;
2427  ty.bits = 1;
2428  ty.lanes = t.lanes();
2429  }
2430 };
2431 
2432 template<typename A>
2433 HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits = 0, int lanes = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
2434  assert_is_lvalue_if_expr<A>();
2435  return {pattern_arg(a), bits, lanes};
2436 }
2437 
2438 template<typename A>
2439 std::ostream &operator<<(std::ostream &s, const IsInt<A> &op) {
2440  s << "is_int(" << op.a;
2441  if (op.bits > 0) {
2442  s << ", " << op.bits;
2443  }
2444  if (op.lanes > 0) {
2445  s << ", " << op.lanes;
2446  }
2447  s << ")";
2448  return s;
2449 }
2450 
2451 template<typename A>
2452 struct IsUInt {
2453  struct pattern_tag {};
2454  A a;
2455  int bits, lanes;
2456 
2457  constexpr static uint32_t binds = bindings<A>::mask;
2458 
2459  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2462  constexpr static bool canonical = true;
2463 
2464  constexpr static bool foldable = true;
2465 
2468  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2469  Type t = a.make(state, {}).type();
2470  val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2471  ty.code = halide_type_uint;
2472  ty.bits = 1;
2473  ty.lanes = t.lanes();
2474  }
2475 };
2476 
2477 template<typename A>
2478 HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits = 0, int lanes = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
2479  assert_is_lvalue_if_expr<A>();
2480  return {pattern_arg(a), bits, lanes};
2481 }
2482 
2483 template<typename A>
2484 std::ostream &operator<<(std::ostream &s, const IsUInt<A> &op) {
2485  s << "is_uint(" << op.a;
2486  if (op.bits > 0) {
2487  s << ", " << op.bits;
2488  }
2489  if (op.lanes > 0) {
2490  s << ", " << op.lanes;
2491  }
2492  s << ")";
2493  return s;
2494 }
2495 
2496 template<typename A>
2497 struct IsScalar {
2498  struct pattern_tag {};
2499  A a;
2500 
2501  constexpr static uint32_t binds = bindings<A>::mask;
2502 
2503  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2506  constexpr static bool canonical = true;
2507 
2508  constexpr static bool foldable = true;
2509 
2512  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2513  Type t = a.make(state, {}).type();
2514  val.u.u64 = t.is_scalar();
2515  ty.code = halide_type_uint;
2516  ty.bits = 1;
2517  ty.lanes = t.lanes();
2518  }
2519 };
2520 
2521 template<typename A>
2522 HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar<decltype(pattern_arg(a))> {
2523  assert_is_lvalue_if_expr<A>();
2524  return {pattern_arg(a)};
2525 }
2526 
2527 template<typename A>
2528 std::ostream &operator<<(std::ostream &s, const IsScalar<A> &op) {
2529  s << "is_scalar(" << op.a << ")";
2530  return s;
2531 }
2532 
2533 template<typename A>
2534 struct IsMaxValue {
2535  struct pattern_tag {};
2536  A a;
2537 
2538  constexpr static uint32_t binds = bindings<A>::mask;
2539 
2540  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2543  constexpr static bool canonical = true;
2544 
2545  constexpr static bool foldable = true;
2546 
2549  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2550  a.make_folded_const(val, ty, state);
2551  const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int));
2552  if (ty.code == halide_type_uint || ty.code == halide_type_int) {
2553  val.u.u64 = (val.u.u64 == max_bits);
2554  } else {
2555  val.u.u64 = 0;
2556  }
2557  ty.code = halide_type_uint;
2558  ty.bits = 1;
2559  }
2560 };
2561 
2562 template<typename A>
2563 HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue<decltype(pattern_arg(a))> {
2564  assert_is_lvalue_if_expr<A>();
2565  return {pattern_arg(a)};
2566 }
2567 
2568 template<typename A>
2569 std::ostream &operator<<(std::ostream &s, const IsMaxValue<A> &op) {
2570  s << "is_max_value(" << op.a << ")";
2571  return s;
2572 }
2573 
2574 template<typename A>
2575 struct IsMinValue {
2576  struct pattern_tag {};
2577  A a;
2578 
2579  constexpr static uint32_t binds = bindings<A>::mask;
2580 
2581  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2584  constexpr static bool canonical = true;
2585 
2586  constexpr static bool foldable = true;
2587 
2590  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2591  a.make_folded_const(val, ty, state);
2592  if (ty.code == halide_type_int) {
2593  const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1);
2594  val.u.u64 = (val.u.u64 == min_bits);
2595  } else if (ty.code == halide_type_uint) {
2596  val.u.u64 = (val.u.u64 == 0);
2597  } else {
2598  val.u.u64 = 0;
2599  }
2600  ty.code = halide_type_uint;
2601  ty.bits = 1;
2602  }
2603 };
2604 
2605 template<typename A>
2606 HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue<decltype(pattern_arg(a))> {
2607  assert_is_lvalue_if_expr<A>();
2608  return {pattern_arg(a)};
2609 }
2610 
2611 template<typename A>
2612 std::ostream &operator<<(std::ostream &s, const IsMinValue<A> &op) {
2613  s << "is_min_value(" << op.a << ")";
2614  return s;
2615 }
2616 
2617 template<typename A>
2618 struct LanesOf {
2619  struct pattern_tag {};
2620  A a;
2621 
2622  constexpr static uint32_t binds = bindings<A>::mask;
2623 
2624  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2627  constexpr static bool canonical = true;
2628 
2629  constexpr static bool foldable = true;
2630 
2633  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2634  Type t = a.make(state, {}).type();
2635  val.u.u64 = t.lanes();
2636  ty.code = halide_type_uint;
2637  ty.bits = 32;
2638  ty.lanes = 1;
2639  }
2640 };
2641 
2642 template<typename A>
2643 HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf<decltype(pattern_arg(a))> {
2644  assert_is_lvalue_if_expr<A>();
2645  return {pattern_arg(a)};
2646 }
2647 
2648 template<typename A>
2649 std::ostream &operator<<(std::ostream &s, const LanesOf<A> &op) {
2650  s << "lanes_of(" << op.a << ")";
2651  return s;
2652 }
2653 
2654 // Verify properties of each rewrite rule. Currently just fuzz tests them.
2655 template<typename Before,
2656  typename After,
2657  typename Predicate,
2658  typename = typename std::enable_if<std::decay<Before>::type::foldable &&
2659  std::decay<After>::type::foldable>::type>
2660 HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2661  halide_type_t wildcard_type, halide_type_t output_type) noexcept {
2662 
2663  // We only validate the rules in the scalar case
2664  wildcard_type.lanes = output_type.lanes = 1;
2665 
2666  // Track which types this rule has been tested for before
2667  static std::set<uint32_t> tested;
2668 
2669  if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2670  return;
2671  }
2672 
2673  // Print it in a form where it can be piped into a python/z3 validator
2674  debug(0) << "validate('" << before << "', '" << after << "', '" << pred << "', " << Type(wildcard_type) << ", " << Type(output_type) << ")\n";
2675 
2676  // Substitute some random constants into the before and after
2677  // expressions and see if the rule holds true. This should catch
2678  // silly errors, but not necessarily corner cases.
2679  static std::mt19937_64 rng(0);
2680  MatcherState state;
2681 
2682  Expr exprs[max_wild];
2683 
2684  for (int trials = 0; trials < 100; trials++) {
2685  // We want to test small constants more frequently than
2686  // large ones, otherwise we'll just get coverage of
2687  // overflow rules.
2688  int shift = (int)(rng() & (wildcard_type.bits - 1));
2689 
2690  for (int i = 0; i < max_wild; i++) {
2691  // Bind all the exprs and constants
2692  switch (wildcard_type.code) {
2693  case halide_type_uint: {
2694  // Normalize to the type's range by adding zero
2695  uint64_t val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2696  state.set_bound_const(i, val, wildcard_type);
2697  val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2698  exprs[i] = make_const(wildcard_type, val);
2699  state.set_binding(i, *exprs[i].get());
2700  } break;
2701  case halide_type_int: {
2702  int64_t val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2703  state.set_bound_const(i, val, wildcard_type);
2704  val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2705  exprs[i] = make_const(wildcard_type, val);
2706  } break;
2707  case halide_type_float:
2708  case halide_type_bfloat: {
2709  // Use a very narrow range of precise floats, so
2710  // that none of the rules a human is likely to
2711  // write have instabilities.
2712  double val = ((int64_t)(rng() & 15) - 8) / 2.0;
2713  state.set_bound_const(i, val, wildcard_type);
2714  val = ((int64_t)(rng() & 15) - 8) / 2.0;
2715  exprs[i] = make_const(wildcard_type, val);
2716  } break;
2717  default:
2718  return; // Don't care about handles
2719  }
2720  state.set_binding(i, *exprs[i].get());
2721  }
2722 
2723  halide_scalar_value_t val_pred, val_before, val_after;
2724  halide_type_t type = output_type;
2725  if (!evaluate_predicate(pred, state)) {
2726  continue;
2727  }
2728  before.make_folded_const(val_before, type, state);
2729  uint16_t lanes = type.lanes;
2730  after.make_folded_const(val_after, type, state);
2731  lanes |= type.lanes;
2732 
2733  if (lanes & MatcherState::special_values_mask) {
2734  continue;
2735  }
2736 
2737  bool ok = true;
2738  switch (output_type.code) {
2739  case halide_type_uint:
2740  // Compare normalized representations
2741  ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.u64, 0) ==
2742  constant_fold_bin_op<Add>(output_type, val_after.u.u64, 0));
2743  break;
2744  case halide_type_int:
2745  ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.i64, 0) ==
2746  constant_fold_bin_op<Add>(output_type, val_after.u.i64, 0));
2747  break;
2748  case halide_type_float:
2749  case halide_type_bfloat: {
2750  double error = std::abs(val_before.u.f64 - val_after.u.f64);
2751  // We accept an equal bit pattern (e.g. inf vs inf),
2752  // a small floating point difference, or turning a nan into not-a-nan.
2753  ok &= (error < 0.01 ||
2754  val_before.u.u64 == val_after.u.u64 ||
2755  std::isnan(val_before.u.f64));
2756  break;
2757  }
2758  default:
2759  return;
2760  }
2761 
2762  if (!ok) {
2763  debug(0) << "Fails with values:\n";
2764  for (int i = 0; i < max_wild; i++) {
2766  state.get_bound_const(i, val, wildcard_type);
2767  debug(0) << " c" << i << ": " << make_const_expr(val, wildcard_type) << "\n";
2768  }
2769  for (int i = 0; i < max_wild; i++) {
2770  debug(0) << " _" << i << ": " << Expr(state.get_binding(i)) << "\n";
2771  }
2772  debug(0) << " Before: " << make_const_expr(val_before, output_type) << "\n";
2773  debug(0) << " After: " << make_const_expr(val_after, output_type) << "\n";
2774  debug(0) << val_before.u.u64 << " " << val_after.u.u64 << "\n";
2776  }
2777  }
2778 }
2779 
2780 template<typename Before,
2781  typename After,
2782  typename Predicate,
2783  typename = typename std::enable_if<!(std::decay<Before>::type::foldable &&
2784  std::decay<After>::type::foldable)>::type>
2785 HALIDE_ALWAYS_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2786  halide_type_t, halide_type_t, int dummy = 0) noexcept {
2787  // We can't verify rewrite rules that can't be constant-folded.
2788 }
2789 
2791 bool evaluate_predicate(bool x, MatcherState &) noexcept {
2792  return x;
2793 }
2794 
2795 template<typename Pattern,
2796  typename = typename enable_if_pattern<Pattern>::type>
2799  halide_type_t ty = halide_type_of<bool>();
2800  p.make_folded_const(c, ty, state);
2801  // Overflow counts as a failed predicate
2802  return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0);
2803 }
2804 
2805 // #defines for testing
2806 
2807 // Print all successful or failed matches
2808 #define HALIDE_DEBUG_MATCHED_RULES 0
2809 #define HALIDE_DEBUG_UNMATCHED_RULES 0
2810 
2811 // Set to true if you want to fuzz test every rewrite passed to
2812 // operator() to ensure the input and the output have the same value
2813 // for lots of random values of the wildcards. Run
2814 // correctness_simplify with this on.
2815 #define HALIDE_FUZZ_TEST_RULES 0
2816 
2817 template<typename Instance>
2818 struct Rewriter {
2819  Instance instance;
2823  bool validate;
2824 
2827  : instance(std::move(instance)), output_type(ot), wildcard_type(wt) {
2828  }
2829 
2830  template<typename After>
2832  result = after.make(state, output_type);
2833  }
2834 
2835  template<typename Before,
2836  typename After,
2837  typename = typename enable_if_pattern<Before>::type,
2838  typename = typename enable_if_pattern<After>::type>
2839  HALIDE_ALWAYS_INLINE bool operator()(Before before, After after) {
2840  static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2841  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2842  static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2843 #if HALIDE_FUZZ_TEST_RULES
2844  fuzz_test_rule(before, after, true, wildcard_type, output_type);
2845 #endif
2846  if (before.template match<0>(unwrap(instance), state)) {
2847  build_replacement(after);
2848 #if HALIDE_DEBUG_MATCHED_RULES
2849  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2850 #endif
2851  return true;
2852  } else {
2853 #if HALIDE_DEBUG_UNMATCHED_RULES
2854  debug(0) << instance << " does not match " << before << "\n";
2855 #endif
2856  return false;
2857  }
2858  }
2859 
2860  template<typename Before,
2861  typename = typename enable_if_pattern<Before>::type>
2862  HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept {
2863  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2864  if (before.template match<0>(unwrap(instance), state)) {
2865  result = after;
2866 #if HALIDE_DEBUG_MATCHED_RULES
2867  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2868 #endif
2869  return true;
2870  } else {
2871 #if HALIDE_DEBUG_UNMATCHED_RULES
2872  debug(0) << instance << " does not match " << before << "\n";
2873 #endif
2874  return false;
2875  }
2876  }
2877 
2878  template<typename Before,
2879  typename = typename enable_if_pattern<Before>::type>
2880  HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept {
2881  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2882 #if HALIDE_FUZZ_TEST_RULES
2883  fuzz_test_rule(before, IntLiteral(after), true, wildcard_type, output_type);
2884 #endif
2885  if (before.template match<0>(unwrap(instance), state)) {
2886  result = make_const(output_type, after);
2887 #if HALIDE_DEBUG_MATCHED_RULES
2888  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2889 #endif
2890  return true;
2891  } else {
2892 #if HALIDE_DEBUG_UNMATCHED_RULES
2893  debug(0) << instance << " does not match " << before << "\n";
2894 #endif
2895  return false;
2896  }
2897  }
2898 
2899  template<typename Before,
2900  typename After,
2901  typename Predicate,
2902  typename = typename enable_if_pattern<Before>::type,
2903  typename = typename enable_if_pattern<After>::type,
2904  typename = typename enable_if_pattern<Predicate>::type>
2905  HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred) {
2906  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2907  static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2908  static_assert((Before::binds & Predicate::binds) == Predicate::binds, "Rule predicate uses unbound values");
2909  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2910  static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2911 
2912 #if HALIDE_FUZZ_TEST_RULES
2913  fuzz_test_rule(before, after, pred, wildcard_type, output_type);
2914 #endif
2915  if (before.template match<0>(unwrap(instance), state) &&
2916  evaluate_predicate(pred, state)) {
2917  build_replacement(after);
2918 #if HALIDE_DEBUG_MATCHED_RULES
2919  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2920 #endif
2921  return true;
2922  } else {
2923 #if HALIDE_DEBUG_UNMATCHED_RULES
2924  debug(0) << instance << " does not match " << before << "\n";
2925 #endif
2926  return false;
2927  }
2928  }
2929 
2930  template<typename Before,
2931  typename Predicate,
2932  typename = typename enable_if_pattern<Before>::type,
2933  typename = typename enable_if_pattern<Predicate>::type>
2934  HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred) {
2935  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2936  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2937 
2938  if (before.template match<0>(unwrap(instance), state) &&
2939  evaluate_predicate(pred, state)) {
2940  result = after;
2941 #if HALIDE_DEBUG_MATCHED_RULES
2942  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2943 #endif
2944  return true;
2945  } else {
2946 #if HALIDE_DEBUG_UNMATCHED_RULES
2947  debug(0) << instance << " does not match " << before << "\n";
2948 #endif
2949  return false;
2950  }
2951  }
2952 
2953  template<typename Before,
2954  typename Predicate,
2955  typename = typename enable_if_pattern<Before>::type,
2956  typename = typename enable_if_pattern<Predicate>::type>
2957  HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred) {
2958  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2959  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2960 #if HALIDE_FUZZ_TEST_RULES
2961  fuzz_test_rule(before, IntLiteral(after), pred, wildcard_type, output_type);
2962 #endif
2963  if (before.template match<0>(unwrap(instance), state) &&
2964  evaluate_predicate(pred, state)) {
2965  result = make_const(output_type, after);
2966 #if HALIDE_DEBUG_MATCHED_RULES
2967  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2968 #endif
2969  return true;
2970  } else {
2971 #if HALIDE_DEBUG_UNMATCHED_RULES
2972  debug(0) << instance << " does not match " << before << "\n";
2973 #endif
2974  return false;
2975  }
2976  }
2977 };
2978 
2979 /** Construct a rewriter for the given instance, which may be a pattern
2980  * with concrete expressions as leaves, or just an expression. The
2981  * second optional argument (wildcard_type) is a hint as to what the
2982  * type of the wildcards is likely to be. If omitted it uses the same
2983  * type as the expression itself. They are not required to be this
2984  * type, but the rule will only be tested for wildcards of that type
2985  * when testing is enabled.
2986  *
2987  * The rewriter can be used to check to see if the instance is one of
2988  * some number of patterns and if so rewrite it into another form,
2989  * using its operator() method. See Simplify.cpp for a bunch of
2990  * example usage.
2991  *
2992  * Important: Any Exprs in patterns are captured by reference, not by
2993  * value, so ensure they outlive the rewriter.
2994  */
2995 // @{
2996 template<typename Instance,
2997  typename = typename enable_if_pattern<Instance>::type>
2998 HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
2999  return {pattern_arg(instance), output_type, wildcard_type};
3000 }
3001 
3002 template<typename Instance,
3003  typename = typename enable_if_pattern<Instance>::type>
3004 HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
3005  return {pattern_arg(instance), output_type, output_type};
3006 }
3007 
3009 auto rewriter(const Expr &e, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3010  return {pattern_arg(e), e.type(), wildcard_type};
3011 }
3012 
3014 auto rewriter(const Expr &e) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3015  return {pattern_arg(e), e.type(), e.type()};
3016 }
3017 // @}
3018 
3019 } // namespace IRMatcher
3020 
3021 } // namespace Internal
3022 } // namespace Halide
3023 
3024 #endif
#define internal_error
Definition: Errors.h:23
@ halide_type_float
IEEE floating point numbers.
@ halide_type_bfloat
floating point numbers in the bfloat format
@ halide_type_int
signed integers
@ halide_type_uint
unsigned integers
#define HALIDE_NEVER_INLINE
Definition: HalideRuntime.h:50
#define HALIDE_ALWAYS_INLINE
Definition: HalideRuntime.h:49
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Methods to test Exprs and Stmts for equality of value.
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
For optional debugging during codegen, use the debug class as follows:
Definition: Debug.h:49
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
Definition: IRMatch.h:229
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1598
auto shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1590
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
Definition: IRMatch.h:2998
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
Definition: IRMatch.h:579
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1539
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
Definition: IRMatch.h:1287
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
Definition: IRMatch.h:1655
HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp< Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1088
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
Definition: IRMatch.h:2791
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1044
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
Definition: IRMatch.h:1262
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
Definition: IRMatch.h:2032
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1182
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:933
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
Definition: IRMatch.h:2563
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
Definition: IRMatch.h:1313
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
Definition: IRMatch.h:1944
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
Definition: IRMatch.h:1162
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition: IRMatch.h:2310
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin< decltype(pattern_arg(args))... >
Definition: IRMatch.h:1534
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1192
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:999
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1586
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1602
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1547
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
Definition: IRMatch.h:940
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
Definition: IRMatch.h:1039
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1564
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
Definition: IRMatch.h:1006
HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp< Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1110
HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept -> SliceOp< decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))>
Definition: IRMatch.h:2142
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1868
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1032
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1560
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1073
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1318
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
Definition: IRMatch.h:571
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1157
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
Definition: IRMatch.h:2078
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
Definition: IRMatch.h:2228
auto widening_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1552
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
Definition: IRMatch.h:588
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1059
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:980
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
Definition: IRMatch.h:2522
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
Definition: IRMatch.h:2191
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
Definition: IRMatch.h:1661
auto halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1578
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1117
constexpr bool and_reduce()
Definition: IRMatch.h:1342
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1282
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1556
constexpr int max_wild
Definition: IRMatch.h:74
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1257
HALIDE_ALWAYS_INLINE bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept
Definition: IRMatch.h:195
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
Definition: IRMatch.h:2395
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1207
bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1132
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1308
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits=0, int lanes=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
Definition: IRMatch.h:2478
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
Definition: IRMatch.h:1950
constexpr bool commutative(IRNodeType t)
Definition: IRMatch.h:627
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1543
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
Definition: IRMatch.h:973
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
Definition: IRMatch.h:1938
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
Definition: IRMatch.h:1804
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits=0, int lanes=0) noexcept -> IsInt< decltype(pattern_arg(a))>
Definition: IRMatch.h:2433
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
Definition: IRMatch.h:1731
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
Definition: IRMatch.h:2606
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1095
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
Definition: IRMatch.h:2660
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1167
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1582
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1568
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1013
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1606
auto shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1594
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1217
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:966
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
Definition: IRMatch.h:1187
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
Definition: IRMatch.h:1137
HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition: IRMatch.h:2316
HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf< decltype(pattern_arg(a))>
Definition: IRMatch.h:2643
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1142
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
Definition: IRMatch.h:1932
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
Definition: IRMatch.h:1926
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1292
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
Definition: IRMatch.h:160
constexpr uint32_t bitwise_or_reduce()
Definition: IRMatch.h:1333
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1610
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1242
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
Definition: IRMatch.h:149
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
Definition: IRMatch.h:1212
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin< decltype(pattern_arg(a))>
Definition: IRMatch.h:1572
constexpr int const_min(int a, int b)
Definition: IRMatch.h:1352
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1267
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition: IRMatch.h:1066
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1232
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:947
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
Definition: IRMatch.h:2358
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
Definition: IRMatch.h:1237
T div_imp(T a, T b)
Definition: IROperator.h:260
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
Expr make_zero(Type t)
Construct the representation of zero in the given type.
void expr_match_test()
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
constexpr IRNodeType StrongestExprNodeType
Definition: Expr.h:81
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition: IROperator.h:239
bool sub_would_overflow(int bits, int64_t a, int64_t b)
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition: Expr.h:25
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
@ Predicate
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
@ C
No name mangling.
Expr likely_if_innermost(Expr e)
Equivalent to likely, but only triggers a loop partitioning if found in an innermost loop.
Expr abs(Expr a)
Returns the absolute value of a signed integer or floating-point expression.
Expr likely(Expr e)
Expressions tagged with this intrinsic are considered to be part of the steady state of some loop wit...
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
Definition: Expr.h:258
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition: Expr.h:322
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition: Expr.h:316
The sum of two expressions.
Definition: IR.h:56
Logical and - are both expressions true.
Definition: IR.h:175
A base class for expression nodes.
Definition: Expr.h:143
A vector with 'lanes' elements, in which every element is 'value'.
Definition: IR.h:259
static Expr make(Expr value, int lanes)
static const IRNodeType _node_type
Definition: IR.h:265
A function call.
Definition: IR.h:490
@ signed_integer_overflow
Definition: IR.h:594
@ rounding_mul_shift_right
Definition: IR.h:584
bool is_intrinsic() const
Definition: IR.h:707
static const IRNodeType _node_type
Definition: IR.h:752
The actual IR nodes begin here.
Definition: IR.h:30
static const IRNodeType _node_type
Definition: IR.h:35
The ratio of two expressions.
Definition: IR.h:83
Is the first expression equal to the second.
Definition: IR.h:121
Floating point constants.
Definition: Expr.h:236
static const FloatImm * make(Type t, double value)
Is the first expression greater than or equal to the second.
Definition: IR.h:166
Is the first expression greater than the second.
Definition: IR.h:157
constexpr static uint32_t binds
Definition: IRMatch.h:645
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:648
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:676
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:657
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:647
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition: IRMatch.h:719
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:667
constexpr static bool canonical
Definition: IRMatch.h:653
constexpr static bool foldable
Definition: IRMatch.h:673
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1770
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1764
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1752
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1746
constexpr static uint32_t binds
Definition: IRMatch.h:1744
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1787
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1747
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2347
constexpr static bool foldable
Definition: IRMatch.h:2344
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2340
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2341
constexpr static uint32_t binds
Definition: IRMatch.h:2337
constexpr static bool canonical
Definition: IRMatch.h:2342
constexpr static bool canonical
Definition: IRMatch.h:2047
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2046
constexpr static bool foldable
Definition: IRMatch.h:2068
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2050
constexpr static uint32_t binds
Definition: IRMatch.h:2043
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2045
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:2059
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2064
constexpr static bool canonical
Definition: IRMatch.h:760
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:820
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:758
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:759
constexpr static bool foldable
Definition: IRMatch.h:783
constexpr static uint32_t binds
Definition: IRMatch.h:756
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:767
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:786
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:777
constexpr static bool foldable
Definition: IRMatch.h:2182
constexpr static uint32_t binds
Definition: IRMatch.h:2152
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2154
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2155
constexpr static bool canonical
Definition: IRMatch.h:2156
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition: IRMatch.h:2159
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2185
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:516
constexpr static bool canonical
Definition: IRMatch.h:508
constexpr static uint32_t binds
Definition: IRMatch.h:504
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
Definition: IRMatch.h:511
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
Definition: IRMatch.h:539
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:551
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:506
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:507
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:544
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
Definition: IRMatch.h:534
constexpr static bool foldable
Definition: IRMatch.h:548
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
Definition: IRMatch.h:1382
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1417
constexpr static bool canonical
Definition: IRMatch.h:1370
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
Definition: IRMatch.h:1412
constexpr static bool foldable
Definition: IRMatch.h:1475
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1477
static constexpr uint32_t binds
Definition: IRMatch.h:1366
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
Definition: IRMatch.h:1375
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
Definition: IRMatch.h:1399
std::tuple< Args... > args
Definition: IRMatch.h:1360
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1387
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
Definition: IRMatch.h:1408
HALIDE_ALWAYS_INLINE Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
Definition: IRMatch.h:1520
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1369
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1368
constexpr static bool canonical
Definition: IRMatch.h:2287
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2296
constexpr static bool foldable
Definition: IRMatch.h:2293
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2286
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2285
constexpr static uint32_t binds
Definition: IRMatch.h:2282
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2377
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2384
constexpr static bool canonical
Definition: IRMatch.h:2379
constexpr static uint32_t binds
Definition: IRMatch.h:2374
constexpr static bool foldable
Definition: IRMatch.h:2381
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2378
constexpr static uint32_t binds
Definition: IRMatch.h:2412
constexpr static bool foldable
Definition: IRMatch.h:2419
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2415
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2422
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2416
constexpr static bool canonical
Definition: IRMatch.h:2417
constexpr static bool canonical
Definition: IRMatch.h:2543
constexpr static bool foldable
Definition: IRMatch.h:2545
constexpr static uint32_t binds
Definition: IRMatch.h:2538
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2541
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2542
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2548
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2582
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2583
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2589
constexpr static bool canonical
Definition: IRMatch.h:2584
constexpr static uint32_t binds
Definition: IRMatch.h:2579
constexpr static bool foldable
Definition: IRMatch.h:2586
constexpr static bool foldable
Definition: IRMatch.h:2508
constexpr static bool canonical
Definition: IRMatch.h:2506
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2511
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2505
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2504
constexpr static uint32_t binds
Definition: IRMatch.h:2501
constexpr static bool canonical
Definition: IRMatch.h:2462
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2467
constexpr static uint32_t binds
Definition: IRMatch.h:2457
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2460
constexpr static bool foldable
Definition: IRMatch.h:2464
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2461
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2632
constexpr static bool foldable
Definition: IRMatch.h:2629
constexpr static bool canonical
Definition: IRMatch.h:2627
constexpr static uint32_t binds
Definition: IRMatch.h:2622
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2625
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2626
To save stack space, the matcher objects are largely stateless and immutable.
Definition: IRMatch.h:82
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
Definition: IRMatch.h:127
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
Definition: IRMatch.h:103
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
Definition: IRMatch.h:115
static constexpr uint16_t special_values_mask
Definition: IRMatch.h:88
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
Definition: IRMatch.h:121
halide_type_t bound_const_type[max_wild]
Definition: IRMatch.h:90
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
Definition: IRMatch.h:93
HALIDE_ALWAYS_INLINE MatcherState() noexcept
Definition: IRMatch.h:134
halide_scalar_value_t bound_const[max_wild]
Definition: IRMatch.h:84
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
Definition: IRMatch.h:98
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
Definition: IRMatch.h:109
static constexpr uint16_t signed_integer_overflow
Definition: IRMatch.h:87
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1962
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1963
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1968
constexpr static uint32_t binds
Definition: IRMatch.h:1960
constexpr static bool canonical
Definition: IRMatch.h:1965
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1983
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
Definition: IRMatch.h:1978
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1992
constexpr static bool foldable
Definition: IRMatch.h:1989
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1621
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1626
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1635
constexpr static uint32_t binds
Definition: IRMatch.h:1619
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1622
constexpr static bool foldable
Definition: IRMatch.h:1644
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1640
constexpr static bool canonical
Definition: IRMatch.h:1623
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1647
constexpr static bool canonical
Definition: IRMatch.h:2247
constexpr static bool foldable
Definition: IRMatch.h:2264
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2250
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2259
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2246
constexpr static uint32_t binds
Definition: IRMatch.h:2242
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2267
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2245
constexpr static bool foldable
Definition: IRMatch.h:2215
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2218
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2212
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2211
constexpr static uint32_t binds
Definition: IRMatch.h:2207
constexpr static bool canonical
Definition: IRMatch.h:2213
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1846
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1819
constexpr static bool canonical
Definition: IRMatch.h:1821
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1818
constexpr static bool foldable
Definition: IRMatch.h:1858
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1839
constexpr static uint32_t binds
Definition: IRMatch.h:1816
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1824
HALIDE_NEVER_INLINE void build_replacement(After after)
Definition: IRMatch.h:2831
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
Definition: IRMatch.h:2905
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
Definition: IRMatch.h:2880
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
Definition: IRMatch.h:2826
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
Definition: IRMatch.h:2934
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
Definition: IRMatch.h:2862
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
Definition: IRMatch.h:2957
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
Definition: IRMatch.h:2839
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1682
constexpr static bool canonical
Definition: IRMatch.h:1684
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1711
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1681
constexpr static uint32_t binds
Definition: IRMatch.h:1679
constexpr static bool foldable
Definition: IRMatch.h:1708
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
Definition: IRMatch.h:1697
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1687
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1704
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2094
constexpr static bool foldable
Definition: IRMatch.h:2124
HALIDE_ALWAYS_INLINE SliceOp(Vec v, Base b, Stride s, Lanes l)
Definition: IRMatch.h:2127
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2098
static constexpr uint32_t binds
Definition: IRMatch.h:2091
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2093
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2112
constexpr static bool canonical
Definition: IRMatch.h:2095
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:210
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:217
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:222
constexpr static uint32_t binds
Definition: IRMatch.h:207
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:211
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1901
constexpr static uint32_t binds
Definition: IRMatch.h:1881
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1884
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1888
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1883
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1908
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:364
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:360
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:385
constexpr static uint32_t binds
Definition: IRMatch.h:357
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:359
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:395
constexpr static uint32_t binds
Definition: IRMatch.h:411
constexpr static bool foldable
Definition: IRMatch.h:450
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:443
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:414
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:418
constexpr static bool canonical
Definition: IRMatch.h:415
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:453
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:413
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
Definition: IRMatch.h:437
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:240
constexpr static uint32_t binds
Definition: IRMatch.h:238
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:279
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:289
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:245
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
Definition: IRMatch.h:266
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:241
constexpr static uint32_t binds
Definition: IRMatch.h:304
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:307
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:306
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:311
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:342
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:332
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:472
constexpr static bool foldable
Definition: IRMatch.h:489
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:485
constexpr static bool canonical
Definition: IRMatch.h:473
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:471
constexpr static uint32_t binds
Definition: IRMatch.h:469
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:476
constexpr static uint32_t mask
Definition: IRMatch.h:146
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition: Expr.h:113
Integer constants.
Definition: Expr.h:218
static const IntImm * make(Type t, int64_t value)
Is the first expression less than or equal to the second.
Definition: IR.h:148
Is the first expression less than the second.
Definition: IR.h:139
The greater of two values.
Definition: IR.h:112
The lesser of two values.
Definition: IR.h:103
The remainder of a / b.
Definition: IR.h:94
The product of two expressions.
Definition: IR.h:74
Is the first expression not equal to the second.
Definition: IR.h:130
Logical not - true if the expression false.
Definition: IR.h:193
static Expr make(Expr a)
Logical or - is at least one of the expression true.
Definition: IR.h:184
A linear ramp vector node.
Definition: IR.h:247
static const IRNodeType _node_type
Definition: IR.h:253
static Expr make(Expr base, Expr stride, int lanes)
A ternary operator.
Definition: IR.h:204
static Expr make(Expr condition, Expr true_value, Expr false_value)
static const IRNodeType _node_type
Definition: IR.h:209
Construct a new vector by taking elements from another sequence of vectors.
Definition: IR.h:841
static Expr make_slice(Expr vector, int begin, int stride, int size)
Convenience constructor for making a shuffle representing a contiguous subset of a vector.
std::vector< Expr > vectors
Definition: IR.h:842
bool is_slice() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
int slice_stride() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition: IR.h:896
int slice_begin() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition: IR.h:893
The difference of two expressions.
Definition: IR.h:65
static const IRNodeType _node_type
Definition: IR.h:70
static Expr make(Expr a, Expr b)
Unsigned integer constants.
Definition: Expr.h:227
static const UIntImm * make(Type t, uint64_t value)
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition: IR.h:966
static const IRNodeType _node_type
Definition: IR.h:985
static Expr make(Operator op, Expr vec, int lanes)
Types in the halide type system.
Definition: Type.h:276
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
Definition: Type.h:428
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
Definition: Type.h:348
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
Definition: Type.h:434
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
Definition: Type.h:342
HALIDE_ALWAYS_INLINE bool is_vector() const
Is this type a vector type? (lanes() != 1).
Definition: Type.h:403
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
Definition: Type.h:410
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
Definition: Type.h:416
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
union halide_scalar_value_t::@4 u
A runtime tag for a type in the halide type system.
uint8_t bits
The number of bits of precision of a single scalar value of this type.
uint16_t lanes
How many elements in a vector.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.