Halide  17.0.2
Halide compiler and libraries
Expr.h
Go to the documentation of this file.
1 #ifndef HALIDE_EXPR_H
2 #define HALIDE_EXPR_H
3 
4 /** \file
5  * Base classes for Halide expressions (\ref Halide::Expr) and statements (\ref Halide::Internal::Stmt)
6  */
7 
8 #include <string>
9 #include <vector>
10 
11 #include "IntrusivePtr.h"
12 #include "Type.h"
13 
14 namespace Halide {
15 
16 struct bfloat16_t;
17 struct float16_t;
18 
19 namespace Internal {
20 
21 class IRMutator;
22 class IRVisitor;
23 
24 /** All our IR node types get unique IDs for the purposes of RTTI */
25 enum class IRNodeType {
26  // Exprs, in order of strength. Code in IRMatch.h and the
27  // simplifier relies on this order for canonicalization of
28  // expressions, so you may need to update those modules if you
29  // change this list.
30  IntImm,
31  UIntImm,
32  FloatImm,
33  StringImm,
34  Broadcast,
35  Cast,
37  Variable,
38  Add,
39  Sub,
40  Mod,
41  Mul,
42  Div,
43  Min,
44  Max,
45  EQ,
46  NE,
47  LT,
48  LE,
49  GT,
50  GE,
51  And,
52  Or,
53  Not,
54  Select,
55  Load,
56  Ramp,
57  Call,
58  Let,
59  Shuffle,
61  // Stmts
62  LetStmt,
63  AssertStmt,
65  For,
66  Acquire,
67  Store,
68  Provide,
69  Allocate,
70  Free,
71  Realize,
72  Block,
73  Fork,
74  IfThenElse,
75  Evaluate,
76  Prefetch,
77  Atomic,
79 };
80 
82 
83 /** The abstract base classes for a node in the Halide IR. */
84 struct IRNode {
85 
86  /** We use the visitor pattern to traverse IR nodes throughout the
87  * compiler, so we have a virtual accept method which accepts
88  * visitors.
89  */
90  virtual void accept(IRVisitor *v) const = 0;
92  : node_type(t) {
93  }
94  virtual ~IRNode() = default;
95 
96  /** These classes are all managed with intrusive reference
97  * counting, so we also track a reference count. It's mutable
98  * so that we can do reference counting even through const
99  * references to IR nodes.
100  */
102 
103  /** Each IR node subclass has a unique identifier. We can compare
104  * these values to do runtime type identification. We don't
105  * compile with rtti because that injects run-time type
106  * identification stuff everywhere (and often breaks when linking
107  * external libraries compiled without it), and we only want it
108  * for IR nodes. One might want to put this value in the vtable,
109  * but that adds another level of indirection, and for Exprs we
110  * have 32 free bits in between the ref count and the Type
111  * anyway, so this doesn't increase the memory footprint of an IR node.
112  */
114 };
115 
116 template<>
117 inline RefCount &ref_count<IRNode>(const IRNode *t) noexcept {
118  return t->ref_count;
119 }
120 
121 template<>
122 inline void destroy<IRNode>(const IRNode *t) {
123  delete t;
124 }
125 
126 /** IR nodes are split into expressions and statements. These are
127  similar to expressions and statements in C - expressions
128  represent some value and have some type (e.g. x + 3), and
129  statements are side-effecting pieces of code that do not
130  represent a value (e.g. assert(x > 3)) */
131 
132 /** A base class for statement nodes. They have no properties or
133  methods beyond base IR nodes for now. */
134 struct BaseStmtNode : public IRNode {
136  : IRNode(t) {
137  }
138  virtual Stmt mutate_stmt(IRMutator *v) const = 0;
139 };
140 
141 /** A base class for expression nodes. They all contain their types
142  * (e.g. Int(32), Float(32)) */
143 struct BaseExprNode : public IRNode {
145  : IRNode(t) {
146  }
147  virtual Expr mutate_expr(IRMutator *v) const = 0;
149 };
150 
151 /** We use the "curiously recurring template pattern" to avoid
152  duplicated code in the IR Nodes. These classes live between the
153  abstract base classes and the actual IR Nodes in the
154  inheritance hierarchy. It provides an implementation of the
155  accept function necessary for the visitor pattern to work, and
156  a concrete instantiation of a unique IRNodeType per class. */
157 template<typename T>
158 struct ExprNode : public BaseExprNode {
159  void accept(IRVisitor *v) const override;
160  Expr mutate_expr(IRMutator *v) const override;
162  : BaseExprNode(T::_node_type) {
163  }
164  ~ExprNode() override = default;
165 };
166 
167 template<typename T>
168 struct StmtNode : public BaseStmtNode {
169  void accept(IRVisitor *v) const override;
170  Stmt mutate_stmt(IRMutator *v) const override;
172  : BaseStmtNode(T::_node_type) {
173  }
174  ~StmtNode() override = default;
175 };
176 
177 /** IR nodes are passed around opaque handles to them. This is a
178  base class for those handles. It manages the reference count,
179  and dispatches visitors. */
180 struct IRHandle : public IntrusivePtr<const IRNode> {
182  IRHandle() = default;
183 
185  IRHandle(const IRNode *p)
186  : IntrusivePtr<const IRNode>(p) {
187  }
188 
189  /** Dispatch to the correct visitor method for this node. E.g. if
190  * this node is actually an Add node, then this will call
191  * IRVisitor::visit(const Add *) */
192  void accept(IRVisitor *v) const {
193  ptr->accept(v);
194  }
195 
196  /** Downcast this ir node to its actual type (e.g. Add, or
197  * Select). This returns nullptr if the node is not of the requested
198  * type. Example usage:
199  *
200  * if (const Add *add = node->as<Add>()) {
201  * // This is an add node
202  * }
203  */
204  template<typename T>
205  const T *as() const {
206  if (ptr && ptr->node_type == T::_node_type) {
207  return (const T *)ptr;
208  }
209  return nullptr;
210  }
211 
213  return ptr->node_type;
214  }
215 };
216 
217 /** Integer constants */
218 struct IntImm : public ExprNode<IntImm> {
220 
221  static const IntImm *make(Type t, int64_t value);
222 
224 };
225 
226 /** Unsigned integer constants */
227 struct UIntImm : public ExprNode<UIntImm> {
229 
230  static const UIntImm *make(Type t, uint64_t value);
231 
233 };
234 
235 /** Floating point constants */
236 struct FloatImm : public ExprNode<FloatImm> {
237  double value;
238 
239  static const FloatImm *make(Type t, double value);
240 
242 };
243 
244 /** String constants */
245 struct StringImm : public ExprNode<StringImm> {
246  std::string value;
247 
248  static const StringImm *make(const std::string &val);
249 
251 };
252 
253 } // namespace Internal
254 
255 /** A fragment of Halide syntax. It's implemented as reference-counted
256  * handle to a concrete expression node, but it's immutable, so you
257  * can treat it as a value type. */
258 struct Expr : public Internal::IRHandle {
259  /** Make an undefined expression */
261  Expr() = default;
262 
263  /** Make an expression from a concrete expression node pointer (e.g. Add) */
266  : IRHandle(n) {
267  }
268 
269  /** Make an expression representing numeric constants of various types. */
270  // @{
271  explicit Expr(int8_t x)
272  : IRHandle(Internal::IntImm::make(Int(8), x)) {
273  }
274  explicit Expr(int16_t x)
275  : IRHandle(Internal::IntImm::make(Int(16), x)) {
276  }
278  : IRHandle(Internal::IntImm::make(Int(32), x)) {
279  }
280  explicit Expr(int64_t x)
281  : IRHandle(Internal::IntImm::make(Int(64), x)) {
282  }
283  explicit Expr(uint8_t x)
284  : IRHandle(Internal::UIntImm::make(UInt(8), x)) {
285  }
286  explicit Expr(uint16_t x)
287  : IRHandle(Internal::UIntImm::make(UInt(16), x)) {
288  }
289  explicit Expr(uint32_t x)
290  : IRHandle(Internal::UIntImm::make(UInt(32), x)) {
291  }
292  explicit Expr(uint64_t x)
293  : IRHandle(Internal::UIntImm::make(UInt(64), x)) {
294  }
296  : IRHandle(Internal::FloatImm::make(Float(16), (double)x)) {
297  }
299  : IRHandle(Internal::FloatImm::make(BFloat(16), (double)x)) {
300  }
301  Expr(float x)
302  : IRHandle(Internal::FloatImm::make(Float(32), x)) {
303  }
304  explicit Expr(double x)
305  : IRHandle(Internal::FloatImm::make(Float(64), x)) {
306  }
307  // @}
308 
309  /** Make an expression representing a const string (i.e. a StringImm) */
310  Expr(const std::string &s)
311  : IRHandle(Internal::StringImm::make(s)) {
312  }
313 
314  /** Override get() to return a BaseExprNode * instead of an IRNode * */
316  const Internal::BaseExprNode *get() const {
317  return (const Internal::BaseExprNode *)ptr;
318  }
319 
320  /** Get the type of this expression node */
322  Type type() const {
323  return get()->type;
324  }
325 };
326 
327 /** This lets you use an Expr as a key in a map of the form
328  * map<Expr, Foo, ExprCompare> */
329 struct ExprCompare {
330  bool operator()(const Expr &a, const Expr &b) const {
331  return a.get() < b.get();
332  }
333 };
334 
335 /** A single-dimensional span. Includes all numbers between min and
336  * (min + extent - 1). */
337 struct Range {
339 
340  Range() = default;
341  Range(const Expr &min_in, const Expr &extent_in);
342 };
343 
344 /** A multi-dimensional box. The outer product of the elements */
345 typedef std::vector<Range> Region;
346 
347 /** An enum describing different address spaces to be used with Func::store_in. */
348 enum class MemoryType {
349  /** Let Halide select a storage type automatically */
350  Auto,
351 
352  /** Heap/global memory. Allocated using halide_malloc, or
353  * halide_device_malloc */
354  Heap,
355 
356  /** Stack memory. Allocated using alloca. Requires a constant
357  * size. Corresponds to per-thread local memory on the GPU. If all
358  * accesses are at constant coordinates, may be promoted into the
359  * register file at the discretion of the register allocator. */
360  Stack,
361 
362  /** Register memory. The allocation should be promoted into the
363  * register file. All stores must be at constant coordinates. May
364  * be spilled to the stack at the discretion of the register
365  * allocator. */
366  Register,
367 
368  /** Allocation is stored in GPU shared memory. Also known as
369  * "local" in OpenCL, and "threadgroup" in metal. Can be shared
370  * across GPU threads within the same block. */
371  GPUShared,
372 
373  /** Allocation is stored in GPU texture memory and accessed through
374  * hardware sampler */
375  GPUTexture,
376 
377  /** Allocate Locked Cache Memory to act as local memory */
378  LockedCache,
379  /** Vector Tightly Coupled Memory. HVX (Hexagon) local memory available on
380  * v65+. This memory has higher performance and lower power. Ideal for
381  * intermediate buffers. Necessary for vgather-vscatter instructions
382  * on Hexagon */
383  VTCM,
384 
385  /** AMX Tile register for X86. Any data that would be used in an AMX matrix
386  * multiplication must first be loaded into an AMX tile register. */
387  AMXTile,
388 };
389 
390 namespace Internal {
391 
392 /** An enum describing a type of loop traversal. Used in schedules,
393  * and in the For loop IR node. Serial is a conventional ordered for
394  * loop. Iterations occur in increasing order, and each iteration must
395  * appear to have finished before the next begins. Parallel, GPUBlock,
396  * and GPUThread are parallel and unordered: iterations may occur in
397  * any order, and multiple iterations may occur
398  * simultaneously. Vectorized and GPULane are parallel and
399  * synchronous: they act as if all iterations occur at the same time
400  * in lockstep. */
401 enum class ForType {
402  Serial,
403  Parallel,
404  Vectorized,
405  Unrolled,
406  Extern,
407  GPUBlock,
408  GPUThread,
409  GPULane,
410 };
411 
412 /** Check if for_type executes for loop iterations in parallel and unordered. */
414 
415 /** Returns true if for_type executes for loop iterations in parallel. */
416 bool is_parallel(ForType for_type);
417 
418 /** A reference-counted handle to a statement node. */
419 struct Stmt : public IRHandle {
420  Stmt() = default;
421  Stmt(const BaseStmtNode *n)
422  : IRHandle(n) {
423  }
424 
425  /** Override get() to return a BaseStmtNode * instead of an IRNode * */
427  const BaseStmtNode *get() const {
428  return (const Internal::BaseStmtNode *)ptr;
429  }
430 
431  /** This lets you use a Stmt as a key in a map of the form
432  * map<Stmt, Foo, Stmt::Compare> */
433  struct Compare {
434  bool operator()(const Stmt &a, const Stmt &b) const {
435  return a.ptr < b.ptr;
436  }
437  };
438 };
439 
440 } // namespace Internal
441 } // namespace Halide
442 
443 #endif
#define HALIDE_ALWAYS_INLINE
Definition: HalideRuntime.h:49
Support classes for reference-counting via intrusive shared pointers.
Defines halide types.
A base class for passes over the IR which modify it (e.g.
Definition: IRMutator.h:26
A base class for algorithms that need to recursively walk over the IR.
Definition: IRVisitor.h:19
A class representing a reference count to be used with IntrusivePtr.
Definition: IntrusivePtr.h:19
constexpr IRNodeType StrongestExprNodeType
Definition: Expr.h:81
RefCount & ref_count< IRNode >(const IRNode *t) noexcept
Definition: Expr.h:117
ForType
An enum describing a type of loop traversal.
Definition: Expr.h:401
bool is_unordered_parallel(ForType for_type)
Check if for_type executes for loop iterations in parallel and unordered.
bool is_parallel(ForType for_type)
Returns true if for_type executes for loop iterations in parallel.
void destroy< IRNode >(const IRNode *t)
Definition: Expr.h:122
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition: Expr.h:25
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Type BFloat(int bits, int lanes=1)
Construct a floating-point type in the bfloat format.
Definition: Type.h:545
Type UInt(int bits, int lanes=1)
Constructing an unsigned integer type.
Definition: Type.h:535
Type Float(int bits, int lanes=1)
Construct a floating-point type.
Definition: Type.h:540
@ Internal
Not visible externally, similar to 'static' linkage in C.
Type Int(int bits, int lanes=1)
Constructing a signed integer type.
Definition: Type.h:530
std::vector< Range > Region
A multi-dimensional box.
Definition: Expr.h:345
MemoryType
An enum describing different address spaces to be used with Func::store_in.
Definition: Expr.h:348
@ Auto
Let Halide select a storage type automatically.
@ Register
Register memory.
@ Stack
Stack memory.
@ VTCM
Vector Tightly Coupled Memory.
@ AMXTile
AMX Tile register for X86.
@ LockedCache
Allocate Locked Cache Memory to act as local memory.
@ Heap
Heap/global memory.
@ GPUTexture
Allocation is stored in GPU texture memory and accessed through hardware sampler.
@ GPUShared
Allocation is stored in GPU shared memory.
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT8_TYPE__ uint8_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
signed __INT16_TYPE__ int16_t
signed __INT8_TYPE__ int8_t
This lets you use an Expr as a key in a map of the form map<Expr, Foo, ExprCompare>
Definition: Expr.h:329
bool operator()(const Expr &a, const Expr &b) const
Definition: Expr.h:330
A fragment of Halide syntax.
Definition: Expr.h:258
Expr(float x)
Definition: Expr.h:301
HALIDE_ALWAYS_INLINE Expr()=default
Make an undefined expression.
Expr(int32_t x)
Definition: Expr.h:277
Expr(bfloat16_t x)
Definition: Expr.h:298
Expr(uint32_t x)
Definition: Expr.h:289
Expr(const std::string &s)
Make an expression representing a const string (i.e.
Definition: Expr.h:310
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition: Expr.h:322
Expr(int64_t x)
Definition: Expr.h:280
Expr(int16_t x)
Definition: Expr.h:274
Expr(uint64_t x)
Definition: Expr.h:292
Expr(uint16_t x)
Definition: Expr.h:286
Expr(double x)
Definition: Expr.h:304
Expr(int8_t x)
Make an expression representing numeric constants of various types.
Definition: Expr.h:271
HALIDE_ALWAYS_INLINE Expr(const Internal::BaseExprNode *n)
Make an expression from a concrete expression node pointer (e.g.
Definition: Expr.h:265
Expr(uint8_t x)
Definition: Expr.h:283
Expr(float16_t x)
Definition: Expr.h:295
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition: Expr.h:316
A base class for expression nodes.
Definition: Expr.h:143
virtual Expr mutate_expr(IRMutator *v) const =0
BaseExprNode(IRNodeType t)
Definition: Expr.h:144
IR nodes are split into expressions and statements.
Definition: Expr.h:134
BaseStmtNode(IRNodeType t)
Definition: Expr.h:135
virtual Stmt mutate_stmt(IRMutator *v) const =0
We use the "curiously recurring template pattern" to avoid duplicated code in the IR Nodes.
Definition: Expr.h:158
~ExprNode() override=default
Expr mutate_expr(IRMutator *v) const override
void accept(IRVisitor *v) const override
We use the visitor pattern to traverse IR nodes throughout the compiler, so we have a virtual accept ...
Floating point constants.
Definition: Expr.h:236
static const IRNodeType _node_type
Definition: Expr.h:241
static const FloatImm * make(Type t, double value)
Represents a location where storage will be hoisted to for a Func / Realize node with a given name.
Definition: IR.h:932
IR nodes are passed around opaque handles to them.
Definition: Expr.h:180
void accept(IRVisitor *v) const
Dispatch to the correct visitor method for this node.
Definition: Expr.h:192
HALIDE_ALWAYS_INLINE IRHandle()=default
IRNodeType node_type() const
Definition: Expr.h:212
HALIDE_ALWAYS_INLINE IRHandle(const IRNode *p)
Definition: Expr.h:185
const T * as() const
Downcast this ir node to its actual type (e.g.
Definition: Expr.h:205
The abstract base classes for a node in the Halide IR.
Definition: Expr.h:84
virtual ~IRNode()=default
virtual void accept(IRVisitor *v) const =0
We use the visitor pattern to traverse IR nodes throughout the compiler, so we have a virtual accept ...
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition: Expr.h:113
RefCount ref_count
These classes are all managed with intrusive reference counting, so we also track a reference count.
Definition: Expr.h:101
IRNode(IRNodeType t)
Definition: Expr.h:91
Integer constants.
Definition: Expr.h:218
static const IRNodeType _node_type
Definition: Expr.h:223
static const IntImm * make(Type t, int64_t value)
Intrusive shared pointers have a reference count (a RefCount object) stored in the class itself.
Definition: IntrusivePtr.h:68
This lets you use a Stmt as a key in a map of the form map<Stmt, Foo, Stmt::Compare>
Definition: Expr.h:433
bool operator()(const Stmt &a, const Stmt &b) const
Definition: Expr.h:434
A reference-counted handle to a statement node.
Definition: Expr.h:419
Stmt(const BaseStmtNode *n)
Definition: Expr.h:421
HALIDE_ALWAYS_INLINE const BaseStmtNode * get() const
Override get() to return a BaseStmtNode * instead of an IRNode *.
Definition: Expr.h:427
void accept(IRVisitor *v) const override
We use the visitor pattern to traverse IR nodes throughout the compiler, so we have a virtual accept ...
Stmt mutate_stmt(IRMutator *v) const override
~StmtNode() override=default
String constants.
Definition: Expr.h:245
static const StringImm * make(const std::string &val)
static const IRNodeType _node_type
Definition: Expr.h:250
Unsigned integer constants.
Definition: Expr.h:227
static const IRNodeType _node_type
Definition: Expr.h:232
static const UIntImm * make(Type t, uint64_t value)
A single-dimensional span.
Definition: Expr.h:337
Range()=default
Expr min
Definition: Expr.h:338
Expr extent
Definition: Expr.h:338
Range(const Expr &min_in, const Expr &extent_in)
Types in the halide type system.
Definition: Type.h:276
Class that provides a type that implements half precision floating point using the bfloat16 format.
Definition: Float16.h:158
Class that provides a type that implements half precision floating point (IEEE754 2008 binary16) in s...
Definition: Float16.h:17