Halide 17.0.2
Halide compiler and libraries
Loading...
Searching...
No Matches
Expr.h
Go to the documentation of this file.
1#ifndef HALIDE_EXPR_H
2#define HALIDE_EXPR_H
3
4/** \file
5 * Base classes for Halide expressions (\ref Halide::Expr) and statements (\ref Halide::Internal::Stmt)
6 */
7
8#include <string>
9#include <vector>
10
11#include "IntrusivePtr.h"
12#include "Type.h"
13
14namespace Halide {
15
16struct bfloat16_t;
17struct float16_t;
18
19namespace Internal {
20
21class IRMutator;
22class IRVisitor;
23
24/** All our IR node types get unique IDs for the purposes of RTTI */
25enum class IRNodeType {
26 // Exprs, in order of strength. Code in IRMatch.h and the
27 // simplifier relies on this order for canonicalization of
28 // expressions, so you may need to update those modules if you
29 // change this list.
30 IntImm,
31 UIntImm,
35 Cast,
38 Add,
39 Sub,
40 Mod,
41 Mul,
42 Div,
43 Min,
44 Max,
45 EQ,
46 NE,
47 LT,
48 LE,
49 GT,
50 GE,
51 And,
52 Or,
53 Not,
54 Select,
55 Load,
56 Ramp,
57 Call,
58 Let,
59 Shuffle,
61 // Stmts
62 LetStmt,
65 For,
66 Acquire,
67 Store,
68 Provide,
70 Free,
71 Realize,
72 Block,
73 Fork,
77 Atomic,
79};
80
82
83/** The abstract base classes for a node in the Halide IR. */
84struct IRNode {
85
86 /** We use the visitor pattern to traverse IR nodes throughout the
87 * compiler, so we have a virtual accept method which accepts
88 * visitors.
89 */
90 virtual void accept(IRVisitor *v) const = 0;
92 : node_type(t) {
93 }
94 virtual ~IRNode() = default;
95
96 /** These classes are all managed with intrusive reference
97 * counting, so we also track a reference count. It's mutable
98 * so that we can do reference counting even through const
99 * references to IR nodes.
100 */
102
103 /** Each IR node subclass has a unique identifier. We can compare
104 * these values to do runtime type identification. We don't
105 * compile with rtti because that injects run-time type
106 * identification stuff everywhere (and often breaks when linking
107 * external libraries compiled without it), and we only want it
108 * for IR nodes. One might want to put this value in the vtable,
109 * but that adds another level of indirection, and for Exprs we
110 * have 32 free bits in between the ref count and the Type
111 * anyway, so this doesn't increase the memory footprint of an IR node.
112 */
114};
115
116template<>
117inline RefCount &ref_count<IRNode>(const IRNode *t) noexcept {
118 return t->ref_count;
119}
120
121template<>
122inline void destroy<IRNode>(const IRNode *t) {
123 delete t;
124}
125
126/** IR nodes are split into expressions and statements. These are
127 similar to expressions and statements in C - expressions
128 represent some value and have some type (e.g. x + 3), and
129 statements are side-effecting pieces of code that do not
130 represent a value (e.g. assert(x > 3)) */
131
132/** A base class for statement nodes. They have no properties or
133 methods beyond base IR nodes for now. */
134struct BaseStmtNode : public IRNode {
136 : IRNode(t) {
137 }
138 virtual Stmt mutate_stmt(IRMutator *v) const = 0;
139};
140
141/** A base class for expression nodes. They all contain their types
142 * (e.g. Int(32), Float(32)) */
143struct BaseExprNode : public IRNode {
145 : IRNode(t) {
146 }
147 virtual Expr mutate_expr(IRMutator *v) const = 0;
149};
150
151/** We use the "curiously recurring template pattern" to avoid
152 duplicated code in the IR Nodes. These classes live between the
153 abstract base classes and the actual IR Nodes in the
154 inheritance hierarchy. It provides an implementation of the
155 accept function necessary for the visitor pattern to work, and
156 a concrete instantiation of a unique IRNodeType per class. */
157template<typename T>
158struct ExprNode : public BaseExprNode {
159 void accept(IRVisitor *v) const override;
160 Expr mutate_expr(IRMutator *v) const override;
162 : BaseExprNode(T::_node_type) {
163 }
164 ~ExprNode() override = default;
165};
166
167template<typename T>
168struct StmtNode : public BaseStmtNode {
169 void accept(IRVisitor *v) const override;
170 Stmt mutate_stmt(IRMutator *v) const override;
172 : BaseStmtNode(T::_node_type) {
173 }
174 ~StmtNode() override = default;
175};
176
177/** IR nodes are passed around opaque handles to them. This is a
178 base class for those handles. It manages the reference count,
179 and dispatches visitors. */
180struct IRHandle : public IntrusivePtr<const IRNode> {
182 IRHandle() = default;
183
185 IRHandle(const IRNode *p)
186 : IntrusivePtr<const IRNode>(p) {
187 }
188
189 /** Dispatch to the correct visitor method for this node. E.g. if
190 * this node is actually an Add node, then this will call
191 * IRVisitor::visit(const Add *) */
192 void accept(IRVisitor *v) const {
193 ptr->accept(v);
194 }
195
196 /** Downcast this ir node to its actual type (e.g. Add, or
197 * Select). This returns nullptr if the node is not of the requested
198 * type. Example usage:
199 *
200 * if (const Add *add = node->as<Add>()) {
201 * // This is an add node
202 * }
203 */
204 template<typename T>
205 const T *as() const {
206 if (ptr && ptr->node_type == T::_node_type) {
207 return (const T *)ptr;
208 }
209 return nullptr;
210 }
211
213 return ptr->node_type;
214 }
215};
216
217/** Integer constants */
218struct IntImm : public ExprNode<IntImm> {
220
221 static const IntImm *make(Type t, int64_t value);
222
224};
225
226/** Unsigned integer constants */
227struct UIntImm : public ExprNode<UIntImm> {
229
230 static const UIntImm *make(Type t, uint64_t value);
231
233};
234
235/** Floating point constants */
236struct FloatImm : public ExprNode<FloatImm> {
237 double value;
238
239 static const FloatImm *make(Type t, double value);
240
242};
243
244/** String constants */
245struct StringImm : public ExprNode<StringImm> {
246 std::string value;
247
248 static const StringImm *make(const std::string &val);
249
251};
252
253} // namespace Internal
254
255/** A fragment of Halide syntax. It's implemented as reference-counted
256 * handle to a concrete expression node, but it's immutable, so you
257 * can treat it as a value type. */
258struct Expr : public Internal::IRHandle {
259 /** Make an undefined expression */
261 Expr() = default;
262
263 /** Make an expression from a concrete expression node pointer (e.g. Add) */
266 : IRHandle(n) {
267 }
268
269 /** Make an expression representing numeric constants of various types. */
270 // @{
271 explicit Expr(int8_t x)
272 : IRHandle(Internal::IntImm::make(Int(8), x)) {
273 }
274 explicit Expr(int16_t x)
275 : IRHandle(Internal::IntImm::make(Int(16), x)) {
276 }
278 : IRHandle(Internal::IntImm::make(Int(32), x)) {
279 }
280 explicit Expr(int64_t x)
281 : IRHandle(Internal::IntImm::make(Int(64), x)) {
282 }
283 explicit Expr(uint8_t x)
284 : IRHandle(Internal::UIntImm::make(UInt(8), x)) {
285 }
286 explicit Expr(uint16_t x)
287 : IRHandle(Internal::UIntImm::make(UInt(16), x)) {
288 }
289 explicit Expr(uint32_t x)
290 : IRHandle(Internal::UIntImm::make(UInt(32), x)) {
291 }
292 explicit Expr(uint64_t x)
293 : IRHandle(Internal::UIntImm::make(UInt(64), x)) {
294 }
296 : IRHandle(Internal::FloatImm::make(Float(16), (double)x)) {
297 }
299 : IRHandle(Internal::FloatImm::make(BFloat(16), (double)x)) {
300 }
301 Expr(float x)
302 : IRHandle(Internal::FloatImm::make(Float(32), x)) {
303 }
304 explicit Expr(double x)
305 : IRHandle(Internal::FloatImm::make(Float(64), x)) {
306 }
307 // @}
308
309 /** Make an expression representing a const string (i.e. a StringImm) */
310 Expr(const std::string &s)
311 : IRHandle(Internal::StringImm::make(s)) {
312 }
313
314 /** Override get() to return a BaseExprNode * instead of an IRNode * */
317 return (const Internal::BaseExprNode *)ptr;
318 }
319
320 /** Get the type of this expression node */
322 Type type() const {
323 return get()->type;
324 }
325};
326
327/** This lets you use an Expr as a key in a map of the form
328 * map<Expr, Foo, ExprCompare> */
330 bool operator()(const Expr &a, const Expr &b) const {
331 return a.get() < b.get();
332 }
333};
334
335/** A single-dimensional span. Includes all numbers between min and
336 * (min + extent - 1). */
337struct Range {
339
340 Range() = default;
341 Range(const Expr &min_in, const Expr &extent_in);
342};
343
344/** A multi-dimensional box. The outer product of the elements */
345typedef std::vector<Range> Region;
346
347/** An enum describing different address spaces to be used with Func::store_in. */
348enum class MemoryType {
349 /** Let Halide select a storage type automatically */
350 Auto,
351
352 /** Heap/global memory. Allocated using halide_malloc, or
353 * halide_device_malloc */
354 Heap,
355
356 /** Stack memory. Allocated using alloca. Requires a constant
357 * size. Corresponds to per-thread local memory on the GPU. If all
358 * accesses are at constant coordinates, may be promoted into the
359 * register file at the discretion of the register allocator. */
360 Stack,
361
362 /** Register memory. The allocation should be promoted into the
363 * register file. All stores must be at constant coordinates. May
364 * be spilled to the stack at the discretion of the register
365 * allocator. */
366 Register,
367
368 /** Allocation is stored in GPU shared memory. Also known as
369 * "local" in OpenCL, and "threadgroup" in metal. Can be shared
370 * across GPU threads within the same block. */
371 GPUShared,
372
373 /** Allocation is stored in GPU texture memory and accessed through
374 * hardware sampler */
376
377 /** Allocate Locked Cache Memory to act as local memory */
379 /** Vector Tightly Coupled Memory. HVX (Hexagon) local memory available on
380 * v65+. This memory has higher performance and lower power. Ideal for
381 * intermediate buffers. Necessary for vgather-vscatter instructions
382 * on Hexagon */
383 VTCM,
384
385 /** AMX Tile register for X86. Any data that would be used in an AMX matrix
386 * multiplication must first be loaded into an AMX tile register. */
387 AMXTile,
388};
389
390namespace Internal {
391
392/** An enum describing a type of loop traversal. Used in schedules,
393 * and in the For loop IR node. Serial is a conventional ordered for
394 * loop. Iterations occur in increasing order, and each iteration must
395 * appear to have finished before the next begins. Parallel, GPUBlock,
396 * and GPUThread are parallel and unordered: iterations may occur in
397 * any order, and multiple iterations may occur
398 * simultaneously. Vectorized and GPULane are parallel and
399 * synchronous: they act as if all iterations occur at the same time
400 * in lockstep. */
401enum class ForType {
402 Serial,
403 Parallel,
405 Unrolled,
406 Extern,
407 GPUBlock,
408 GPUThread,
409 GPULane,
410};
411
412/** Check if for_type executes for loop iterations in parallel and unordered. */
414
415/** Returns true if for_type executes for loop iterations in parallel. */
416bool is_parallel(ForType for_type);
417
418/** A reference-counted handle to a statement node. */
419struct Stmt : public IRHandle {
420 Stmt() = default;
422 : IRHandle(n) {
423 }
424
425 /** Override get() to return a BaseStmtNode * instead of an IRNode * */
427 const BaseStmtNode *get() const {
428 return (const Internal::BaseStmtNode *)ptr;
429 }
430
431 /** This lets you use a Stmt as a key in a map of the form
432 * map<Stmt, Foo, Stmt::Compare> */
433 struct Compare {
434 bool operator()(const Stmt &a, const Stmt &b) const {
435 return a.ptr < b.ptr;
436 }
437 };
438};
439
440} // namespace Internal
441} // namespace Halide
442
443#endif
#define HALIDE_ALWAYS_INLINE
Support classes for reference-counting via intrusive shared pointers.
Defines halide types.
A base class for passes over the IR which modify it (e.g.
Definition IRMutator.h:26
A base class for algorithms that need to recursively walk over the IR.
Definition IRVisitor.h:19
A class representing a reference count to be used with IntrusivePtr.
constexpr IRNodeType StrongestExprNodeType
Definition Expr.h:81
ForType
An enum describing a type of loop traversal.
Definition Expr.h:401
RefCount & ref_count< IRNode >(const IRNode *t) noexcept
Definition Expr.h:117
bool is_unordered_parallel(ForType for_type)
Check if for_type executes for loop iterations in parallel and unordered.
bool is_parallel(ForType for_type)
Returns true if for_type executes for loop iterations in parallel.
void destroy< IRNode >(const IRNode *t)
Definition Expr.h:122
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition Expr.h:25
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Type BFloat(int bits, int lanes=1)
Construct a floating-point type in the bfloat format.
Definition Type.h:545
Type UInt(int bits, int lanes=1)
Constructing an unsigned integer type.
Definition Type.h:535
Type Float(int bits, int lanes=1)
Construct a floating-point type.
Definition Type.h:540
@ Internal
Not visible externally, similar to 'static' linkage in C.
Type Int(int bits, int lanes=1)
Constructing a signed integer type.
Definition Type.h:530
std::vector< Range > Region
A multi-dimensional box.
Definition Expr.h:345
MemoryType
An enum describing different address spaces to be used with Func::store_in.
Definition Expr.h:348
@ Auto
Let Halide select a storage type automatically.
@ Register
Register memory.
@ Stack
Stack memory.
@ VTCM
Vector Tightly Coupled Memory.
@ AMXTile
AMX Tile register for X86.
@ LockedCache
Allocate Locked Cache Memory to act as local memory.
@ Heap
Heap/global memory.
@ GPUTexture
Allocation is stored in GPU texture memory and accessed through hardware sampler.
@ GPUShared
Allocation is stored in GPU shared memory.
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT8_TYPE__ uint8_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
signed __INT16_TYPE__ int16_t
signed __INT8_TYPE__ int8_t
This lets you use an Expr as a key in a map of the form map<Expr, Foo, ExprCompare>
Definition Expr.h:329
bool operator()(const Expr &a, const Expr &b) const
Definition Expr.h:330
A fragment of Halide syntax.
Definition Expr.h:258
Expr(float x)
Definition Expr.h:301
HALIDE_ALWAYS_INLINE Expr()=default
Make an undefined expression.
Expr(int32_t x)
Definition Expr.h:277
Expr(bfloat16_t x)
Definition Expr.h:298
Expr(uint32_t x)
Definition Expr.h:289
Expr(const std::string &s)
Make an expression representing a const string (i.e.
Definition Expr.h:310
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition Expr.h:322
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition Expr.h:316
Expr(int64_t x)
Definition Expr.h:280
Expr(int16_t x)
Definition Expr.h:274
Expr(uint64_t x)
Definition Expr.h:292
Expr(uint16_t x)
Definition Expr.h:286
Expr(double x)
Definition Expr.h:304
Expr(int8_t x)
Make an expression representing numeric constants of various types.
Definition Expr.h:271
HALIDE_ALWAYS_INLINE Expr(const Internal::BaseExprNode *n)
Make an expression from a concrete expression node pointer (e.g.
Definition Expr.h:265
Expr(uint8_t x)
Definition Expr.h:283
Expr(float16_t x)
Definition Expr.h:295
The sum of two expressions.
Definition IR.h:56
Allocate a scratch area called with the given name, type, and size.
Definition IR.h:371
Logical and - are both expressions true.
Definition IR.h:175
If the 'condition' is false, then evaluate and return the message, which should be a call to an error...
Definition IR.h:294
Lock all the Store nodes in the body statement.
Definition IR.h:948
A base class for expression nodes.
Definition Expr.h:143
virtual Expr mutate_expr(IRMutator *v) const =0
BaseExprNode(IRNodeType t)
Definition Expr.h:144
IR nodes are split into expressions and statements.
Definition Expr.h:134
BaseStmtNode(IRNodeType t)
Definition Expr.h:135
virtual Stmt mutate_stmt(IRMutator *v) const =0
A sequence of statements to be executed in-order.
Definition IR.h:442
A vector with 'lanes' elements, in which every element is 'value'.
Definition IR.h:259
A function call.
Definition IR.h:490
The actual IR nodes begin here.
Definition IR.h:30
The ratio of two expressions.
Definition IR.h:83
Is the first expression equal to the second.
Definition IR.h:121
Evaluate and discard an expression, presumably because it has some side-effect.
Definition IR.h:476
We use the "curiously recurring template pattern" to avoid duplicated code in the IR Nodes.
Definition Expr.h:158
~ExprNode() override=default
Expr mutate_expr(IRMutator *v) const override
void accept(IRVisitor *v) const override
We use the visitor pattern to traverse IR nodes throughout the compiler, so we have a virtual accept ...
Floating point constants.
Definition Expr.h:236
static const IRNodeType _node_type
Definition Expr.h:241
static const FloatImm * make(Type t, double value)
A for loop.
Definition IR.h:805
A pair of statements executed concurrently.
Definition IR.h:457
Free the resources associated with the given buffer.
Definition IR.h:413
Is the first expression greater than or equal to the second.
Definition IR.h:166
Is the first expression greater than the second.
Definition IR.h:157
Represents a location where storage will be hoisted to for a Func / Realize node with a given name.
Definition IR.h:932
IR nodes are passed around opaque handles to them.
Definition Expr.h:180
void accept(IRVisitor *v) const
Dispatch to the correct visitor method for this node.
Definition Expr.h:192
HALIDE_ALWAYS_INLINE IRHandle()=default
const T * as() const
Downcast this ir node to its actual type (e.g.
Definition Expr.h:205
IRNodeType node_type() const
Definition Expr.h:212
HALIDE_ALWAYS_INLINE IRHandle(const IRNode *p)
Definition Expr.h:185
The abstract base classes for a node in the Halide IR.
Definition Expr.h:84
virtual ~IRNode()=default
virtual void accept(IRVisitor *v) const =0
We use the visitor pattern to traverse IR nodes throughout the compiler, so we have a virtual accept ...
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition Expr.h:113
RefCount ref_count
These classes are all managed with intrusive reference counting, so we also track a reference count.
Definition Expr.h:101
IRNode(IRNodeType t)
Definition Expr.h:91
An if-then-else block.
Definition IR.h:466
Integer constants.
Definition Expr.h:218
static const IRNodeType _node_type
Definition Expr.h:223
static const IntImm * make(Type t, int64_t value)
Intrusive shared pointers have a reference count (a RefCount object) stored in the class itself.
Is the first expression less than or equal to the second.
Definition IR.h:148
Is the first expression less than the second.
Definition IR.h:139
A let expression, like you might find in a functional language.
Definition IR.h:271
The statement form of a let node.
Definition IR.h:282
Load a value from a named symbol if predicate is true.
Definition IR.h:217
The greater of two values.
Definition IR.h:112
The lesser of two values.
Definition IR.h:103
The remainder of a / b.
Definition IR.h:94
The product of two expressions.
Definition IR.h:74
Is the first expression not equal to the second.
Definition IR.h:130
Logical not - true if the expression false.
Definition IR.h:193
Logical or - is at least one of the expression true.
Definition IR.h:184
Represent a multi-dimensional region of a Func or an ImageParam that needs to be prefetched.
Definition IR.h:910
This node is a helpful annotation to do with permissions.
Definition IR.h:315
This defines the value of a function at a multi-dimensional location.
Definition IR.h:354
A linear ramp vector node.
Definition IR.h:247
Allocate a multi-dimensional buffer of the given type and size.
Definition IR.h:427
Reinterpret value as another type, without affecting any of the bits (on little-endian systems).
Definition IR.h:47
A ternary operator.
Definition IR.h:204
Construct a new vector by taking elements from another sequence of vectors.
Definition IR.h:841
This lets you use a Stmt as a key in a map of the form map<Stmt, Foo, Stmt::Compare>
Definition Expr.h:433
bool operator()(const Stmt &a, const Stmt &b) const
Definition Expr.h:434
A reference-counted handle to a statement node.
Definition Expr.h:419
Stmt(const BaseStmtNode *n)
Definition Expr.h:421
HALIDE_ALWAYS_INLINE const BaseStmtNode * get() const
Override get() to return a BaseStmtNode * instead of an IRNode *.
Definition Expr.h:427
void accept(IRVisitor *v) const override
We use the visitor pattern to traverse IR nodes throughout the compiler, so we have a virtual accept ...
Stmt mutate_stmt(IRMutator *v) const override
~StmtNode() override=default
Store a 'value' to the buffer called 'name' at a given 'index' if 'predicate' is true.
Definition IR.h:333
String constants.
Definition Expr.h:245
static const StringImm * make(const std::string &val)
static const IRNodeType _node_type
Definition Expr.h:250
The difference of two expressions.
Definition IR.h:65
Unsigned integer constants.
Definition Expr.h:227
static const IRNodeType _node_type
Definition Expr.h:232
static const UIntImm * make(Type t, uint64_t value)
A named variable.
Definition IR.h:758
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition IR.h:966
A single-dimensional span.
Definition Expr.h:337
Range()=default
Expr min
Definition Expr.h:338
Expr extent
Definition Expr.h:338
Range(const Expr &min_in, const Expr &extent_in)
Types in the halide type system.
Definition Type.h:276
Class that provides a type that implements half precision floating point using the bfloat16 format.
Definition Float16.h:158
Class that provides a type that implements half precision floating point (IEEE754 2008 binary16) in s...
Definition Float16.h:17