Halide  17.0.2
Halide compiler and libraries
SpirvIR.h
Go to the documentation of this file.
1 #ifndef HALIDE_SPIRV_IR_H
2 #define HALIDE_SPIRV_IR_H
3 
4 /** \file
5  * Defines methods for constructing and encoding instructions into the Khronos
6  * format specification known as the Standard Portable Intermediate Representation
7  * for Vulkan (SPIR-V). These class interfaces adopt Halide's conventions for its
8  * own IR, but is implemented as a stand-alone optional component that can be
9  * enabled as required for certain runtimes (eg Vulkan).
10  *
11  * NOTE: This file is only used internally for CodeGen! *DO NOT* add this file
12  * to the list of exported Halide headers in the src/CMakeFiles.txt or the
13  * top level Makefile.
14  */
15 #ifdef WITH_SPIRV
16 
17 #include <map>
18 #include <set>
19 #include <stack>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "IntrusivePtr.h"
24 #include "Type.h"
25 
26 #include <spirv/1.6/GLSL.std.450.h> // GLSL extended instructions for common intrinsics
27 #include <spirv/1.6/spirv.h> // Use v1.6 headers but only use the minimal viable format version (for maximum compatiblity)
28 
29 namespace Halide {
30 namespace Internal {
31 
32 /** Precision requirment for return values */
33 enum SpvPrecision {
34  SpvFullPrecision,
35  SpvRelaxedPrecision,
36 };
37 
38 /** Scope qualifiers for Execution & Memory operations */
39 enum SpvScope {
40  SpvCrossDeviceScope = 0,
41  SpvDeviceScope = 1,
42  SpvWorkgroupScope = 2,
43  SpvSubgroupScope = 3,
44  SpvInvocationScope = 4
45 };
46 
47 /** Specific types of predefined constants */
48 enum SpvPredefinedConstant {
49  SpvNullConstant,
50  SpvTrueConstant,
51  SpvFalseConstant,
52 };
53 
54 /** Specific types of SPIR-V object ids */
55 enum SpvKind {
56  SpvInvalidItem,
57  SpvTypeId,
58  SpvVoidTypeId,
59  SpvBoolTypeId,
60  SpvIntTypeId,
61  SpvUIntTypeId,
62  SpvFloatTypeId,
63  SpvVectorTypeId,
64  SpvArrayTypeId,
65  SpvRuntimeArrayTypeId,
66  SpvStringTypeId,
67  SpvPointerTypeId,
68  SpvStructTypeId,
69  SpvFunctionTypeId,
70  SpvAccessChainId,
71  SpvConstantId,
72  SpvBoolConstantId,
73  SpvIntConstantId,
74  SpvFloatConstantId,
75  SpvStringConstantId,
76  SpvCompositeConstantId,
77  SpvResultId,
78  SpvVariableId,
79  SpvInstructionId,
80  SpvFunctionId,
81  SpvBlockId,
82  SpvLabelId,
83  SpvParameterId,
84  SpvImportId,
85  SpvModuleId,
86  SpvUnknownItem,
87 };
88 
89 /** Specific types of SPIR-V operand types */
90 enum SpvValueType {
91  SpvInvalidValueType,
92  SpvOperandId,
93  SpvBitMaskLiteral,
94  SpvIntegerLiteral,
95  SpvIntegerData,
96  SpvFloatData,
97  SpvStringData,
98  SpvUnknownValueType
99 };
100 
101 /** SPIR-V requires all IDs to be 32-bit unsigned integers */
102 using SpvId = uint32_t;
103 using SpvBinary = std::vector<uint32_t>;
104 
105 static constexpr SpvStorageClass SpvInvalidStorageClass = SpvStorageClassMax; // sentinel for invalid storage class
106 static constexpr SpvId SpvInvalidId = SpvId(-1);
107 static constexpr SpvId SpvNoResult = 0;
108 static constexpr SpvId SpvNoType = 0;
109 
110 /** Pre-declarations for SPIR-V IR classes */
111 class SpvModule;
112 class SpvFunction;
113 class SpvBlock;
114 class SpvInstruction;
115 class SpvBuilder;
116 class SpvContext;
117 struct SpvFactory;
118 
119 /** Pre-declarations for SPIR-V IR data structures */
120 struct SpvModuleContents;
121 struct SpvFunctionContents;
122 struct SpvBlockContents;
123 struct SpvInstructionContents;
124 
125 /** Intrusive pointer types for SPIR-V IR data */
126 using SpvModuleContentsPtr = IntrusivePtr<SpvModuleContents>;
127 using SpvFunctionContentsPtr = IntrusivePtr<SpvFunctionContents>;
128 using SpvBlockContentsPtr = IntrusivePtr<SpvBlockContents>;
129 using SpvInstructionContentsPtr = IntrusivePtr<SpvInstructionContents>;
130 
131 /** General interface for representing a SPIR-V Instruction */
132 class SpvInstruction {
133 public:
134  using LiteralValue = std::pair<uint32_t, SpvValueType>;
135  using Immediates = std::vector<LiteralValue>;
136  using Operands = std::vector<SpvId>;
137  using ValueTypes = std::vector<SpvValueType>;
138 
139  SpvInstruction() = default;
140  ~SpvInstruction();
141 
142  SpvInstruction(const SpvInstruction &) = default;
143  SpvInstruction &operator=(const SpvInstruction &) = default;
144  SpvInstruction(SpvInstruction &&) = default;
145  SpvInstruction &operator=(SpvInstruction &&) = default;
146 
147  void set_result_id(SpvId id);
148  void set_type_id(SpvId id);
149  void set_op_code(SpvOp opcode);
150  void add_operand(SpvId id);
151  void add_operands(const Operands &operands);
152  void add_immediate(SpvId id, SpvValueType type);
153  void add_immediates(const Immediates &Immediates);
154  void add_data(uint32_t bytes, const void *data, SpvValueType type);
155  void add_string(const std::string &str);
156 
157  template<typename T>
158  void append(const T &operands_or_immediates_or_strings);
159 
160  SpvId result_id() const;
161  SpvId type_id() const;
162  SpvOp op_code() const;
163  SpvId operand(uint32_t index) const;
164  const void *data(uint32_t index = 0) const;
165  SpvValueType value_type(uint32_t index) const;
166  const Operands &operands() const;
167 
168  bool has_type() const;
169  bool has_result() const;
170  bool is_defined() const;
171  bool is_immediate(uint32_t index) const;
172  uint32_t length() const;
173  void check_defined() const;
174  void clear();
175 
176  void encode(SpvBinary &binary) const;
177 
178  static SpvInstruction make(SpvOp op_code);
179 
180 protected:
181  SpvInstructionContentsPtr contents;
182 };
183 
184 /** General interface for representing a SPIR-V Block */
185 class SpvBlock {
186 public:
187  using Instructions = std::vector<SpvInstruction>;
188  using Variables = std::vector<SpvInstruction>;
189  using Blocks = std::vector<SpvBlock>;
190 
191  SpvBlock() = default;
192  ~SpvBlock();
193 
194  SpvBlock(const SpvBlock &) = default;
195  SpvBlock &operator=(const SpvBlock &) = default;
196  SpvBlock(SpvBlock &&) = default;
197  SpvBlock &operator=(SpvBlock &&) = default;
198 
199  void add_instruction(SpvInstruction inst);
200  void add_variable(SpvInstruction var);
201  const Instructions &instructions() const;
202  const Variables &variables() const;
203  bool is_reachable() const;
204  bool is_terminated() const;
205  bool is_defined() const;
206  SpvId id() const;
207  void check_defined() const;
208  void clear();
209 
210  void encode(SpvBinary &binary) const;
211 
212  static SpvBlock make(SpvId block_id);
213 
214 protected:
215  SpvBlockContentsPtr contents;
216 };
217 
218 /** General interface for representing a SPIR-V Function */
219 class SpvFunction {
220 public:
221  using Blocks = std::vector<SpvBlock>;
222  using Parameters = std::vector<SpvInstruction>;
223 
224  SpvFunction() = default;
225  ~SpvFunction();
226 
227  SpvFunction(const SpvFunction &) = default;
228  SpvFunction &operator=(const SpvFunction &) = default;
229  SpvFunction(SpvFunction &&) = default;
230  SpvFunction &operator=(SpvFunction &&) = default;
231 
232  SpvBlock create_block(SpvId block_id);
233  void add_block(SpvBlock block);
234  void add_parameter(SpvInstruction param);
235  void set_return_precision(SpvPrecision precision);
236  void set_parameter_precision(uint32_t index, SpvPrecision precision);
237  bool is_defined() const;
238  void clear();
239 
240  const Blocks &blocks() const;
241  SpvBlock entry_block() const;
242  SpvBlock tail_block() const;
243  SpvPrecision return_precision() const;
244  const Parameters &parameters() const;
245  SpvPrecision parameter_precision(uint32_t index) const;
246  uint32_t parameter_count() const;
247  uint32_t control_mask() const;
248  SpvInstruction declaration() const;
249  SpvId return_type_id() const;
250  SpvId type_id() const;
251  SpvId id() const;
252  void check_defined() const;
253 
254  void encode(SpvBinary &binary) const;
255 
256  static SpvFunction make(SpvId func_id, SpvId func_type_id, SpvId return_type_id, uint32_t control_mask = SpvFunctionControlMaskNone);
257 
258 protected:
259  SpvFunctionContentsPtr contents;
260 };
261 
262 /** General interface for representing a SPIR-V code module */
263 class SpvModule {
264 public:
265  using ImportDefinition = std::pair<SpvId, std::string>;
266  using ImportNames = std::vector<std::string>;
267  using EntryPointNames = std::vector<std::string>;
268  using Instructions = std::vector<SpvInstruction>;
269  using Functions = std::vector<SpvFunction>;
270  using Capabilities = std::vector<SpvCapability>;
271  using Extensions = std::vector<std::string>;
272  using Imports = std::vector<ImportDefinition>;
273 
274  SpvModule() = default;
275  ~SpvModule();
276 
277  SpvModule(const SpvModule &) = default;
278  SpvModule &operator=(const SpvModule &) = default;
279  SpvModule(SpvModule &&) = default;
280  SpvModule &operator=(SpvModule &&) = default;
281 
282  void add_debug_string(SpvId result_id, const std::string &string);
283  void add_debug_symbol(SpvId id, const std::string &symbol);
284  void add_annotation(SpvInstruction val);
285  void add_type(SpvInstruction val);
286  void add_constant(SpvInstruction val);
287  void add_global(SpvInstruction val);
288  void add_execution_mode(SpvInstruction val);
289  void add_function(SpvFunction val);
290  void add_instruction(SpvInstruction val);
291  void add_entry_point(const std::string &name, SpvInstruction entry_point);
292 
293  void import_instruction_set(SpvId id, const std::string &instruction_set);
294  void require_capability(SpvCapability val);
295  void require_extension(const std::string &val);
296 
297  void set_version_format(uint32_t version);
298  void set_source_language(SpvSourceLanguage val);
299  void set_addressing_model(SpvAddressingModel val);
300  void set_memory_model(SpvMemoryModel val);
301  void set_binding_count(SpvId count);
302 
303  uint32_t version_format() const;
304  SpvSourceLanguage source_language() const;
305  SpvAddressingModel addressing_model() const;
306  SpvMemoryModel memory_model() const;
307  SpvInstruction entry_point(const std::string &name) const;
308  EntryPointNames entry_point_names() const;
309  ImportNames import_names() const;
310  SpvId lookup_import(const std::string &Instruction_set) const;
311  uint32_t entry_point_count() const;
312 
313  Imports imports() const;
314  Extensions extensions() const;
315  Capabilities capabilities() const;
316  Instructions entry_points() const;
317  const Instructions &execution_modes() const;
318  const Instructions &debug_source() const;
319  const Instructions &debug_symbols() const;
320  const Instructions &annotations() const;
321  const Instructions &type_definitions() const;
322  const Instructions &global_constants() const;
323  const Instructions &global_variables() const;
324  const Functions &function_definitions() const;
325 
326  uint32_t binding_count() const;
327  SpvModule module() const;
328 
329  bool is_imported(const std::string &instruction_set) const;
330  bool is_capability_required(SpvCapability val) const;
331  bool is_extension_required(const std::string &val) const;
332  bool is_defined() const;
333  SpvId id() const;
334  void check_defined() const;
335  void clear();
336 
337  void encode(SpvBinary &binary) const;
338 
339  static SpvModule make(SpvId module_id,
340  SpvSourceLanguage source_language = SpvSourceLanguageUnknown,
341  SpvAddressingModel addressing_model = SpvAddressingModelLogical,
342  SpvMemoryModel memory_model = SpvMemoryModelSimple);
343 
344 protected:
345  SpvModuleContentsPtr contents;
346 };
347 
348 /** Builder interface for constructing a SPIR-V code module and
349  * all associated types, declarations, blocks, functions &
350  * instructions */
351 class SpvBuilder {
352 public:
353  using ParamTypes = std::vector<SpvId>;
354  using Components = std::vector<SpvId>;
355  using StructMemberTypes = std::vector<SpvId>;
356  using Variables = std::vector<SpvId>;
357  using Indices = std::vector<uint32_t>;
358  using Literals = std::vector<uint32_t>;
359 
360  SpvBuilder();
361  ~SpvBuilder() = default;
362 
363  SpvBuilder(const SpvBuilder &) = delete;
364  SpvBuilder &operator=(const SpvBuilder &) = delete;
365 
366  // Reserve a unique ID to use for identifying a specifc kind of SPIR-V result **/
367  SpvId reserve_id(SpvKind = SpvResultId);
368 
369  // Look up the specific kind of SPIR-V item from its unique ID
370  SpvKind kind_of(SpvId id) const;
371 
372  // Get a human readable name for a specific kind of SPIR-V item
373  std::string kind_name(SpvKind kind) const;
374 
375  // Look up the ID associated with the type for a given variable ID
376  SpvId type_of(SpvId variable_id) const;
377 
378  // Top-Level declaration methods ... each of these is a convenvience
379  // function that checks to see if the requested thing has already been
380  // declared, in which case it returns its existing id, otherwise it
381  // adds a new declaration, and returns the new id. This avoids all
382  // the logic checks in the calling code, and also ensures that
383  // duplicates aren't created.
384 
385  SpvId declare_void_type();
386  SpvId declare_type(const Type &type, uint32_t array_size = 1);
387  SpvId declare_pointer_type(const Type &type, SpvStorageClass storage_class);
388  SpvId declare_pointer_type(SpvId type_id, SpvStorageClass storage_class);
389  SpvId declare_constant(const Type &type, const void *data, bool is_specialization = false);
390  SpvId declare_null_constant(const Type &type);
391  SpvId declare_bool_constant(bool value);
392  SpvId declare_string_constant(const std::string &str);
393  SpvId declare_integer_constant(const Type &type, int64_t value);
394  SpvId declare_float_constant(const Type &type, double value);
395  SpvId declare_scalar_constant(const Type &type, const void *data);
396  SpvId declare_vector_constant(const Type &type, const void *data);
397  SpvId declare_specialization_constant(const Type &type, const void *data);
398  SpvId declare_access_chain(SpvId ptr_type_id, SpvId base_id, const Indices &indices);
399  SpvId declare_pointer_access_chain(SpvId ptr_type_id, SpvId base_id, SpvId element_id, const Indices &indices);
400  SpvId declare_function_type(SpvId return_type, const ParamTypes &param_types = {});
401  SpvId declare_function(const std::string &name, SpvId function_type);
402  SpvId declare_struct(const std::string &name, const StructMemberTypes &member_types);
403  SpvId declare_variable(const std::string &name, SpvId type_id, SpvStorageClass storage_class, SpvId initializer_id = SpvInvalidId);
404  SpvId declare_global_variable(const std::string &name, SpvId type_id, SpvStorageClass storage_class, SpvId initializer_id = SpvInvalidId);
405  SpvId declare_symbol(const std::string &symbol, SpvId id, SpvId scope_id);
406 
407  // Top level creation methods for adding new items ... these have a limited
408  // number of checks and the caller must ensure that duplicates aren't created
409  SpvId add_type(const Type &type, uint32_t array_size = 1);
410  SpvId add_struct(const std::string &name, const StructMemberTypes &member_types);
411  SpvId add_array_with_default_size(SpvId base_type_id, SpvId array_size_id);
412  SpvId add_runtime_array(SpvId base_type_id);
413  SpvId add_pointer_type(const Type &type, SpvStorageClass storage_class);
414  SpvId add_pointer_type(SpvId base_type_id, SpvStorageClass storage_class);
415  SpvId add_constant(const Type &type, const void *data, bool is_specialization = false);
416  SpvId add_function_type(SpvId return_type_id, const ParamTypes &param_type_ids);
417  SpvId add_function(const std::string &name, SpvId return_type, const ParamTypes &param_types = {});
418  SpvId add_instruction(SpvInstruction val);
419 
420  void add_annotation(SpvId target_id, SpvDecoration decoration_type, const Literals &literals = {});
421  void add_struct_annotation(SpvId struct_type_id, uint32_t member_index, SpvDecoration decoration_type, const Literals &literals = {});
422  void add_symbol(const std::string &symbol, SpvId id, SpvId scope_id);
423 
424  void add_entry_point(SpvId func_id, SpvExecutionModel exec_model,
425  const Variables &variables = {});
426 
427  // Define the execution mode with a fixed local size for the workgroup (using literal values)
428  void add_execution_mode_local_size(SpvId entry_point_id, uint32_t local_size_x, uint32_t local_size_y, uint32_t local_size_z);
429 
430  // Same as above but uses id's for the local size (to allow specialization constants to be used)
431  void add_execution_mode_local_size_id(SpvId entry_point_id, SpvId local_size_x, SpvId local_size_y, SpvId local_size_z);
432 
433  // Assigns a specific SPIR-V version format for output (needed for compatibility)
434  void set_version_format(uint32_t version);
435 
436  // Assigns a specific source language hint to the module
437  void set_source_language(SpvSourceLanguage val);
438 
439  // Sets the addressing model to use for the module
440  void set_addressing_model(SpvAddressingModel val);
441 
442  // Sets the memory model to use for the module
443  void set_memory_model(SpvMemoryModel val);
444 
445  // Returns the source language hint for the module
446  SpvSourceLanguage source_language() const;
447 
448  // Returns the addressing model used for the module
449  SpvAddressingModel addressing_model() const;
450 
451  // Returns the memory model used for the module
452  SpvMemoryModel memory_model() const;
453 
454  // Import the GLSL.std.450 external instruction set. Returns its corresponding ID.
455  SpvId import_glsl_intrinsics();
456 
457  // Import an external instruction set bby name. Returns its corresponding ID.
458  SpvId import_instruction_set(const std::string &instruction_set);
459 
460  // Add an extension string to the list of required extensions for the module
461  void require_extension(const std::string &extension);
462 
463  // Add a specific capability to the list of requirements for the module
464  void require_capability(SpvCapability);
465 
466  // Returns true if the given instruction set has been imported
467  bool is_imported(const std::string &instruction_set) const;
468 
469  // Returns true if the given extension string is required by the module
470  bool is_extension_required(const std::string &extension) const;
471 
472  // Returns true if the given capability is required by the module
473  bool is_capability_required(SpvCapability) const;
474 
475  // Change the current build location to the given block. All local
476  // declarations and instructions will be added here.
477  void enter_block(const SpvBlock &block);
478 
479  // Create a new block with the given ID
480  SpvBlock create_block(SpvId block_id);
481 
482  // Returns the current block (the active scope for building)
483  SpvBlock current_block() const;
484 
485  // Resets the block build scope, and unassigns the current block
486  SpvBlock leave_block();
487 
488  // Change the current build scope to be within the given function
489  void enter_function(const SpvFunction &func);
490 
491  // Returns the function object for the given ID (or an invalid function if none is found)
492  SpvFunction lookup_function(SpvId func_id) const;
493 
494  // Returns the current function being used as the active build scope
495  SpvFunction current_function() const;
496 
497  // Resets the function build scope, and unassigns the current function
498  SpvFunction leave_function();
499 
500  // Returns the current id being used for building (ie the last item created)
501  SpvId current_id() const;
502 
503  // Updates the current id being used for building
504  void update_id(SpvId id);
505 
506  // Returns true if the given id is of the corresponding type
507  bool is_pointer_type(SpvId id) const;
508  bool is_struct_type(SpvId id) const;
509  bool is_vector_type(SpvId id) const;
510  bool is_scalar_type(SpvId id) const;
511  bool is_array_type(SpvId id) const;
512  bool is_constant(SpvId id) const;
513 
514  // Looks up the given pointer type id and returns a corresponding base type id (or an invalid id if none is found)
515  SpvId lookup_base_type(SpvId pointer_type) const;
516 
517  // Returns the storage class for the given variable id (or invalid if none is found)
518  SpvStorageClass lookup_storage_class(SpvId id) const;
519 
520  // Returns the item id for the given symbol name (or an invalid id if none is found)
521  SpvId lookup_id(const std::string &symbol) const;
522 
523  // Returns the build scope id for the item id (or an invalid id if none is found)
524  SpvId lookup_scope(SpvId id) const;
525 
526  // Returns the id for the imported instruction set (or an invalid id if none is found)
527  SpvId lookup_import(const std::string &instruction_set) const;
528 
529  // Returns the symbol string for the given id (or an empty string if none is found)
530  std::string lookup_symbol(SpvId id) const;
531 
532  // Returns the current module being used for building
533  SpvModule current_module() const;
534 
535  // Appends the given instruction to the current build location
536  void append(SpvInstruction inst);
537 
538  // Finalizes the module and prepares it for encoding (must be called before module can be used)
539  void finalize();
540 
541  // Encodes the current module to the given binary
542  void encode(SpvBinary &binary) const;
543 
544  // Resets the builder and all internal state
545  void reset();
546 
547 protected:
548  using TypeKey = uint64_t;
549  using TypeMap = std::unordered_map<TypeKey, SpvId>;
550  using KindMap = std::unordered_map<SpvId, SpvKind>;
551  using PointerTypeKey = std::pair<SpvId, SpvStorageClass>;
552  using PointerTypeMap = std::map<PointerTypeKey, SpvId>;
553  using BaseTypeMap = std::unordered_map<SpvId, SpvId>;
554  using VariableTypeMap = std::unordered_map<SpvId, SpvId>;
555  using StorageClassMap = std::unordered_map<SpvId, SpvStorageClass>;
556  using ConstantKey = uint64_t;
557  using ConstantMap = std::unordered_map<ConstantKey, SpvId>;
558  using StringMap = std::unordered_map<ConstantKey, SpvId>;
559  using ScopeMap = std::unordered_map<SpvId, SpvId>;
560  using IdSymbolMap = std::unordered_map<SpvId, std::string>;
561  using SymbolIdMap = std::unordered_map<std::string, SpvId>;
562  using FunctionTypeKey = uint64_t;
563  using FunctionTypeMap = std::unordered_map<FunctionTypeKey, SpvId>;
564  using FunctionMap = std::unordered_map<SpvId, SpvFunction>;
565 
566  // Internal methods for creating ids, keys, and look ups
567 
568  SpvId make_id(SpvKind kind);
569 
570  TypeKey make_type_key(const Type &type, uint32_t array_size = 1) const;
571  SpvId lookup_type(const Type &type, uint32_t array_size = 1) const;
572 
573  TypeKey make_struct_type_key(const StructMemberTypes &member_types) const;
574  SpvId lookup_struct(const std::string &name, const StructMemberTypes &member_types) const;
575 
576  PointerTypeKey make_pointer_type_key(const Type &type, SpvStorageClass storage_class) const;
577  SpvId lookup_pointer_type(const Type &type, SpvStorageClass storage_class) const;
578 
579  PointerTypeKey make_pointer_type_key(SpvId base_type_id, SpvStorageClass storage_class) const;
580  SpvId lookup_pointer_type(SpvId base_type_id, SpvStorageClass storage_class) const;
581 
582  template<typename T>
583  SpvId declare_scalar_constant_of_type(const Type &scalar_type, const T *data);
584 
585  template<typename T>
586  SpvId declare_specialization_constant_of_type(const Type &scalar_type, const T *data);
587 
588  template<typename T>
589  SpvBuilder::Components declare_constants_for_each_lane(Type type, const void *data);
590 
591  ConstantKey make_bool_constant_key(bool value) const;
592  ConstantKey make_string_constant_key(const std::string &value) const;
593  ConstantKey make_constant_key(uint8_t code, uint8_t bits, int lanes, size_t bytes, const void *data, bool is_specialization = false) const;
594  ConstantKey make_constant_key(const Type &type, const void *data, bool is_specialization = false) const;
595  SpvId lookup_constant(const Type &type, const void *data, bool is_specialization = false) const;
596 
597  ConstantKey make_null_constant_key(const Type &type) const;
598  SpvId lookup_null_constant(const Type &type) const;
599 
600  SpvId lookup_variable(const std::string &name, SpvId type_id, SpvStorageClass storage_class, SpvId scope_id) const;
601  bool has_variable(const std::string &name, SpvId type_id, SpvStorageClass storage_class, SpvId scope_id) const;
602 
603  FunctionTypeKey make_function_type_key(SpvId return_type_id, const ParamTypes &param_type_ids) const;
604  SpvId lookup_function_type(SpvId return_type_id, const ParamTypes &param_type_ids) const;
605 
606  SpvId active_id = SpvInvalidId;
607  SpvFunction active_function;
608  SpvBlock active_block;
609  SpvModule module;
610  KindMap kind_map;
611  TypeMap type_map;
612  TypeMap struct_map;
613  ScopeMap scope_map;
614  StringMap string_map;
615  ConstantMap constant_map;
616  FunctionMap function_map;
617  IdSymbolMap id_symbol_map;
618  SymbolIdMap symbol_id_map;
619  BaseTypeMap base_type_map;
620  StorageClassMap storage_class_map;
621  PointerTypeMap pointer_type_map;
622  VariableTypeMap variable_type_map;
623  FunctionTypeMap function_type_map;
624 };
625 
626 /** Factory interface for constructing specific SPIR-V instructions */
627 struct SpvFactory {
628  using Indices = std::vector<uint32_t>;
629  using Literals = std::vector<uint32_t>;
630  using BranchWeights = std::vector<uint32_t>;
631  using Components = std::vector<SpvId>;
632  using ParamTypes = std::vector<SpvId>;
633  using MemberTypeIds = std::vector<SpvId>;
634  using Operands = std::vector<SpvId>;
635  using Variables = std::vector<SpvId>;
636  using VariableBlockIdPair = std::pair<SpvId, SpvId>; // (Variable Id, Block Id)
637  using BlockVariables = std::vector<VariableBlockIdPair>;
638 
639  static SpvInstruction no_op(SpvId result_id);
640  static SpvInstruction capability(const SpvCapability &capability);
641  static SpvInstruction extension(const std::string &extension);
642  static SpvInstruction import(SpvId instruction_set_id, const std::string &instruction_set_name);
643  static SpvInstruction label(SpvId result_id);
644  static SpvInstruction debug_line(SpvId string_id, uint32_t line, uint32_t column);
645  static SpvInstruction debug_string(SpvId result_id, const std::string &string);
646  static SpvInstruction debug_symbol(SpvId target_id, const std::string &symbol);
647  static SpvInstruction decorate(SpvId target_id, SpvDecoration decoration_type, const Literals &literals = {});
648  static SpvInstruction decorate_member(SpvId struct_type_id, uint32_t member_index, SpvDecoration decoration_type, const Literals &literals = {});
649  static SpvInstruction void_type(SpvId void_type_id);
650  static SpvInstruction bool_type(SpvId bool_type_id);
651  static SpvInstruction integer_type(SpvId int_type_id, uint32_t bits, uint32_t signedness);
652  static SpvInstruction float_type(SpvId float_type_id, uint32_t bits);
653  static SpvInstruction vector_type(SpvId vector_type_id, SpvId element_type_id, uint32_t vector_size);
654  static SpvInstruction array_type(SpvId array_type_id, SpvId element_type_id, SpvId array_size_id);
655  static SpvInstruction struct_type(SpvId result_id, const MemberTypeIds &member_type_ids);
656  static SpvInstruction runtime_array_type(SpvId result_type_id, SpvId base_type_id);
657  static SpvInstruction pointer_type(SpvId pointer_type_id, SpvStorageClass storage_class, SpvId base_type_id);
658  static SpvInstruction function_type(SpvId function_type_id, SpvId return_type_id, const ParamTypes &param_type_ids);
659  static SpvInstruction constant(SpvId result_id, SpvId type_id, size_t bytes, const void *data, SpvValueType value_type);
660  static SpvInstruction null_constant(SpvId result_id, SpvId type_id);
661  static SpvInstruction bool_constant(SpvId result_id, SpvId type_id, bool value);
662  static SpvInstruction string_constant(SpvId result_id, const std::string &value);
663  static SpvInstruction composite_constant(SpvId result_id, SpvId type_id, const Components &components);
664  static SpvInstruction specialization_constant(SpvId result_id, SpvId type_id, size_t bytes, const void *data, SpvValueType value_type);
665  static SpvInstruction variable(SpvId result_id, SpvId result_type_id, uint32_t storage_class, SpvId initializer_id = SpvInvalidId);
666  static SpvInstruction function(SpvId return_type_id, SpvId func_id, uint32_t control_mask, SpvId func_type_id);
667  static SpvInstruction function_parameter(SpvId param_type_id, SpvId param_id);
668  static SpvInstruction function_end();
669  static SpvInstruction return_stmt(SpvId return_value_id = SpvInvalidId);
670  static SpvInstruction entry_point(SpvId exec_model, SpvId func_id, const std::string &name, const Variables &variables);
671  static SpvInstruction memory_model(SpvAddressingModel addressing_model, SpvMemoryModel memory_model);
672  static SpvInstruction exec_mode_local_size(SpvId function_id, uint32_t local_size_size_x, uint32_t local_size_size_y, uint32_t local_size_size_z);
673  static SpvInstruction exec_mode_local_size_id(SpvId function_id, SpvId local_size_x_id, SpvId local_size_y_id, SpvId local_size_z_id); // only avail in 1.2
674  static SpvInstruction memory_barrier(SpvId memory_scope_id, SpvId semantics_mask_id);
675  static SpvInstruction control_barrier(SpvId execution_scope_id, SpvId memory_scope_id, SpvId semantics_mask_id);
676  static SpvInstruction bitwise_not(SpvId type_id, SpvId result_id, SpvId src_id);
677  static SpvInstruction bitwise_and(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
678  static SpvInstruction logical_not(SpvId type_id, SpvId result_id, SpvId src_id);
679  static SpvInstruction logical_and(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
680  static SpvInstruction shift_right_logical(SpvId type_id, SpvId result_id, SpvId src_id, SpvId shift_id);
681  static SpvInstruction shift_right_arithmetic(SpvId type_id, SpvId result_id, SpvId src_id, SpvId shift_id);
682  static SpvInstruction multiply_extended(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, bool is_signed);
683  static SpvInstruction select(SpvId type_id, SpvId result_id, SpvId condition_id, SpvId true_id, SpvId false_id);
684  static SpvInstruction in_bounds_access_chain(SpvId type_id, SpvId result_id, SpvId base_id, const Indices &indices);
685  static SpvInstruction pointer_access_chain(SpvId type_id, SpvId result_id, SpvId base_id, SpvId element_id, const Indices &indices);
686  static SpvInstruction load(SpvId type_id, SpvId result_id, SpvId ptr_id, uint32_t access_mask = 0x0);
687  static SpvInstruction store(SpvId ptr_id, SpvId obj_id, uint32_t access_mask = 0x0);
688  static SpvInstruction vector_insert_dynamic(SpvId type_id, SpvId result_id, SpvId vector_id, SpvId value_id, SpvId index_id);
689  static SpvInstruction vector_extract_dynamic(SpvId type_id, SpvId result_id, SpvId vector_id, SpvId value_id, SpvId index_id);
690  static SpvInstruction vector_shuffle(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, const Indices &indices);
691  static SpvInstruction composite_insert(SpvId type_id, SpvId result_id, SpvId object_id, SpvId composite_id, const SpvFactory::Indices &indices);
692  static SpvInstruction composite_extract(SpvId type_id, SpvId result_id, SpvId composite_id, const Indices &indices);
693  static SpvInstruction composite_construct(SpvId type_id, SpvId result_id, const Components &constituents);
694  static SpvInstruction is_inf(SpvId type_id, SpvId result_id, SpvId src_id);
695  static SpvInstruction is_nan(SpvId type_id, SpvId result_id, SpvId src_id);
696  static SpvInstruction bitcast(SpvId type_id, SpvId result_id, SpvId src_id);
697  static SpvInstruction float_add(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
698  static SpvInstruction integer_add(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
699  static SpvInstruction integer_equal(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
700  static SpvInstruction integer_not_equal(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
701  static SpvInstruction integer_less_than(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, bool is_signed);
702  static SpvInstruction integer_less_than_equal(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, bool is_signed);
703  static SpvInstruction integer_greater_than(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, bool is_signed);
704  static SpvInstruction integer_greater_than_equal(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, bool is_signed);
705  static SpvInstruction branch(SpvId target_label_id);
706  static SpvInstruction conditional_branch(SpvId condition_label_id, SpvId true_label_id, SpvId false_label_id, const BranchWeights &weights = {});
707  static SpvInstruction loop_merge(SpvId merge_label_id, SpvId continue_label_id, uint32_t loop_control_mask = SpvLoopControlMaskNone);
708  static SpvInstruction selection_merge(SpvId merge_label_id, uint32_t selection_control_mask = SpvSelectionControlMaskNone);
709  static SpvInstruction phi(SpvId type_id, SpvId result_id, const BlockVariables &block_vars);
710  static SpvInstruction unary_op(SpvOp op_code, SpvId type_id, SpvId result_id, SpvId src_id);
711  static SpvInstruction binary_op(SpvOp op_code, SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
712  static SpvInstruction convert(SpvOp op_code, SpvId type_id, SpvId result_id, SpvId src_id);
713  static SpvInstruction extended(SpvId instruction_set_id, SpvId instruction_number, SpvId type_id, SpvId result_id, const SpvFactory::Operands &operands);
714 };
715 
716 /** Contents of a SPIR-V Instruction */
717 struct SpvInstructionContents {
718  using Operands = std::vector<SpvId>;
719  using ValueTypes = std::vector<SpvValueType>;
720  mutable RefCount ref_count;
721  SpvOp op_code = SpvOpNop;
722  SpvId result_id = SpvNoResult;
723  SpvId type_id = SpvNoType;
724  Operands operands;
725  ValueTypes value_types;
726 };
727 
728 /** Contents of a SPIR-V code block */
729 struct SpvBlockContents {
730  using Instructions = std::vector<SpvInstruction>;
731  using Variables = std::vector<SpvInstruction>;
732  using Blocks = std::vector<SpvBlock>;
733  mutable RefCount ref_count;
734  SpvId block_id = SpvInvalidId;
735  Instructions instructions;
736  Variables variables;
737  Blocks before;
738  Blocks after;
739  bool reachable = true;
740 };
741 
742 /** Contents of a SPIR-V function */
743 struct SpvFunctionContents {
744  using PrecisionMap = std::unordered_map<SpvId, SpvPrecision>;
745  using Parameters = std::vector<SpvInstruction>;
746  using Blocks = std::vector<SpvBlock>;
747  mutable RefCount ref_count;
748  SpvId function_id;
749  SpvId function_type_id;
750  SpvId return_type_id;
751  uint32_t control_mask;
752  SpvInstruction declaration;
753  Parameters parameters;
754  PrecisionMap precision;
755  Blocks blocks;
756 };
757 
758 /** Contents of a SPIR-V code module */
759 struct SpvModuleContents {
760  using Capabilities = std::set<SpvCapability>;
761  using Extensions = std::set<std::string>;
762  using Imports = std::unordered_map<std::string, SpvId>;
763  using Functions = std::vector<SpvFunction>;
764  using Instructions = std::vector<SpvInstruction>;
765  using EntryPoints = std::unordered_map<std::string, SpvInstruction>;
766 
767  mutable RefCount ref_count;
768  SpvId module_id = SpvInvalidId;
769  SpvId version_format = SpvVersion;
770  SpvId binding_count = 0;
771  SpvSourceLanguage source_language = SpvSourceLanguageUnknown;
772  SpvAddressingModel addressing_model = SpvAddressingModelLogical;
773  SpvMemoryModel memory_model = SpvMemoryModelSimple;
774  Capabilities capabilities;
775  Extensions extensions;
776  Imports imports;
777  EntryPoints entry_points;
778  Instructions execution_modes;
779  Instructions debug_source;
780  Instructions debug_symbols;
781  Instructions annotations;
782  Instructions types;
783  Instructions constants;
784  Instructions globals;
785  Functions functions;
786  Instructions instructions;
787 };
788 
789 /** Helper functions for determining calling convention of GLSL builtins **/
790 bool is_glsl_unary_op(SpvId glsl_op_code);
791 bool is_glsl_binary_op(SpvId glsl_op_code);
792 uint32_t glsl_operand_count(SpvId glsl_op_code);
793 
794 /** Output the contents of a SPIR-V module in human-readable form **/
795 std::ostream &operator<<(std::ostream &stream, const SpvModule &);
796 
797 /** Output the definition of a SPIR-V function in human-readable form **/
798 std::ostream &operator<<(std::ostream &stream, const SpvFunction &);
799 
800 /** Output the contents of a SPIR-V block in human-readable form **/
801 std::ostream &operator<<(std::ostream &stream, const SpvBlock &);
802 
803 /** Output a SPIR-V instruction in human-readable form **/
804 std::ostream &operator<<(std::ostream &stream, const SpvInstruction &);
805 
806 } // namespace Internal
807 } // namespace Halide
808 
809 #endif // WITH_SPIRV
810 
811 namespace Halide {
812 namespace Internal {
813 
814 /** Internal test for SPIR-V IR **/
816 
817 } // namespace Internal
818 } // namespace Halide
819 
820 #endif // HALIDE_SPIRV_IR_H
Support classes for reference-counting via intrusive shared pointers.
Defines halide types.
void * lookup_symbol(const char *sym, const known_symbol *map)
void spirv_ir_test()
Internal test for SPIR-V IR.
std::ostream & operator<<(std::ostream &stream, const Stmt &)
Emit a halide statement on an output stream (such as std::cout) in a human-readable form.
RefCount & ref_count(const T *t) noexcept
Because in this header we don't yet know how client classes store their RefCount (and we don't want t...
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
Type type_of()
Construct the halide equivalent of a C type.
Definition: Type.h:561
Expr select(Expr condition, Expr true_value, Expr false_value)
Returns an expression similar to the ternary operator in C, except that it always evaluates all argum...
Expr is_nan(Expr x)
Returns true if the argument is a Not a Number (NaN).
Expr is_inf(Expr x)
Returns true if the argument is Inf or -Inf.
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
unsigned __INT8_TYPE__ uint8_t
unsigned __INT32_TYPE__ uint32_t