Halide  17.0.2
Halide compiler and libraries
IRVisitor.h
Go to the documentation of this file.
1 #ifndef HALIDE_IR_VISITOR_H
2 #define HALIDE_IR_VISITOR_H
3 
4 #include <set>
5 
6 #include "IR.h"
7 
8 /** \file
9  * Defines the base class for things that recursively walk over the IR
10  */
11 
12 namespace Halide {
13 namespace Internal {
14 
15 /** A base class for algorithms that need to recursively walk over the
16  * IR. The default implementations just recursively walk over the
17  * children. Override the ones you care about.
18  */
19 class IRVisitor {
20 public:
21  IRVisitor() = default;
22  virtual ~IRVisitor() = default;
23 
24 protected:
25  // ExprNode<> and StmtNode<> are allowed to call visit (to implement accept())
26  template<typename T>
27  friend struct ExprNode;
28 
29  template<typename T>
30  friend struct StmtNode;
31 
32  virtual void visit(const IntImm *);
33  virtual void visit(const UIntImm *);
34  virtual void visit(const FloatImm *);
35  virtual void visit(const StringImm *);
36  virtual void visit(const Cast *);
37  virtual void visit(const Reinterpret *);
38  virtual void visit(const Variable *);
39  virtual void visit(const Add *);
40  virtual void visit(const Sub *);
41  virtual void visit(const Mul *);
42  virtual void visit(const Div *);
43  virtual void visit(const Mod *);
44  virtual void visit(const Min *);
45  virtual void visit(const Max *);
46  virtual void visit(const EQ *);
47  virtual void visit(const NE *);
48  virtual void visit(const LT *);
49  virtual void visit(const LE *);
50  virtual void visit(const GT *);
51  virtual void visit(const GE *);
52  virtual void visit(const And *);
53  virtual void visit(const Or *);
54  virtual void visit(const Not *);
55  virtual void visit(const Select *);
56  virtual void visit(const Load *);
57  virtual void visit(const Ramp *);
58  virtual void visit(const Broadcast *);
59  virtual void visit(const Call *);
60  virtual void visit(const Let *);
61  virtual void visit(const LetStmt *);
62  virtual void visit(const AssertStmt *);
63  virtual void visit(const ProducerConsumer *);
64  virtual void visit(const For *);
65  virtual void visit(const Store *);
66  virtual void visit(const Provide *);
67  virtual void visit(const Allocate *);
68  virtual void visit(const Free *);
69  virtual void visit(const Realize *);
70  virtual void visit(const Block *);
71  virtual void visit(const IfThenElse *);
72  virtual void visit(const Evaluate *);
73  virtual void visit(const Shuffle *);
74  virtual void visit(const VectorReduce *);
75  virtual void visit(const Prefetch *);
76  virtual void visit(const Fork *);
77  virtual void visit(const Acquire *);
78  virtual void visit(const Atomic *);
79  virtual void visit(const HoistedStorage *);
80 };
81 
82 /** A base class for algorithms that walk recursively over the IR
83  * without visiting the same node twice. This is for passes that are
84  * capable of interpreting the IR as a DAG instead of a tree. */
85 class IRGraphVisitor : public IRVisitor {
86 protected:
87  /** By default these methods add the node to the visited set, and
88  * return whether or not it was already there. If it wasn't there,
89  * it delegates to the appropriate visit method. You can override
90  * them if you like. */
91  // @{
92  virtual void include(const Expr &);
93  virtual void include(const Stmt &);
94  // @}
95 
96 private:
97  /** The nodes visited so far */
98  std::set<IRHandle> visited;
99 
100 protected:
101  /** These methods should call 'include' on the children to only
102  * visit them if they haven't been visited already. */
103  // @{
104  void visit(const IntImm *) override;
105  void visit(const UIntImm *) override;
106  void visit(const FloatImm *) override;
107  void visit(const StringImm *) override;
108  void visit(const Cast *) override;
109  void visit(const Reinterpret *) override;
110  void visit(const Variable *) override;
111  void visit(const Add *) override;
112  void visit(const Sub *) override;
113  void visit(const Mul *) override;
114  void visit(const Div *) override;
115  void visit(const Mod *) override;
116  void visit(const Min *) override;
117  void visit(const Max *) override;
118  void visit(const EQ *) override;
119  void visit(const NE *) override;
120  void visit(const LT *) override;
121  void visit(const LE *) override;
122  void visit(const GT *) override;
123  void visit(const GE *) override;
124  void visit(const And *) override;
125  void visit(const Or *) override;
126  void visit(const Not *) override;
127  void visit(const Select *) override;
128  void visit(const Load *) override;
129  void visit(const Ramp *) override;
130  void visit(const Broadcast *) override;
131  void visit(const Call *) override;
132  void visit(const Let *) override;
133  void visit(const LetStmt *) override;
134  void visit(const AssertStmt *) override;
135  void visit(const ProducerConsumer *) override;
136  void visit(const For *) override;
137  void visit(const Store *) override;
138  void visit(const Provide *) override;
139  void visit(const Allocate *) override;
140  void visit(const Free *) override;
141  void visit(const Realize *) override;
142  void visit(const Block *) override;
143  void visit(const IfThenElse *) override;
144  void visit(const Evaluate *) override;
145  void visit(const Shuffle *) override;
146  void visit(const VectorReduce *) override;
147  void visit(const Prefetch *) override;
148  void visit(const Acquire *) override;
149  void visit(const Fork *) override;
150  void visit(const Atomic *) override;
151  void visit(const HoistedStorage *) override;
152  // @}
153 };
154 
155 /** A visitor/mutator capable of passing arbitrary arguments to the
156  * visit methods using CRTP and returning any types from them. All
157  * Expr visitors must have the same signature, and all Stmt visitors
158  * must have the same signature. Does not have default implementations
159  * of the visit methods. */
160 template<typename T, typename ExprRet, typename StmtRet>
162 private:
163  template<typename... Args>
164  ExprRet dispatch_expr(const BaseExprNode *node, Args &&...args) {
165  if (node == nullptr) {
166  return ExprRet{};
167  }
168  switch (node->node_type) {
169  case IRNodeType::IntImm:
170  return ((T *)this)->visit((const IntImm *)node, std::forward<Args>(args)...);
171  case IRNodeType::UIntImm:
172  return ((T *)this)->visit((const UIntImm *)node, std::forward<Args>(args)...);
174  return ((T *)this)->visit((const FloatImm *)node, std::forward<Args>(args)...);
176  return ((T *)this)->visit((const StringImm *)node, std::forward<Args>(args)...);
178  return ((T *)this)->visit((const Broadcast *)node, std::forward<Args>(args)...);
179  case IRNodeType::Cast:
180  return ((T *)this)->visit((const Cast *)node, std::forward<Args>(args)...);
182  return ((T *)this)->visit((const Reinterpret *)node, std::forward<Args>(args)...);
184  return ((T *)this)->visit((const Variable *)node, std::forward<Args>(args)...);
185  case IRNodeType::Add:
186  return ((T *)this)->visit((const Add *)node, std::forward<Args>(args)...);
187  case IRNodeType::Sub:
188  return ((T *)this)->visit((const Sub *)node, std::forward<Args>(args)...);
189  case IRNodeType::Mod:
190  return ((T *)this)->visit((const Mod *)node, std::forward<Args>(args)...);
191  case IRNodeType::Mul:
192  return ((T *)this)->visit((const Mul *)node, std::forward<Args>(args)...);
193  case IRNodeType::Div:
194  return ((T *)this)->visit((const Div *)node, std::forward<Args>(args)...);
195  case IRNodeType::Min:
196  return ((T *)this)->visit((const Min *)node, std::forward<Args>(args)...);
197  case IRNodeType::Max:
198  return ((T *)this)->visit((const Max *)node, std::forward<Args>(args)...);
199  case IRNodeType::EQ:
200  return ((T *)this)->visit((const EQ *)node, std::forward<Args>(args)...);
201  case IRNodeType::NE:
202  return ((T *)this)->visit((const NE *)node, std::forward<Args>(args)...);
203  case IRNodeType::LT:
204  return ((T *)this)->visit((const LT *)node, std::forward<Args>(args)...);
205  case IRNodeType::LE:
206  return ((T *)this)->visit((const LE *)node, std::forward<Args>(args)...);
207  case IRNodeType::GT:
208  return ((T *)this)->visit((const GT *)node, std::forward<Args>(args)...);
209  case IRNodeType::GE:
210  return ((T *)this)->visit((const GE *)node, std::forward<Args>(args)...);
211  case IRNodeType::And:
212  return ((T *)this)->visit((const And *)node, std::forward<Args>(args)...);
213  case IRNodeType::Or:
214  return ((T *)this)->visit((const Or *)node, std::forward<Args>(args)...);
215  case IRNodeType::Not:
216  return ((T *)this)->visit((const Not *)node, std::forward<Args>(args)...);
217  case IRNodeType::Select:
218  return ((T *)this)->visit((const Select *)node, std::forward<Args>(args)...);
219  case IRNodeType::Load:
220  return ((T *)this)->visit((const Load *)node, std::forward<Args>(args)...);
221  case IRNodeType::Ramp:
222  return ((T *)this)->visit((const Ramp *)node, std::forward<Args>(args)...);
223  case IRNodeType::Call:
224  return ((T *)this)->visit((const Call *)node, std::forward<Args>(args)...);
225  case IRNodeType::Let:
226  return ((T *)this)->visit((const Let *)node, std::forward<Args>(args)...);
227  case IRNodeType::Shuffle:
228  return ((T *)this)->visit((const Shuffle *)node, std::forward<Args>(args)...);
230  return ((T *)this)->visit((const VectorReduce *)node, std::forward<Args>(args)...);
231  // Explicitly list the Stmt types rather than using a
232  // default case so that when new IR nodes are added we
233  // don't miss them here.
234  case IRNodeType::LetStmt:
237  case IRNodeType::For:
238  case IRNodeType::Acquire:
239  case IRNodeType::Store:
240  case IRNodeType::Provide:
242  case IRNodeType::Free:
243  case IRNodeType::Realize:
244  case IRNodeType::Block:
245  case IRNodeType::Fork:
249  case IRNodeType::Atomic:
251  internal_error << "Unreachable";
252  }
253  return ExprRet{};
254  }
255 
256  template<typename... Args>
257  StmtRet dispatch_stmt(const BaseStmtNode *node, Args &&...args) {
258  if (node == nullptr) {
259  return StmtRet{};
260  }
261  switch (node->node_type) {
262  case IRNodeType::IntImm:
263  case IRNodeType::UIntImm:
267  case IRNodeType::Cast:
270  case IRNodeType::Add:
271  case IRNodeType::Sub:
272  case IRNodeType::Mod:
273  case IRNodeType::Mul:
274  case IRNodeType::Div:
275  case IRNodeType::Min:
276  case IRNodeType::Max:
277  case IRNodeType::EQ:
278  case IRNodeType::NE:
279  case IRNodeType::LT:
280  case IRNodeType::LE:
281  case IRNodeType::GT:
282  case IRNodeType::GE:
283  case IRNodeType::And:
284  case IRNodeType::Or:
285  case IRNodeType::Not:
286  case IRNodeType::Select:
287  case IRNodeType::Load:
288  case IRNodeType::Ramp:
289  case IRNodeType::Call:
290  case IRNodeType::Let:
291  case IRNodeType::Shuffle:
293  internal_error << "Unreachable";
294  break;
295  case IRNodeType::LetStmt:
296  return ((T *)this)->visit((const LetStmt *)node, std::forward<Args>(args)...);
298  return ((T *)this)->visit((const AssertStmt *)node, std::forward<Args>(args)...);
300  return ((T *)this)->visit((const ProducerConsumer *)node, std::forward<Args>(args)...);
301  case IRNodeType::For:
302  return ((T *)this)->visit((const For *)node, std::forward<Args>(args)...);
303  case IRNodeType::Acquire:
304  return ((T *)this)->visit((const Acquire *)node, std::forward<Args>(args)...);
305  case IRNodeType::Store:
306  return ((T *)this)->visit((const Store *)node, std::forward<Args>(args)...);
307  case IRNodeType::Provide:
308  return ((T *)this)->visit((const Provide *)node, std::forward<Args>(args)...);
310  return ((T *)this)->visit((const Allocate *)node, std::forward<Args>(args)...);
311  case IRNodeType::Free:
312  return ((T *)this)->visit((const Free *)node, std::forward<Args>(args)...);
313  case IRNodeType::Realize:
314  return ((T *)this)->visit((const Realize *)node, std::forward<Args>(args)...);
315  case IRNodeType::Block:
316  return ((T *)this)->visit((const Block *)node, std::forward<Args>(args)...);
317  case IRNodeType::Fork:
318  return ((T *)this)->visit((const Fork *)node, std::forward<Args>(args)...);
320  return ((T *)this)->visit((const IfThenElse *)node, std::forward<Args>(args)...);
322  return ((T *)this)->visit((const Evaluate *)node, std::forward<Args>(args)...);
324  return ((T *)this)->visit((const Prefetch *)node, std::forward<Args>(args)...);
325  case IRNodeType::Atomic:
326  return ((T *)this)->visit((const Atomic *)node, std::forward<Args>(args)...);
328  return ((T *)this)->visit((const HoistedStorage *)node, std::forward<Args>(args)...);
329  }
330  return StmtRet{};
331  }
332 
333 public:
334  template<typename... Args>
335  HALIDE_ALWAYS_INLINE StmtRet dispatch(const Stmt &s, Args &&...args) {
336  return dispatch_stmt(s.get(), std::forward<Args>(args)...);
337  }
338 
339  template<typename... Args>
340  HALIDE_ALWAYS_INLINE StmtRet dispatch(Stmt &&s, Args &&...args) {
341  return dispatch_stmt(s.get(), std::forward<Args>(args)...);
342  }
343 
344  template<typename... Args>
345  HALIDE_ALWAYS_INLINE ExprRet dispatch(const Expr &e, Args &&...args) {
346  return dispatch_expr(e.get(), std::forward<Args>(args)...);
347  }
348 
349  template<typename... Args>
350  HALIDE_ALWAYS_INLINE ExprRet dispatch(Expr &&e, Args &&...args) {
351  return dispatch_expr(e.get(), std::forward<Args>(args)...);
352  }
353 };
354 
355 } // namespace Internal
356 } // namespace Halide
357 
358 #endif
#define internal_error
Definition: Errors.h:23
#define HALIDE_ALWAYS_INLINE
Definition: HalideRuntime.h:49
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
A base class for algorithms that walk recursively over the IR without visiting the same node twice.
Definition: IRVisitor.h:85
void visit(const Div *) override
void visit(const Shuffle *) override
void visit(const NE *) override
void visit(const Block *) override
void visit(const EQ *) override
void visit(const Let *) override
void visit(const Provide *) override
void visit(const StringImm *) override
virtual void include(const Expr &)
By default these methods add the node to the visited set, and return whether or not it was already th...
void visit(const For *) override
void visit(const HoistedStorage *) override
void visit(const Ramp *) override
void visit(const Or *) override
void visit(const UIntImm *) override
void visit(const Mul *) override
void visit(const AssertStmt *) override
void visit(const GE *) override
void visit(const Min *) override
void visit(const Free *) override
void visit(const Add *) override
void visit(const Acquire *) override
void visit(const Store *) override
void visit(const Max *) override
void visit(const IntImm *) override
These methods should call 'include' on the children to only visit them if they haven't been visited a...
void visit(const IfThenElse *) override
void visit(const LT *) override
void visit(const VectorReduce *) override
void visit(const Atomic *) override
void visit(const Sub *) override
void visit(const Not *) override
void visit(const Mod *) override
void visit(const ProducerConsumer *) override
void visit(const LetStmt *) override
void visit(const LE *) override
void visit(const Allocate *) override
void visit(const Load *) override
virtual void include(const Stmt &)
void visit(const Realize *) override
void visit(const Prefetch *) override
void visit(const FloatImm *) override
void visit(const Fork *) override
void visit(const Call *) override
void visit(const Reinterpret *) override
void visit(const And *) override
void visit(const Variable *) override
void visit(const Evaluate *) override
void visit(const Broadcast *) override
void visit(const GT *) override
void visit(const Cast *) override
void visit(const Select *) override
A base class for algorithms that need to recursively walk over the IR.
Definition: IRVisitor.h:19
virtual void visit(const NE *)
virtual void visit(const Mul *)
virtual void visit(const Max *)
virtual void visit(const Select *)
virtual void visit(const Load *)
virtual void visit(const Div *)
virtual void visit(const Fork *)
virtual void visit(const Sub *)
virtual void visit(const LE *)
virtual ~IRVisitor()=default
virtual void visit(const ProducerConsumer *)
virtual void visit(const VectorReduce *)
virtual void visit(const GE *)
virtual void visit(const StringImm *)
virtual void visit(const Allocate *)
virtual void visit(const IfThenElse *)
virtual void visit(const For *)
virtual void visit(const Prefetch *)
virtual void visit(const Block *)
virtual void visit(const UIntImm *)
virtual void visit(const HoistedStorage *)
virtual void visit(const FloatImm *)
virtual void visit(const GT *)
virtual void visit(const Mod *)
virtual void visit(const Acquire *)
virtual void visit(const Atomic *)
virtual void visit(const Ramp *)
virtual void visit(const Free *)
virtual void visit(const IntImm *)
virtual void visit(const Or *)
virtual void visit(const EQ *)
virtual void visit(const Broadcast *)
virtual void visit(const Call *)
virtual void visit(const Min *)
virtual void visit(const Variable *)
virtual void visit(const Realize *)
virtual void visit(const Add *)
virtual void visit(const Shuffle *)
virtual void visit(const Reinterpret *)
virtual void visit(const Evaluate *)
virtual void visit(const AssertStmt *)
virtual void visit(const And *)
virtual void visit(const LetStmt *)
virtual void visit(const Store *)
virtual void visit(const Provide *)
virtual void visit(const LT *)
virtual void visit(const Cast *)
virtual void visit(const Not *)
virtual void visit(const Let *)
A visitor/mutator capable of passing arbitrary arguments to the visit methods using CRTP and returnin...
Definition: IRVisitor.h:161
HALIDE_ALWAYS_INLINE StmtRet dispatch(const Stmt &s, Args &&...args)
Definition: IRVisitor.h:335
HALIDE_ALWAYS_INLINE ExprRet dispatch(Expr &&e, Args &&...args)
Definition: IRVisitor.h:350
HALIDE_ALWAYS_INLINE StmtRet dispatch(Stmt &&s, Args &&...args)
Definition: IRVisitor.h:340
HALIDE_ALWAYS_INLINE ExprRet dispatch(const Expr &e, Args &&...args)
Definition: IRVisitor.h:345
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
A fragment of Halide syntax.
Definition: Expr.h:258
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition: Expr.h:316
The sum of two expressions.
Definition: IR.h:56
Allocate a scratch area called with the given name, type, and size.
Definition: IR.h:371
Logical and - are both expressions true.
Definition: IR.h:175
If the 'condition' is false, then evaluate and return the message, which should be a call to an error...
Definition: IR.h:294
Lock all the Store nodes in the body statement.
Definition: IR.h:948
A base class for expression nodes.
Definition: Expr.h:143
IR nodes are split into expressions and statements.
Definition: Expr.h:134
A sequence of statements to be executed in-order.
Definition: IR.h:442
A vector with 'lanes' elements, in which every element is 'value'.
Definition: IR.h:259
A function call.
Definition: IR.h:490
The actual IR nodes begin here.
Definition: IR.h:30
The ratio of two expressions.
Definition: IR.h:83
Is the first expression equal to the second.
Definition: IR.h:121
Evaluate and discard an expression, presumably because it has some side-effect.
Definition: IR.h:476
We use the "curiously recurring template pattern" to avoid duplicated code in the IR Nodes.
Definition: Expr.h:158
Floating point constants.
Definition: Expr.h:236
A for loop.
Definition: IR.h:805
A pair of statements executed concurrently.
Definition: IR.h:457
Free the resources associated with the given buffer.
Definition: IR.h:413
Is the first expression greater than or equal to the second.
Definition: IR.h:166
Is the first expression greater than the second.
Definition: IR.h:157
Represents a location where storage will be hoisted to for a Func / Realize node with a given name.
Definition: IR.h:932
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition: Expr.h:113
An if-then-else block.
Definition: IR.h:466
Integer constants.
Definition: Expr.h:218
Is the first expression less than or equal to the second.
Definition: IR.h:148
Is the first expression less than the second.
Definition: IR.h:139
A let expression, like you might find in a functional language.
Definition: IR.h:271
The statement form of a let node.
Definition: IR.h:282
Load a value from a named symbol if predicate is true.
Definition: IR.h:217
The greater of two values.
Definition: IR.h:112
The lesser of two values.
Definition: IR.h:103
The remainder of a / b.
Definition: IR.h:94
The product of two expressions.
Definition: IR.h:74
Is the first expression not equal to the second.
Definition: IR.h:130
Logical not - true if the expression false.
Definition: IR.h:193
Logical or - is at least one of the expression true.
Definition: IR.h:184
Represent a multi-dimensional region of a Func or an ImageParam that needs to be prefetched.
Definition: IR.h:910
This node is a helpful annotation to do with permissions.
Definition: IR.h:315
This defines the value of a function at a multi-dimensional location.
Definition: IR.h:354
A linear ramp vector node.
Definition: IR.h:247
Allocate a multi-dimensional buffer of the given type and size.
Definition: IR.h:427
Reinterpret value as another type, without affecting any of the bits (on little-endian systems).
Definition: IR.h:47
A ternary operator.
Definition: IR.h:204
Construct a new vector by taking elements from another sequence of vectors.
Definition: IR.h:841
A reference-counted handle to a statement node.
Definition: Expr.h:419
HALIDE_ALWAYS_INLINE const BaseStmtNode * get() const
Override get() to return a BaseStmtNode * instead of an IRNode *.
Definition: Expr.h:427
Store a 'value' to the buffer called 'name' at a given 'index' if 'predicate' is true.
Definition: IR.h:333
String constants.
Definition: Expr.h:245
The difference of two expressions.
Definition: IR.h:65
Unsigned integer constants.
Definition: Expr.h:227
A named variable.
Definition: IR.h:758
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition: IR.h:966